In [1]:
import os
import random
from tqdm import tqdm

import torch

from datasets.imagenet import ImageNet
import clip
from utils import *
from clip.moco import load_moco
from clip.amu import *
from parse_args import parse_args

In [2]:
parser = parse_args()
# args = parser.parse_args()
args, _ = parser.parse_known_args(['--rand_seed', '2',
                                  '--torch_rand_seed', '1',
                                  '--exp_name', 'test_16_shot',
                                  '--clip_backbone', 'ViT-B-16',
                                  '--augment_epoch', '1',
                                  '--alpha', '0.5',
                                  '--lambda_merge', '0.35',
                                  '--train_epoch', '51',
                                  '--lr', '1e-3',
                                  '--batch_size', '8',
                                  '--shots', '4',
                                  '--root_path', 'data',
                                  '--dataset','oxford_pets',
                                  ])

In [3]:
cache_dir = os.path.join('./caches', args.dataset)
os.makedirs(cache_dir, exist_ok=True)
args.cache_dir = cache_dir

logger = config_logging(args)
logger.info("\nRunning configs.")
args_dict = vars(args)
message = '\n'.join([f'{k:<20}: {v}' for k, v in args_dict.items()])
logger.info(message)

2024-07-09 09:56 - INFO: - 
Running configs.
2024-07-09 09:56 - INFO: - exp_name            : test_16_shot
rand_seed           : 2
torch_rand_seed     : 1
root_path           : data
dataset             : oxford_pets
shots               : 4
train_epoch         : 51
lr                  : 0.001
load_pre_feat       : False
clip_backbone       : ViT-B-16
batch_size          : 8
val_batch_size      : 256
num_classes         : 1000
augment_epoch       : 1
load_aux_weight     : False
alpha               : 0.5
lambda_merge        : 0.35
uncent_type         : none
uncent_power        : 0.4
cache_dir           : ./caches/oxford_pets


In [4]:

# CLIP
clip_model, preprocess = clip.load(args.clip_backbone)
clip_model.eval()
# AUX MODEL 
aux_model, args.feat_dim = load_moco("data/r-50-1000ep.pth.tar")#Aux model path
    
aux_model.cuda()
aux_model.eval() 

# ImageNet dataset
random.seed(args.rand_seed)
torch.manual_seed(args.torch_rand_seed)

=> creating model
=> loading checkpoint 'data/r-50-1000ep.pth.tar'
_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])
=> loaded pre-trained model 'data/r-50-1000ep.pth.tar'


<torch._C.Generator at 0x7f4274a17490>

In [5]:
from datasets import build_dataset
from datasets.utils import build_data_loader

In [6]:
from datasets.oxford_pets import OxfordPets
from utils import tfm_train_base, tfm_test_base

dataset=OxfordPets("data",4)
val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=tfm_test_base, shuffle=False)
test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=tfm_test_base, shuffle=False)

train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=tfm_train_base, is_train=True, shuffle=False)#用于给辅助模型
train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=tfm_train_base, is_train=True, shuffle=True)#用于训练



Reading split from data/Oxfordpets/split_zhou_OxfordPets.json
Creating a 4-shot dataset


In [7]:
for i, (data, labels) in enumerate(train_loader_cache):
    print(f"Batch {i + 1}")
    print("Data:", data.shape)
    print("Labels:", labels.shape)
    if i == 3: 
        break

Batch 1
Data: torch.Size([148, 3, 224, 224])
Labels: torch.Size([148])


In [8]:
# Textual features
logger.info("Getting textual features as CLIP's classifier...")
clip_weights = gpt_clip_classifier(dataset.classnames, clip_model, dataset.template)


2024-07-09 09:56 - INFO: - Getting textual features as CLIP's classifier...


In [9]:
# Load visual features of few-shot training set
logger.info("Load visual features of few-shot training set...")
aux_features, aux_labels = load_aux_weight(args, aux_model, train_loader_cache, tfm_norm=tfm_aux)


2024-07-09 09:56 - INFO: - Load visual features of few-shot training set...


Augment Epoch: 0 / 1


100%|██████████| 1/1 [00:02<00:00,  2.22s/it]


In [10]:

# Pre-load test features
logger.info("Loading visual features and labels from test set.")
logger.info("Loading CLIP test feature.")
test_clip_features, test_labels  = load_test_features(args, "test", clip_model, test_loader, tfm_norm=tfm_clip, model_name='clip')

logger.info(f"Loading AUX test feature.")
test_aux_features, test_labels = load_test_features(args, "test", aux_model, test_loader, tfm_norm=tfm_aux, model_name='aux')

test_clip_features = test_clip_features.cuda()
test_aux_features = test_aux_features.cuda()


2024-07-09 09:56 - INFO: - Loading visual features and labels from test set.
2024-07-09 09:56 - INFO: - Loading CLIP test feature.
100%|██████████| 58/58 [00:11<00:00,  5.23it/s]
2024-07-09 09:56 - INFO: - Loading AUX test feature.
100%|██████████| 58/58 [00:10<00:00,  5.37it/s]


In [11]:

# zero shot
tmp =  test_clip_features / test_clip_features.norm(dim=-1, keepdim=True)
l = 100. * tmp @ clip_weights
print(f"{l.argmax(dim=-1).eq(test_labels.cuda()).sum().item()}/ {len(test_labels)} = {l.argmax(dim=-1).eq(test_labels.cuda()).sum().item()/len(test_labels) * 100:.2f}%")


3269/ 3669 = 89.10%


In [12]:
# build amu-model
model = AMU_Model(
    clip_model=clip_model,
    aux_model=aux_model,
    sample_features=[aux_features, aux_labels],
    clip_weights=clip_weights,
    feat_dim=args.feat_dim,
    class_num=37,
    lambda_merge=args.lambda_merge,
    alpha=args.alpha,
    uncent_type=args.uncent_type,
    uncent_power=args.uncent_power
)


init adapter weight by training samples...


In [13]:
def freeze_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()


In [14]:

def train_one_epoch(model, data_loader, optimizer, scheduler, logger):
    # Train
    model.train()
    model.apply(freeze_bn) # freeze BN-layer
    correct_samples, all_samples = 0, 0
    loss_list = []
    loss_aux_list = []
    loss_merge_list = [] 

    # origin image
    for i, (images, target) in enumerate(tqdm(data_loader)):
        images, target = images.cuda(), target.cuda()
        return_dict = model(images, labels=target)
        
        acc = cls_acc(return_dict['logits'], target)
        correct_samples += acc / 100 * len(return_dict['logits'])
        all_samples += len(return_dict['logits'])
        
        loss_list.append(return_dict['loss'].item())
        loss_aux_list.append(return_dict['loss_aux'].item())
        loss_merge_list.append(return_dict['loss_merge'].item())
        
        optimizer.zero_grad()
        return_dict['loss'].backward()
        optimizer.step()
        scheduler.step()

    current_lr = scheduler.get_last_lr()[0]
    logger.info('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list)))
    logger.info("""Loss_aux: {:.4f}, Loss_merge: {:.4f}""".format(sum(loss_aux_list)/len(loss_aux_list), sum(loss_merge_list)/len(loss_merge_list))) 

In [15]:

def train_and_eval(args, logger, model, clip_test_features, 
 aux_test_features, test_labels, train_loader_F):
    model.cuda()
    model.requires_grad_(False)
    model.aux_adapter.requires_grad_(True)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        weight_decay=0.01,
        lr=args.lr, 
        eps=1e-4
        )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch * len(train_loader_F))
    
    best_acc, best_epoch = 0.0, 0

    for train_idx in range(1, args.train_epoch + 1):
        logger.info('Train Epoch: {:} / {:}'.format(train_idx, args.train_epoch))
        train_one_epoch(model, train_loader_F, optimizer, scheduler, logger)
        # Eval
        model.eval()
        with torch.no_grad():
            return_dict = model(
                clip_features=clip_test_features,
                aux_features=aux_test_features,
                labels=test_labels
            )
            acc = cls_acc(return_dict['logits'], test_labels)
            acc_aux = cls_acc(return_dict['aux_logits'], test_labels)
        logger.info("----- Aux branch's Test Acc: {:.2f} ----".format(acc_aux))
        logger.info("----- AMU's Test Acc: {:.2f} -----\n".format(acc))

        if acc > best_acc:
            best_acc = acc
            best_epoch = train_idx
            torch.save(model.aux_adapter.state_dict(), args.cache_dir + f"/best_adapter_" + str(args.shots) + "shots.pt")
    logger.info(f"----- Best Test Acc: {best_acc:.2f}, at epoch: {best_epoch}.-----\n")


In [16]:

train_and_eval(args, logger, model, test_clip_features,  test_aux_features, test_labels, train_loader_F )

2024-07-09 09:56 - INFO: - Train Epoch: 1 / 51
100%|██████████| 1/1 [00:01<00:00,  1.83s/it]
2024-07-09 09:56 - INFO: - LR: 0.000999, Acc: 0.9054 (134.0/148), Loss: 0.6469
2024-07-09 09:56 - INFO: - Loss_aux: 0.8359, Loss_merge: 0.2958
2024-07-09 09:56 - INFO: - ----- Aux branch's Test Acc: 67.81 ----
2024-07-09 09:56 - INFO: - ----- AMU's Test Acc: 91.03 -----

2024-07-09 09:56 - INFO: - Train Epoch: 2 / 51
100%|██████████| 1/1 [00:01<00:00,  1.84s/it]
2024-07-09 09:57 - INFO: - LR: 0.000996, Acc: 0.9189 (136.0/148), Loss: 0.5938
2024-07-09 09:57 - INFO: - Loss_aux: 0.7487, Loss_merge: 0.3062
2024-07-09 09:57 - INFO: - ----- Aux branch's Test Acc: 68.30 ----
2024-07-09 09:57 - INFO: - ----- AMU's Test Acc: 91.09 -----

2024-07-09 09:57 - INFO: - Train Epoch: 3 / 51
100%|██████████| 1/1 [00:01<00:00,  1.84s/it]
2024-07-09 09:57 - INFO: - LR: 0.000991, Acc: 0.8986 (133.0/148), Loss: 0.5643
2024-07-09 09:57 - INFO: - Loss_aux: 0.6931, Loss_merge: 0.3250
2024-07-09 09:57 - INFO: - ----- A