In [1]:
import argparse
import torch
import numpy as np
import random 

from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer

# custom
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.fgvc_aircraft
import datasets.dtd
import datasets.eurosat
import datasets.stanford_cars
import datasets.food101
import datasets.sun397
import datasets.caltech101
import datasets.ucf101
import datasets.imagenet

import datasets.imagenet_sketch
import datasets.imagenetv2
import datasets.imagenet_a
import datasets.imagenet_r

def print_args(args, cfg):
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)


def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    if args.seed:
        cfg.SEED = args.seed

    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains

    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains

    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head

    cfg.DATASET.SUBSAMPLE_CLASSES = args.subsample_classes
    cfg.DATALOADER.TRAIN_X.BATCH_SIZE = args.train_batch_size
    cfg.OPTIM.MAX_EPOCH = args.epoch


def extend_cfg(cfg):
    """
    Add new config variables.

    E.g.
        from yacs.config import CfgNode as CN
        cfg.TRAINER.MY_MODEL = CN()
        cfg.TRAINER.MY_MODEL.PARAM_A = 1.
        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
        cfg.TRAINER.MY_MODEL.PARAM_C = False
    """
    from yacs.config import CfgNode as CN
    
    cfg.VERBOSE = True
    cfg.TRAINER.MYTEMP = CN()
    cfg.TRAINER.MYTEMP.K = 8
    cfg.TRAINER.MYTEMP.CTX_INIT = ''
    cfg.TRAINER.MYTEMP.PREC = 'fp16'

    cfg.TRAINER.COCOOP = CN()
    cfg.TRAINER.COCOOP.N_CTX = 4
    cfg.TRAINER.COCOOP.CTX_INIT = 'a photo of a'
    cfg.TRAINER.COCOOP.PREC = 'fp16'

    cfg.TRAINER.COOP = CN()
    cfg.TRAINER.COOP.N_CTX = 4
    cfg.TRAINER.COOP.CSC = False
    cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = ''
    cfg.TRAINER.COOP.PREC = 'fp16'
    cfg.TRAINER.COOP.CTX_INIT = ''

    cfg.TRAINER.LP = CN()
    cfg.TRAINER.LP.PREC = 'fp16'
    cfg.TRAINER.LP.PROMPT = 'A photo of a {cls_name}'



    cfg.DATASET.SUBSAMPLE_CLASSES = "all"  # all, base or new
    cfg.DATASET.PROMPT = "a photo of a _."
    cfg.DATASET.NUM_SHOTS = 16
    
def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # 4. From optional input arguments
    cfg.merge_from_list(args.opts)

    cfg.freeze()

    return cfg


def main(args):
    cfg = setup_cfg(args)
    set_random_seed(2)
        
    setup_logger(cfg.OUTPUT_DIR)

    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        random.seed(2)
        np.random.seed(2)

    print_args(args, cfg)
    print("Collecting env info ...")
    print("** System info **\n{}\n".format(collect_env_info()))

    trainer = build_trainer(cfg)

    if args.eval_only:
        trainer.load_model(args.model_dir, epoch=args.load_epoch)
        trainer.test()
        return

    #if not args.no_train:
    trainer.train()




In [2]:
## MYTEMP

import os.path as osp
from collections import OrderedDict
import math

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler

from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    if cfg.TRAINER.NAME == "":
      design_trainer = "CoOp"
    else:
      design_trainer = cfg.TRAINER.NAME
    design_details = {"trainer": design_trainer,
                      "vision_depth": 0,
                      "language_depth": 0, "vision_ctx": 0,
                      "language_ctx": 0}
    model = clip.build_model(state_dict or model.state_dict())

    return model

class PromptLearner(nn.Module):
    def __init__(self, cfg, clip_model):
        super().__init__()
        positional_embedding = clip_model.positional_embedding

        # Make sure K >= 1
        assert cfg.TRAINER.MYTEMP.K >= 2, "K should be bigger than 1"

        self.K = cfg.TRAINER.MYTEMP.K # the number of prompt pair
        self.n_ctx = self.K
        self.dtype = clip_model.dtype
        self.d_t = clip_model.ln_final.weight.shape[0] #512
        self.d_v = 768

        clip_imsize = clip_model.visual.input_resolution # 224
        cfg_imsize = cfg.INPUT.SIZE[0] # (224, 224)[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        self.initialization_token(clip_model)
        
    def initialization_token(self, clip_model):
        #### text token initialization #####
        
        text_token = clip_model.token_embedding(torch.tensor([49407]))
        text_token = text_token.repeat(self.K, 1)
        text_noise = torch.randn(self.K, self.d_t)
        text_noise = text_noise / text_noise.norm(dim=-1, keepdim=True)
        text_token += 0.1 * text_noise
        text_token = text_token.type(self.dtype)
        self.text_prompt = nn.Parameter(text_token)
        '''
        t_prompt_vec = torch.empty(self.K, self.d_t, dtype=self.dtype)
        nn.init.normal_(t_prompt_vec, std=0.02)
        self.text_prompt = nn.Parameter(t_prompt_vec, requires_grad=True)
        '''
        #### visual token initialization ####
        
        visual_token = clip_model.visual.class_embedding
        visual_token = visual_token.repeat(self.K, 1)
        visual_noise = torch.randn(self.K, self.d_v)
        visual_noise = visual_noise / visual_noise.norm(dim=-1, keepdim=True)
        visual_token += 0.1 * visual_noise
        visual_token = visual_token.type(self.dtype)
        self.img_prompt = nn.Parameter(visual_token)
        '''
        v_prompt_vec = torch.empty(self.K, self.d_v, dtype=self.dtype)
        nn.init.normal_(v_prompt_vec, std=0.02)
        self.img_prompt = nn.Parameter(v_prompt_vec, requires_grad=True)
        '''
    def forward(self):
        return self.text_prompt, self.img_prompt


class CustomCLIP(nn.Module):
    '''
    cfg : model parameters
    device : model device
    layer : # of query generate FFN layers
    '''
    def __init__(self, cfg, classnames, prompt, clipmodel):
        super().__init__()
        self.cfg = cfg

        # text encoder
        self.token_embedding = clipmodel.token_embedding
        self.text_pos_embedding = clipmodel.positional_embedding
        self.text_transformers = clipmodel.transformer
        self.text_ln_final = clipmodel.ln_final
        self.text_proj = clipmodel.text_projection

        # vision encoder
        self.img_patch_embedding = clipmodel.visual.conv1
        self.img_cls_embedding = clipmodel.visual.class_embedding
        self.img_pos_embedding = clipmodel.visual.positional_embedding
        self.img_pre_ln = clipmodel.visual.ln_pre
        self.img_transformer = clipmodel.visual.transformer
        self.img_post_ln = clipmodel.visual.ln_post
        self.img_proj = clipmodel.visual.proj

        # logit
        self.logit_scale = clipmodel.logit_scale
        
        # initialization token
        self.prompt_learner = PromptLearner(self.cfg, clipmodel)

        #
        self.dtype = clipmodel.dtype
        self.prompts = self.make_prompts(classnames, prompt) # ["a photo of a dog.", ".."]

        # define mask
        self.define_mask()

    def make_prompts(self, classnames, prompt):
        prompts = [prompt.replace('_', c) for c in classnames]
        with torch.no_grad():
            self.text_tokenized = torch.cat([clip.tokenize(p) for p in prompts])
            self.text_x = self.token_embedding(self.text_tokenized).type(self.dtype) + self.text_pos_embedding.type(self.dtype)
            self.len_prompts = self.text_tokenized.argmax(dim=-1) + 1
        return prompts

    def define_mask(self):
        len_max = 77
        attn_head = 8

        text_mask = torch.empty(0, len_max, len_max)
        for idx in self.len_prompts:
            mask = torch.empty(len_max, len_max)
            mask.fill_(float("-inf"))
            mask.triu_(1)  # zero out the lower diagonal
            mask[:, idx:].fill_(float("-inf"))
            text_mask = torch.cat([text_mask, mask.repeat(attn_head, 1, 1)])
        self.text_mask = text_mask

        # image encoder mask
        att_size = 1 + 14 * 14 + self.cfg.TRAINER.MYTEMP.K
        visual_mask = torch.zeros((att_size, att_size), dtype=self.dtype, requires_grad=False)
        visual_mask[:, -1 * self.cfg.TRAINER.MYTEMP.K:] = float("-inf")
        #####

        self.visual_mask = visual_mask

    def forward(self, image, label=None):
        device = image.device

        # load mask from predefined masks
        text_mask = self.text_mask
        visual_mask = self.visual_mask
        K = self.cfg.TRAINER.MYTEMP.K

        # load prompts from prompt learner
        text_prompt, image_prompt = self.prompt_learner()

        ####################### text ###########################        
        text_x = self.text_x
        text_x = text_x.to(device)
        
        for i in range(K):
            text_x[torch.arange(text_x.shape[0]), self.len_prompts+i, :] = text_prompt[i, :].repeat(text_x.shape[0], 1)
        

        text_x = text_x.permute(1, 0, 2)  # NLD -> LND
        text_x = self.text_transformers(text_x, text_mask)
        text_x = text_x.permute(1, 0, 2)
        text_x = self.text_ln_final(text_x).type(self.dtype)

        text_f = torch.empty(text_x.shape[0], 0, 512, device=device, dtype=self.dtype)
        for i in range(K):
            idx = self.len_prompts + i
            x = text_x[torch.arange(text_x.shape[0]), idx]
            text_f = torch.cat([text_f, x[:, None, :]], dim=1)

        text_f = text_f @ self.text_proj
        t_f = text_x[torch.arange(text_x.shape[0]), self.text_tokenized.argmax(dim=-1)] @ self.text_proj
        
        ####################### img ###########################
        batch_size = image.shape[0]
        
        # forward propagate image features with token concatenation
        image_embedding = self.img_patch_embedding(image.type(self.dtype)) # (batch_size, h_dim, 7, 7)
        image_embedding = image_embedding.reshape(batch_size, image_embedding.shape[1], -1)
        image_embedding = image_embedding.permute(0,2,1) # (batch_size, 49, h_dim)
        image_embedding = torch.cat([self.img_cls_embedding.repeat(batch_size,1,1).type(self.dtype), image_embedding], dim=1) # 16 (batch_size, 50, h_dim)
        img_x = image_embedding + self.img_pos_embedding.type(self.dtype) # (N,L,D)
        # concatenation the token on visual encoder
        img_x = torch.cat([img_x, image_prompt.repeat(batch_size, 1, 1)], dim=1)
        # image encoder
        img_x = self.img_pre_ln(img_x)
        img_x = img_x.permute(1, 0, 2)
        img_x = self.img_transformer(img_x, visual_mask)
        img_x = img_x.permute(1, 0, 2)
        img_f = self.img_post_ln(img_x[:, -1 * K:, :]) @ self.img_proj
        i_f = self.img_post_ln(img_x[:, 0, :]) @ self.img_proj
        ####################### logit ###########################
        # logit

        text_f = text_f / text_f.norm(dim=-1, keepdim=True)
        t_f = t_f / t_f.norm(dim=-1, keepdim=True)

        img_f = img_f / img_f.norm(dim=-1, keepdim=True)
        i_f = i_f / i_f.norm(dim=-1, keepdim=True)

        logits = torch.zeros(img_f.shape[0], text_f.shape[0], device=device)
        for i in range(K):
            i_img_f = img_f[:,i,:]
            i_text_f = text_f[:,i,:]
            logit = self.logit_scale.exp() * i_img_f @ i_text_f.t()
            logits += logit
        logits /= K

        if self.prompt_learner.training:
            return F.cross_entropy(logits, label)
        
        return logits


@TRAINER_REGISTRY.register()
class MYTEMP(TrainerX):
    def check_cfg(self, cfg):
        assert cfg.TRAINER.MYTEMP.PREC in ["fp16", "fp32", "amp"]

    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)
        
        if cfg.TRAINER.MYTEMP.PREC == "fp32" or cfg.TRAINER.MYTEMP.PREC == "amp":
            # CLIP's default precision is fp16
            clip_model.float()

        prompt = cfg.DATASET.PROMPT
        ############################################# 통일 #####

        print("Building custom CLIP")
        self.model = CustomCLIP(cfg, classnames, prompt, clip_model)
        
        # parameter freeze
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        
        # Double check
        enabled = set()
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)

        self.model.to(self.device)
        # NOTE: only give prompt_learner to the optimizer
        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)

        self.scaler = GradScaler() if cfg.TRAINER.MYTEMP.PREC == "amp" else None

        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        device_count = torch.cuda.device_count()
        if device_count > 1:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)

        # nan detector
        torch.autograd.set_detect_anomaly(True)

    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)

        model = self.model
        optim = self.optim
        scaler = self.scaler

        prec = self.cfg.TRAINER.MYTEMP.PREC
        if prec == "amp":
            with autocast():
                loss = model(image, label)
            optim.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        else:
            loss = model(image, label)
            optim.zero_grad()
            loss.backward()
            optim.step()

        loss_summary = {"loss": loss.item()}

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    def load_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # By default, the best model is loaded
        model_file = "model-best.pth.tar"

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
            # set strict=False
            self._models[name].load_state_dict(state_dict, strict=False)

  

### **Q2. Trainining CoCoOp**

In this task, you will train CoCoOp on the EuroSAT dataset. If your implementation of CoCoOp in Question 1 is correct, the following code should execute without errors. Please submit the execution file so we can evaluate whether your code runs without any issues.

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="data/", help="path to dataset")
parser.add_argument("--output-dir", type=str, default="", help="output directory")
parser.add_argument(
    "--resume",
    type=str,
    default="",
    help="checkpoint directory (from which the training resumes)",
)
parser.add_argument(
    "--seed", type=int, default=2, help="only positive value enables a fixed seed"
)
parser.add_argument(
    "--source-domains", type=str, nargs="+", help="source domains for DA/DG"
)
parser.add_argument(
    "--target-domains", type=str, nargs="+", help="target domains for DA/DG"
)
parser.add_argument(
    "--transforms", type=str, nargs="+", help="data augmentation methods"
)

parser.add_argument("--trainer", type=str, default="", help="name of trainer")
parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
parser.add_argument("--head", type=str, default="", help="name of head")
parser.add_argument("--eval-only", action="store_true", help="evaluation only")
parser.add_argument(
    "--model-dir",
    type=str,
    default="",
    help="load model from this directory for eval-only mode",
)

parser.add_argument(
    "opts",
    default=None,
    nargs=argparse.REMAINDER,
    help="modify config options using the command-line",
)
parser.add_argument(
    "--dataset-config-file",
    type=str,
    default="configs/datasets/eurosat.yaml",
    help="path to config file for dataset setup",
)

parser.add_argument(
    "--config-file", type=str, default="configs/trainers/TEMP/main.yaml", help="path to config file"
)

# parser.add_argument("--subsample-classes", type=str, default="base")

args = parser.parse_args([])

args.trainer = "MYTEMP"
args.train_batch_size = 4
args.epoch = 20
args.output_dir = "outputs/mytemp"

args.subsample_classes = "base"
args.eval_only = False
mytemp_base_acc = main(args)

***************
** Arguments **
***************
backbone: 
config_file: configs/trainers/TEMP/main.yaml
dataset_config_file: configs/datasets/eurosat.yaml
epoch: 20
eval_only: False
head: 
model_dir: 
opts: []
output_dir: outputs/mytemp
resume: 
root: data/
seed: 2
source_domains: None
subsample_classes: base
target_domains: None
train_batch_size: 4
trainer: MYTEMP
transforms: None
************
** Config **
************
DATALOADER:
  K_TRANSFORMS: 1
  NUM_WORKERS: 4
  RETURN_IMG0: False
  TEST:
    BATCH_SIZE: 100
    SAMPLER: SequentialSampler
  TRAIN_U:
    BATCH_SIZE: 32
    N_DOMAIN: 0
    N_INS: 16
    SAME_AS_X: True
    SAMPLER: RandomSampler
  TRAIN_X:
    BATCH_SIZE: 4
    N_DOMAIN: 0
    N_INS: 16
    SAMPLER: RandomSampler
DATASET:
  ALL_AS_UNLABELED: False
  CIFAR_C_LEVEL: 1
  CIFAR_C_TYPE: 
  NAME: EuroSAT
  NUM_LABELED: -1
  NUM_SHOTS: 16
  PROMPT: a photo of a _.
  ROOT: data/
  SOURCE_DOMAINS: ()
  STL10_FOLD: -1
  SUBSAMPLE_CLASSES: base
  TARGET_DOMAINS: ()
  VAL_PERC

  checkpoint = torch.load(fpath, map_location=map_location)


Building custom CLIP
Parameters to be updated: {'prompt_learner.text_prompt', 'prompt_learner.img_prompt'}
Loading evaluator: Classification
Found checkpoint at outputs/mytemp (will resume training)
Loading checkpoint from "outputs/mytemp/prompt_learner/model.pth.tar-20"
Loaded model weights
Loaded optimizer
Loaded scheduler
Previous epoch: 20
Initialize tensorboard (log_dir=outputs/mytemp/tensorboard)
Finish training
Deploy the last-epoch model
Evaluate on the *test* set


100%|██████████| 42/42 [00:07<00:00,  5.35it/s]

=> result
* total: 4,200
* correct: 3,929
* accuracy: 93.5%
* error: 6.5%
* macro_f1: 93.6%
Elapsed: 0:00:08





In [4]:
# Accuracy on the New Classes.
args.model_dir = "outputs/mytemp"
args.output_dir = "outputs/mytemp/new_classes"
args.subsample_classes = "new"
args.load_epoch = 20
args.eval_only = True
mytemp_novel_acc = main(args)

***************
** Arguments **
***************
backbone: 
config_file: configs/trainers/TEMP/main.yaml
dataset_config_file: configs/datasets/eurosat.yaml
epoch: 20
eval_only: True
head: 
load_epoch: 20
model_dir: outputs/mytemp
opts: []
output_dir: outputs/mytemp/new_classes
resume: 
root: data/
seed: 2
source_domains: None
subsample_classes: new
target_domains: None
train_batch_size: 4
trainer: MYTEMP
transforms: None
************
** Config **
************
DATALOADER:
  K_TRANSFORMS: 1
  NUM_WORKERS: 4
  RETURN_IMG0: False
  TEST:
    BATCH_SIZE: 100
    SAMPLER: SequentialSampler
  TRAIN_U:
    BATCH_SIZE: 32
    N_DOMAIN: 0
    N_INS: 16
    SAME_AS_X: True
    SAMPLER: RandomSampler
  TRAIN_X:
    BATCH_SIZE: 4
    N_DOMAIN: 0
    N_INS: 16
    SAMPLER: RandomSampler
DATASET:
  ALL_AS_UNLABELED: False
  CIFAR_C_LEVEL: 1
  CIFAR_C_TYPE: 
  NAME: EuroSAT
  NUM_LABELED: -1
  NUM_SHOTS: 16
  PROMPT: a photo of a _.
  ROOT: data/
  SOURCE_DOMAINS: ()
  STL10_FOLD: -1
  SUBSAMPLE_CLASSE

  checkpoint = torch.load(fpath, map_location=map_location)
100%|██████████| 39/39 [00:06<00:00,  5.97it/s]
