In [None]:
from pathlib import Path
import random

import clip
import numpy as np
import torch
# from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn

import torchmetrics
import torchvision
from tqdm import tqdm

from back_save.为了测试clip的原始算法.balanced_batch_sampler import BalancedBatchSampler
import os
from collections import defaultdict
from PIL import Image

import gc
# gc.collect()


In [None]:
EPOCH=10
BATCH_SIZE=64

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu" # If using GPU then use mixed precision training.
model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training
if device == "cpu":
  model.float()
else :
  clip.model.convert_weights(model) # Actually this line is unnecessary since clip by default already on float16
def convert_models_to_fp32(model): 
    """这个一定要有,不然梯度直接爆炸"""
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
class image_title_dataset(Dataset):
    def __init__(self, list_image_path,list_txt,nidx,need_path:bool=False):

        self.image_path = list_image_path
        self.mark=False
        if len(list_txt)>1:
            self.mark=True
            self.title  = clip.tokenize(list_txt) #you can tokenize everything at once in here(slow at the beginning), or tokenize it in the training loop.
        else:
            self.title=False
        # self.title=self.title
        self.nidx=nidx
        self.npath=need_path

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx])) # Image from PIL module
        if self.mark ==False:
            title="-1"
        else:
            title = self.title[idx]
        
        if self.npath:
            ndix=os.path.basename(self.image_path[idx])
        else:
            ndix=self.nidx[idx]
        return image,title,ndix


In [None]:
def ret_class_name_dic()->dict:
    """返回动物名字到数字和数字映射到动物名的字典"""
    classes = open('data/classname.txt').read().splitlines()#这是一个包含所有类的列表
    class_name_dic_num={}
    class_name_dic_name={}
    for i in classes:
        name,idx = i.split(' ')
        c = name
        if c.startswith('Animal'):
            c = c[7:]
        if c.startswith('Thu-dog'):
            c = c[8:]
        if c.startswith('Caltech-101'):
            c = c[12:]
        if c.startswith('Food-101'):
            c = c[9:]
        if c not in class_name_dic_name:
            class_name_dic_name[c]=idx
            class_name_dic_num[idx]=c
        else:
            print(name,"already exist!!")
    return class_name_dic_name,class_name_dic_num
class_name_dic_name,class_name_dic_num=ret_class_name_dic()


In [None]:
def ret_pic_patch(num_pic=4)->dict:
    """返回每类四张,的路径和标签"""
    # num_pic=8#返回的图片数量
    r_path=[]
    r_class_num=[]
    info = open('data/train.txt').read().splitlines()
    
    class_check=0
    temp_path=[]
    temp_class=[]
    for i in info:
        path,class_num=i.split(' ')
        path="data/"+path
        if class_check==int(class_num):
            temp_path.append(path)
            temp_class.append(class_num)
        else:
            class_check=int(class_num)

            r_path+=random.sample(temp_path,min(num_pic,len(temp_path)))
            r_class_num+=random.sample(temp_class,min(num_pic,len(temp_path)))
            temp_path=[path]
            temp_class=[class_num]
    r_path+=random.sample(temp_path,min(num_pic,len(temp_path)))
    r_class_num+=random.sample(temp_class,min(num_pic,len(temp_path)))
    return r_path,r_class_num




In [None]:
def ret_test_pic_patch(num_pic=3000,list_image_path1=[]):
    """返回测试集的路径和标签,注意这里的数量是种数量"""
    # num_pic=8#返回的图片数量
    r_path=[]
    r_class_num=[]
    info = open('data/train.txt').read().splitlines()
    random.shuffle(info)
    count=0
    set_train=set(list_image_path1)
    for i in info:
        path,class_num=i.split(' ')
        path="data/"+path
        if i in set_train:
            continue
        else:
            r_path.append(path)
            r_class_num.append(class_num)
            count+=1
            if count==num_pic:
                break
    return r_path,r_class_num

In [None]:
# 训练集和测试集数据准备
list_image_path1,class_num1=ret_pic_patch(num_pic=4)#训练集每类四个
list_image_path2,class_num2=ret_test_pic_patch(num_pic=3000,list_image_path1=list_image_path1)#共计3000张图片
print("len(list_image_path2)",len(list_image_path2))


In [None]:

# list_image_path1,list_image_path2=np.array_split(list_image_path, 2)
# class_num1,class_num2=np.array_split(class_num, 2)

list_txt1 =["a photo of a "+class_name_dic_num[i] for i in class_num1]
dataset = image_title_dataset(list_image_path1,list_txt1,class_num1)
train_dataloader = DataLoader(dataset,batch_size = BATCH_SIZE,shuffle=True)

# 测试集的数据

list_txt2=["a photo of a "+class_name_dic_num[i] for i in class_num2]
test_dataset= image_title_dataset(list_image_path2,list_txt2,class_num2)
test_dataloader = DataLoader(test_dataset,batch_size = BATCH_SIZE,shuffle=True)




In [None]:
#准备loss函数和优化器
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# for p in model.transformer.parameters():
#   p.requires_grad = True
# 解冻视觉模型最后两个残差块的参数
for param in model.visual.transformer.resblocks[-2].parameters():
    param.requires_grad = True
for param in model.visual.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# 解冻文本模型最后两个残差块的参数
for param in model.transformer.resblocks[-2].parameters():
    param.requires_grad = True
for param in model.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# 解冻 ln_final 层的参数
for param in model.ln_final.parameters():
    param.requires_grad = True
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=1e-5, weight_decay=0.0001)
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) 

In [None]:
# 字符编码每一类
classes = open('data/classname.txt').read().splitlines()
new_classes = []
for c in classes:
    c = c.split(' ')[0]
    if c.startswith('Animal'):
        c = c[7:]
    if c.startswith('Thu-dog'):
        c = c[8:]
    if c.startswith('Caltech-101'):
        c = c[12:]
    if c.startswith('Food-101'):
        c = c[9:]
    c = 'a photo of ' + c
    new_classes.append(c)
print(new_classes[0:5])
text2 = clip.tokenize(new_classes).to(device)

In [None]:
best_acc=0

In [None]:
for epoch in range(EPOCH):
    print("epoch______________________________",epoch)
    total_count=0
    total_count1=0
    total_count5=0
    count_loss=0
    model.train()
    for batch in train_dataloader:
        optimizer.zero_grad()

        images,texts,idx = batch 
        images= images.to(device)
        texts = texts.to(device)
        # images = torch.stack([img for img in images], dim=0).to(device)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
        # print("loss为",total_loss)
        count_loss+=total_loss
        total_loss.backward()
        if device == "cpu":
            optimizer.step()
        else : 
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)
        # break 
    # break  
    print('训练loss为',count_loss)
    
    # if not epoch%3==0:
    #     continue
    model.eval()
    for batch in test_dataloader :
        images,texts,idx = batch 
        images= images.to(device)
        texts = texts.to(device)
        logits_per_image, logits_per_text =model(images, text2)
        text_probs=logits_per_image.softmax(dim=-1)
        for i in range(len(idx)):
            top5=text_probs[i].topk(5).indices.tolist()
            if int(idx[i]) in top5:
                total_count5+=1
                if int(idx[i])==top5[0]:
                    total_count1+=1      
        total_count+=len(idx)
      

    print(f"测试集准确率 Top-1: {total_count1 / total_count:.4f}, "
      f"测试集准确率 Top-5: {total_count5 / total_count:.4f}, "
      f"Top-1 正确个数: {total_count1}, "
      f"Top-5 正确个数: {total_count5}, "
      f"总数: {total_count}")
    acc=total_count1 / total_count
    if acc>best_acc:
        best_acc=acc
        torch.save({
            'acc': total_count1 / total_count,
            'epoch':epoch,
            'model_state_dict': model.state_dict(),
            # 'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss,
            }, f"model_checkpoint/model_best_acc.pt") #just change to your preferred folder/filename
        
    # del count_loss,total_count1,total_count,total_count5

In [None]:
# torch.save({
#         'epoch': epoch,
#         'model_state_dict': model.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'loss': total_loss,
#         }, f"model_checkpoint/model_10.pt") #just change to your preferred folder/filename


In [None]:

# 加载之前训练的模型
# model, preprocess = clip.load("ViT-B/32",device=device,jit=False) #Must set jit=False for training

checkpoint = torch.load("model_checkpoint/model_best_acc.pt")

# # Use these 3 lines if you use default model setting(not training setting) of the clip. For example, if you set context_length to 100 since your string is very long during training, then assign 100 to checkpoint['model_state_dict']["context_length"] 
# checkpoint['model_state_dict']["input_resolution"] = model.input_resolution #default is 224
# checkpoint['model_state_dict']["context_length"] = model.context_length # default is 77
print("测试集的准确率为",checkpoint["acc"],"epoch为",checkpoint["epoch"])

model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
#提交的数据集
imgs_dir = 'data/' + 'TestSetA/' 
save_path='data/result.txt'
imgs = os.listdir(imgs_dir)
save_file = open(save_path, 'w')
file_paths = [os.path.join(imgs_dir, file_name) for file_name in imgs]

test_dataset= image_title_dataset(file_paths,[],[],True)
test_dataloader = DataLoader(test_dataset,batch_size = BATCH_SIZE,shuffle=True)
model.eval()
count=0
for batch in tqdm(test_dataloader) :
    images,_,file_name = batch 
    images= images.to(device)
    logits_per_image, logits_per_text =model(images, text2)
    text_probs=logits_per_image.softmax(dim=-1)
    
    for i in range(len(file_name)):
        if len(text_probs[i])>=5:
            top5=text_probs[i].topk(5).indices.tolist()
    
        save_file.write(file_name[i] + ' ' +' '.join([str(p) for p in top5]) + '\n')  
        count+=1
        # print(file_name[i] + ' ' +' '.join([str(p) for p in top5]) + '\n')
print("写入完成,共计",count)
save_file.close()

In [None]:
import zipfile
# 压缩结果文件
zip_file_path = 'data/result.zip'
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
    zipf.write(save_path, os.path.basename(save_path))

# 删除原文件
os.remove(save_path)
print(f"{save_path} 已压缩为 {zip_file_path} 并删除原文件。")