In [1]:
import pickle
import torch
import clip
from torchvision import datasets
from torch.utils.data import DataLoader,Subset
import test_finetune
import random
from tqdm import tqdm
import os

In [2]:
model,preprocess = clip.load("ViT-B/32",device="cuda")
torch.save(model.visual.state_dict(),'zero_shot.pt')

In [3]:
train_data = datasets.CIFAR100(root="datas",train=True, download= True, transform=preprocess)
test_data = datasets.CIFAR100(root = "datas", train= False, download=True, transform=preprocess)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
with open('datas/cifar-100-python/train', 'rb') as fo:
    dict = pickle.load(fo, encoding='bytes')

with open('datas/cifar-100-python/test', 'rb') as fo:
    dict_test = pickle.load(fo, encoding='bytes')

In [5]:
def seed_all(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_all(90)

In [6]:
labels = dict[b'fine_labels']
label_test = dict_test[b'fine_labels']

task1_keys = list(range(20))
task2_keys = list(range(20,40))
task3_keys= list(range(40,60))
task4_keys = list(range(60,80))
task5_keys = list(range(80,100))

task1_indicies= []
task2_indicies= []
task3_indicies= []
task4_indicies= []
task5_indicies= []

test1_indicies= []
test2_indicies= []
test3_indicies= []
test4_indicies= []
test5_indicies= []

for i in range(len(labels)):
    if labels[i] in task1_keys:
        task1_indicies.append(i)
    elif labels[i] in task2_keys:
        task2_indicies.append(i)
    elif labels[i] in task3_keys:
        task3_indicies.append(i) 
    elif labels[i] in task4_keys:
        task4_indicies.append(i)
    else:
        task5_indicies.append(i)


for i in range(len(label_test)):
    if label_test[i] in task1_keys:
        test1_indicies.append(i)
    elif label_test[i] in task2_keys:
        test2_indicies.append(i)
    elif label_test[i] in task3_keys:
        test3_indicies.append(i) 
    elif label_test[i] in task4_keys:
        test4_indicies.append(i)
    else:
        test5_indicies.append(i)


In [7]:
task1 = Subset(train_data,task1_indicies)
task2 = Subset(train_data,task2_indicies)
task3 = Subset(train_data,task3_indicies)
task4 = Subset(train_data,task4_indicies)
task5 = Subset(train_data,task5_indicies)


test1 = Subset(test_data,test1_indicies)
test2 = Subset(test_data,test2_indicies)
test3 = Subset(test_data,test3_indicies)
test4 = Subset(test_data,test4_indicies)
test5 = Subset(test_data,test5_indicies)
test = [test1, test2, test3, test4, test5]

loader1 = DataLoader(task1, 64, shuffle=True)
loader2 = DataLoader(task2, 64, shuffle=True)
loader3 = DataLoader(task3, 64, shuffle=True)
loader4 = DataLoader(task4, 64, shuffle=True)
loader5 = DataLoader(task5, 64, shuffle=True)
train = [loader1, loader2, loader3, loader4, loader5]

In [8]:
def get_ref_sample( testdata, num, seeds):
    random.seed(seeds)

    id = random.sample(range(0,len(testdata)), num)

    shape_exp,_ = test_data[0]
    shape_exp = list(shape_exp.shape)
    shape_exp.insert(0,num)

    sample = torch.zeros(shape_exp)
    label = torch.zeros(num)

    for i in range(num):
        d, l = testdata[id[i]]
        sample[i] = d
        label[i] = l
    
    sample = sample.cuda()
    label = label.cuda()

    return sample, label


In [9]:
def get_emb(model, dict, img ,label, txt_only):
    ref_img_emb = None
    
    if txt_only == False:
        ref_img_emb = model.encode_image(img)
        ref_img_emb = ref_img_emb / ref_img_emb.norm(dim=-1, keepdim=True)
    label  =label.int()
    dict_ = [dict[i] for i in label]
    ref_txt_emb = torch.cat([clip.tokenize(f"the image of a {c}") for c in dict_]).cuda()
    ref_txt_emb = model.encode_text(ref_txt_emb)
    ref_txt_emb = ref_txt_emb/ ref_txt_emb.norm(dim=-1, keepdim=True)

    return ref_img_emb, ref_txt_emb

In [11]:
f = test_finetune.Finetune(train_data.classes, 2, model)

TABLE = [100,80,60,30,10]
TEST = torch.cat([clip.tokenize(f"the image of a {c}") for c in test_data.classes]).cuda()
INTERPO = [0.3,0.3,0.4]

num_ref_sample = 1000

def sub_test(cls_idx, model, test = TEST):
    sub = TEST[:(cls_idx+1)*20]
    out = model.encode_text(sub)
    out = out/out.norm(dim=-1, keepdim=True)
    return out

for count in range(5):
    print(f"Begin training task {count}...")
    f.fine_tune(train[count])

    ########################################################################
    # This is for weight-ensemble
    ########################################################################
    # if os.path.exists("previous.pt"):
    #     previous_state = torch.load("previous.pt")
    #     current = model.visual.state_dict()
    #     zero_shot = torch.load("zero_shot.pt")

    #     interpolated = {}
    #     for key in current :
    #         # interpolated[key] = INTERPO[0] * zero_shot[key] + INTERPO[1] * previous_state[key] + INTERPO[2] * current[key]
    #         # interpolated[key] = 0.5 * zero_shot[key] + 0.5 * current[key]
    #         interpolated[key] = 0.5 * previous_state[key] + 0.5 * current[key]
        
    #     model.visual.load_state_dict(interpolated)
    #     torch.save(interpolated, "previous.pt")
    # else:
    #     torch.save(model.visual.state_dict(), "previous.pt")
    ########################################################################

    ref_img, ref_label = get_ref_sample(test[count],10,count)
    ref_img, ref_txt = get_emb(model, test_data.classes,ref_img, ref_label, False)
    test_loader = DataLoader(test[count], 64, shuffle=False)
    txt_out = sub_test(count, model)
    print("Testing current task...")
    acc = 0
    total_len = 0
    for data, label in tqdm(test_loader):
        data = data.cuda()
        label = label.cuda()
        out = model.encode_image(data)
        out = out/out.norm(dim=-1, keepdim=True)
        sim = out @ txt_out.T
        ans = torch.argmax(sim, 1)
        r = (ans == label).sum()
        total_len += len(data)
        acc += r
    
    acc = acc/total_len
    print(f"Accuracy for current task is {acc}")


##############################################################################################
# This part is for normal testing with test datasets
##############################################################################################
    acc = 0
    total_len = 0
    for test_count in range(count):
        print(f"Testing task {test_count}..." )
        test_loader = DataLoader(test[test_count], 64, shuffle=False)
        for data, label in tqdm(test_loader):
            data = data.cuda()
            label = label.cuda()
            out = model.encode_image(data)
            out = out/out.norm(dim=-1, keepdim = True)
            sim = out @ txt_out.T
            ans = torch.argmax(sim, dim = 1)
            r = (ans == label).sum()
            total_len += len(data)
            acc += r
        acc = acc/total_len
        print(f"Accuracy for task {test_count} is {acc}")
###############################################################################################


###############################################################################################
# This is testing the "CLIP embedding superposition"
###############################################################################################
    # for test_count in range(count):
    #     print(f"Testing task {test_count}..." )
    #     img, label = get_ref_sample(test[test_count], num_ref_sample, test_count)
    #     _, txt = get_emb(model, test_data.classes, None, label, True)
    #     point = 0
    #     for i in tqdm(range(num_ref_sample)):
    #         sample = txt[i].unsqueeze(0)
    #         guess_img = ref_img - ref_txt + sample
    #         guess_img = torch.mean(guess_img, dim = 0)
    #         similarity = guess_img @ txt_out.T
    #         _, result = torch.topk(similarity, 5)
    #         temp = (result == label[i]).nonzero(as_tuple=True)[0]
    #         if temp.size() != 0:
    #             point += TABLE[temp[0]]/num_ref_sample
    #     print(f"The test score for task {test_count} is {point}")
################################################################################################
        

