In [None]:
import os
import random
from tqdm import tqdm

import torch

from datasets.imagenet import ImageNet
import clip
from utils import *
from clip.moco import load_moco
from clip.amu import *
from parse_args import parse_args

In [None]:
parser = parse_args()
# args = parser.parse_args()
args, _ = parser.parse_known_args(['--rand_seed', '2',
                                  '--torch_rand_seed', '3407',
                                  '--exp_name', 'test_16_shot',#ViT-B-16,RN101
                                  '--clip_backbone', 'ViT-B-32',
                                  '--augment_epoch', '1',
                                  '--alpha', '0.5',
                                  '--lambda_merge', '0.35',
                                  '--train_epoch', '10',
                                  '--lr', '1e-3',
                                  '--batch_size', '8',
                                  '--shots', '4',
                                  '--root_path', 'data',
                                  '--dataset','oxford_pets',
                                  '--uncent_type','max'
                                #   '--load_aux_weight'
                                  ])

In [None]:
cache_dir = os.path.join('./caches', args.dataset)
os.makedirs(cache_dir, exist_ok=True)
args.cache_dir = cache_dir

logger = config_logging(args)
logger.info("\nRunning configs.")
args_dict = vars(args)
message = '\n'.join([f'{k:<20}: {v}' for k, v in args_dict.items()])
# args.load_pre_feat=True
logger.info(message)


In [None]:

# CLIP
clip_model, preprocess = clip.load(args.clip_backbone)
clip_model.eval()
# AUX MODEL 
aux_model, args.feat_dim = load_moco("data/r-50-1000ep.pth.tar")#Aux model path
# aux_model, preprocess=clip.load('RN101')
   
aux_model.cuda()
aux_model.eval() 

# ImageNet dataset
random.seed(args.rand_seed)
torch.manual_seed(args.torch_rand_seed)

In [None]:
from datasets import build_dataset
from datasets.utils import build_data_loader


In [None]:
from datasets.oxford_pets import OxfordPets
from datasets.my_test_set import MyDataSet
from utils import tfm_train_base, tfm_test_base

dataset=MyDataSet("data",4,1)
val_loader = build_data_loader(data_source=dataset.val, batch_size=128, is_train=False, tfm=tfm_test_base, shuffle=False)
test_loader = build_data_loader(data_source=dataset.test, batch_size=128, is_train=False, tfm=tfm_test_base, shuffle=False)

train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=tfm_train_base, is_train=True, shuffle=False)#用于给辅助模型
train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=args.batch_size, tfm=tfm_train_base, is_train=True, shuffle=True)#用于训练



In [None]:
# for i, (data, labels) in enumerate(train_loader_cache):
#     print(f"Batch {i + 1}")
#     print("Data:", data.shape)
#     print("Labels:", labels.shape)
#     if i == 3: 
#         break

In [None]:
# Textual features
logger.info("Getting textual features as CLIP's classifier...")
clip_weights = gpt_clip_classifier(dataset.classnames, clip_model, dataset.template)


In [None]:
# Load visual features of few-shot training set
logger.info("Load visual features of few-shot training set...")
aux_features, aux_labels = load_aux_weight(args, aux_model, train_loader_cache, tfm_norm=tfm_aux)


In [None]:

# Pre-load test features
logger.info("Loading visual features and labels from test set.")
logger.info("Loading CLIP test feature.")
test_clip_features, test_labels  = load_test_features(args, "test", clip_model, test_loader, tfm_norm=tfm_clip, model_name='clip')


In [None]:

# zero shot
tmp =  test_clip_features / test_clip_features.norm(dim=-1, keepdim=True)
l = 100. * tmp @ clip_weights
print(f"{l.argmax(dim=-1).eq(test_labels.cuda()).sum().item()}/ {len(test_labels)} = {l.argmax(dim=-1).eq(test_labels.cuda()).sum().item()/len(test_labels) * 100:.2f}%")


In [None]:

logger.info(f"Loading AUX test feature.")
test_aux_features, test_labels = load_test_features(args, "test", aux_model, test_loader, tfm_norm=tfm_aux, model_name='aux')

test_clip_features = test_clip_features.cuda()
test_aux_features = test_aux_features.cuda()


In [None]:
# build amu-model
model = AMU_Model(
    clip_model=clip_model,
    aux_model=aux_model,
    sample_features=[aux_features, aux_labels],
    clip_weights=clip_weights,
    feat_dim=args.feat_dim,
    class_num=374,
    lambda_merge=args.lambda_merge,
    alpha=args.alpha,
    uncent_type=args.uncent_type,
    uncent_power=args.uncent_power
)


In [None]:
def freeze_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()


In [None]:

def train_one_epoch(model, data_loader, optimizer, scheduler, logger):
    # Train
    model.train()
    model.apply(freeze_bn) # freeze BN-layer
    correct_samples, all_samples = 0, 0
    loss_list = []
    loss_aux_list = []
    loss_merge_list = [] 

    # origin image
    for i, (images, target) in enumerate(tqdm(data_loader)):
        images, target = images.cuda(), target.cuda()
        # print('images[0].dtype',images[0].dtype)
        return_dict = model(images, labels=target)
        
        acc = cls_acc(return_dict['logits'], target)
        correct_samples += acc / 100 * len(return_dict['logits'])
        all_samples += len(return_dict['logits'])
        
        loss_list.append(return_dict['loss'].item())
        loss_aux_list.append(return_dict['loss_aux'].item())
        loss_merge_list.append(return_dict['loss_merge'].item())
        
        optimizer.zero_grad()
        return_dict['loss'].backward()
        optimizer.step()
        scheduler.step()

    current_lr = scheduler.get_last_lr()[0]
    logger.info('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list)))
    logger.info("""Loss_aux: {:.4f}, Loss_merge: {:.4f}""".format(sum(loss_aux_list)/len(loss_aux_list), sum(loss_merge_list)/len(loss_merge_list))) 

In [None]:

def train_and_eval(args, logger, model, clip_test_features, 
 aux_test_features, test_labels, train_loader_F):
    model.cuda()
    model.requires_grad_(False)
    model.aux_adapter.requires_grad_(True)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        weight_decay=0.01,
        lr=args.lr, 
        eps=1e-4
        )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch * len(train_loader_F))
    
    best_acc, best_epoch = 0.0, 0

    for train_idx in range(1, args.train_epoch + 1):
        logger.info('Train Epoch: {:} / {:}'.format(train_idx, args.train_epoch))
        train_one_epoch(model, train_loader_F, optimizer, scheduler, logger)
        # Eval
        model.eval()
        with torch.no_grad():
            return_dict = model(
                clip_features=clip_test_features,
                aux_features=aux_test_features,
                labels=test_labels
            )
            acc = cls_acc(return_dict['logits'], test_labels)
            acc_aux = cls_acc(return_dict['aux_logits'], test_labels)
        logger.info("----- Aux branch's Test Acc: {:.2f} ----".format(acc_aux))
        logger.info("----- AMU's Test Acc: {:.2f} -----\n".format(acc))

        if acc > best_acc:
            best_acc = acc
            best_epoch = train_idx
            torch.save(model.aux_adapter.state_dict(), args.cache_dir + f"/best_adapter_" + str(args.shots) + "shots.pt")
    logger.info(f"----- Best Test Acc: {best_acc:.2f}, at epoch: {best_epoch}.-----\n")


In [None]:

train_and_eval(args, logger, model, test_clip_features,test_aux_features, test_labels, train_loader_F )

In [None]:

best_adapter_path = args.cache_dir + f"/best_adapter_{args.shots}shots.pt"
model.aux_adapter.load_state_dict(torch.load(best_adapter_path))

In [None]:
from datasets import my_test_set_2
from datasets.my_test_set_2 import MyDataSet2

import importlib
importlib.reload(my_test_set_2)

In [None]:
dataset=MyDataSet2("data",4)
test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=tfm_test_base, shuffle=False)

In [None]:
# for i, (data, labels) in enumerate(test_loader):
#     print(f"Batch {i + 1}")
#     print("Data:", data.shape)
#     print("Labels:", labels.shape)
#     if i == 3: 
#         break

In [None]:
test_clip_features, test_labels  = load_test_features(args, "test", clip_model, test_loader, tfm_norm=tfm_clip, model_name='clip')

logger.info(f"Loading AUX test feature.")
test_aux_features, test_labels = load_test_features(args, "test", aux_model, test_loader, tfm_norm=tfm_aux, model_name='aux')

In [None]:
model.eval()
return_dict = model(clip_features=test_clip_features,aux_features=test_aux_features)

In [None]:
top5=return_dict['logits'].topk(5, 1, True, True)[1]

In [None]:
#提交的数据集
test_labels_list=test_labels.tolist()
top5_list=top5.tolist()
save_path='data/result.txt'
save_file = open(save_path, 'w')
count=0
for i in tqdm(range(len(test_labels_list))) :
    temp="image_"+str(test_labels_list[i])
    name=temp+".jpg"
    # print(name)
    if not os.path.exists("data/TestSetA/"+name):
        name=temp+".jpeg"
    if not os.path.exists("data/TestSetA/"+name):
        name=temp+".png"
        print(name)
        print("??????")
    save_file.write(name + ' ' +' '.join([str(p) for p in top5_list[i]]) + '\n')  
    
    count+=1
    # break
print("写入完成,共计",count)
save_file.close()


In [None]:
import zipfile
# 压缩结果文件
zip_file_path = 'data/result.zip'
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
    zipf.write(save_path, os.path.basename(save_path))

# 删除原文件
# os.remove(save_path)
print(f"{save_path} 已压缩为 {zip_file_path} 并删除原文件。")

In [None]:
len(os.listdir("data/TestSetA/"))