In [None]:
import jittor as jt
from PIL import Image
import jclip as clip
import os
from tqdm import tqdm
import argparse
import numpy as np
import random

jt.flags.use_cuda = 1
print("包导入成功")

parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, default='A')
args, unknown = parser.parse_known_args()

model, preprocess = clip.load("../data/ViT-B-32.pkl")
imgs_dir = '../data'

def encode_pre_word(c) -> str:
    """更具输入的名字返回一句话"""
    seq = 'a photo of ' + c
    return seq

def ret_class_name_dic() -> dict:
    """返回数字映射到动物名字的字典"""
    classes = open('../data/classname.txt').read().splitlines() # 这是一个包含所有类的列表
    class_name_dic = {} # 这是数字映射到动物名字的字典
    for i in classes:
        name, idx = i.split(' ')
        if idx not in class_name_dic:
            class_name_dic[idx] = name
    return class_name_dic

def train_data() -> list:
    """返回元素为玩意的列表['TrainSet/Animal/Bee/57.jpg', '1'] 每类四个的打乱列表"""
    train_labels = open('../data/train.txt').read().splitlines()
    train_data_dic = {} # 每类的图片的字典
    for i in train_labels:
        path, class_name = i.split(' ')
        if class_name in train_data_dic:
            train_data_dic[class_name].append([path, class_name])
        else:
            train_data_dic[class_name] = [[path, class_name]]
    # 训练集的要训练的每类四张的列表
    ret_list = [] # 用于返回每类四张的列表
    for i in train_data_dic:
        ret_list += random.sample(train_data_dic[i], 4)
    random.shuffle(ret_list)
    return ret_list # 内存不够再优化

def test_data(train_data: list) -> list:
    """返回元素为玩意的列表['TrainSet/Animal/Bee/57.jpg 1] 共计3000个用于测试
    同时剔除训练集中的元素"""
    set1 = set(map(tuple, train_data))
    train_labels = open('../data/train.txt').read().splitlines()
    result = [item for item in train_labels if tuple(item.split(' ')) not in set1] 
    return random.sample(result, 3000)

train_img_features = []
train_word_features = []
count = 0
class_name_dic = ret_class_name_dic()

with jt.no_grad():
    for info in tqdm(train_data()):
        img, indx = info
        img = os.path.join(imgs_dir, img)
        image = Image.open(img)
        image = preprocess(image).unsqueeze(0)
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        train_img_features.append(image_features)

        a_seq = encode_pre_word(class_name_dic[indx]) # 转为句子
        token = clip.tokenize(a_seq)
        text_features = model.encode_text(token)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        train_word_features.append(text_features)
        count += 1
        if count == 5:
            break

train_img_features = jt.concat(train_img_features, dim=0)
train_word_features = jt.concat(train_word_features, dim=0)

for param in model.parameters():
    param.requires_grad = False

# 解冻视觉模型最后两个残差块的参数
for param in model.visual.transformer.resblocks[-2].parameters():
    param.requires_grad = True
for param in model.visual.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# 解冻文本模型最后两个残差块的参数
for param in model.transformer.resblocks[-2].parameters():
    param.requires_grad = True
for param in model.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# 解冻 ln_final 层的参数
for param in model.ln_final.parameters():
    param.requires_grad = True

def score_function(image_embedding, text_embedding):
    # 使用余弦相似度作为评分函数
    return jt.matmul(image_embedding, text_embedding.transpose(1, 0))

# 定义 CLIP 损失函数
def clip_loss(image_embeddings, text_embeddings, temperature=1.0):
    batch_size = image_embeddings.shape[0]
    
    # 计算图像和文本之间的相似度得分
    logits_per_image = score_function(image_embeddings, text_embeddings) / temperature
    logits_per_text = logits_per_image.transpose(1, 0)
    
    # 创建标签
    ground_truth = jt.arange(batch_size, dtype=jt.int64)
    
    # 计算交叉熵损失
    loss_img = jt.nn.cross_entropy_loss(logits_per_image, ground_truth)
    loss_text = jt.nn.cross_entropy_loss(logits_per_text, ground_truth)
    
    return (loss_img + loss_text) / 2

# 将 filter 生成器转换为列表
params_to_optimize = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = jt.optim.AdamW(params_to_optimize, lr=1e-5)

# Training loop
num_epochs = 10
batch_size = 32
num_batches = len(train_img_features) // batch_size

for epoch in range(num_epochs):
    epoch_loss = 0
    for i in range(num_batches):
        batch_img_features = train_img_features[i * batch_size: (i + 1) * batch_size]
        batch_text_features = train_word_features[i * batch_size: (i + 1) * batch_size]
        
        optimizer.zero_grad()
        
        # Forward pass
        image_features = model.encode_image(batch_img_features)
        text_features = model.encode_text(batch_text_features)
        
        # Calculate loss
        loss = clip_loss(image_features, text_features)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/num_batches:.4f}")

print("微调训练完成")


In [None]:
classes = open('../data/classname.txt').read().splitlines()

# remove the prefix Animal, Thu-dog, Caltech-101, Food-101

new_classes = []
for c in classes:
    c = c.split(' ')[0]
    if c.startswith('Animal'):
        c = c[7:]
    if c.startswith('Thu-dog'):
        c = c[8:]
    if c.startswith('Caltech-101'):
        c = c[12:]
    if c.startswith('Food-101'):
        c = c[9:]
    c = 'a photo of ' + c
    new_classes.append(c)

text = clip.tokenize(new_classes)
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)

In [None]:
split = 'TestSetA' 

imgs_dir = '../data/' + split
imgs = os.listdir(imgs_dir)

save_file = open('result.txt', 'w')

preds = []
with jt.no_grad():
    for img in tqdm(imgs):
        img_path = os.path.join(imgs_dir, img)
        image = Image.open(img_path)
        image = preprocess(image).unsqueeze(0)
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_probs = (100.0 *
                      image_features @ text_features.transpose(0, 1)).softmax(
                          dim=-1)
        # top5 predictions
        _, top_labels = text_probs[0].topk(5)
        preds.append(top_labels)
        # save top5 predictions to file
        save_file.write(img + ' ' +
                        ' '.join([str(p.item()) for p in top_labels]) + '\n')

In [None]:
test_datas=test_data()
with jt.no_grad():
    count_ac=0
    total=0
    for info in tqdm(test_datas):
        img_path,class_name = info
        image = Image.open(img_path)
        image = preprocess(image).unsqueeze(0)
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_probs = (100.0 *
                      image_features @ text_features.transpose(0, 1)).softmax(
                          dim=-1)
        # top5 predictions
        _, top_labels = text_probs[0].topk(1)
        preds.append(top_labels)
        # save top5 predictions to file
        save_file.write(img + ' ' +
                        ' '.join([str(p.item()) for p in top_labels]) + '\n')