In [None]:
import jittor as jt
from PIL import Image
import jclip as clip
import os
from tqdm import tqdm
import argparse
from sklearn.linear_model import LogisticRegression
import numpy as np
import random
# from colorama import Fore, Back, Style, init
# init()

jt.flags.use_cuda = 1
print("包导入成功")
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, default='A')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()
model, preprocess = clip.load("../data/ViT-B-32.pkl")
imgs_dir = '../data'

In [None]:

# encode这块后面可以强化下
def encode_pre_word(c)->str:
    """更具输入的名字返回一句话"""
    seq = 'a photo of ' + c
    return seq

def ret_class_name_dic()->dict:
    """返回数字映射到动物名字的字典"""
    classes = open('../data/classname.txt').read().splitlines()#这是一个包含所有类的列表
    class_name_dic={}#这是数字映射到动物名字的字典
    for i in classes:
        name,idx = i.split(' ')
        if idx not in class_name_dic:
            class_name_dic[idx]=name
    return class_name_dic

# class_name_dic=ret_class_name_dic()

In [None]:
#训练的数据处理
def train_data()->list:
    """返回元素为玩意的列表['TrainSet/Animal/Bee/57.jpg', '1'] 每类四个的打乱列表"""
    train_labels = open('../data/train.txt').read().splitlines()
    train_data_dic={}#每类的图片的字典
    for i in train_labels:
        path,class_name=i.split(' ')
        if class_name in train_data_dic:
            train_data_dic[class_name].append([path,class_name])
        else:
            train_data_dic[class_name]=[]
    # 训练集的要训练的每类四张的列表
    ret_list=[]#用于返回每类四张的列表
    for i in train_data_dic:
        ret_list+=random.sample(train_data_dic[i],4)
    random.shuffle(ret_list)
    return ret_list# 内存不够再优化

# train_data=train_data()
#测试数据列表返回
def test_data(train_data:list)->list:
    """返回元素为玩意的列表['TrainSet/Animal/Bee/57.jpg 1] 共计3000个用于测试
    同时剔除训练集中的元素"""
    set1=set(train_data)
    train_labels = open('../data/train.txt').read().splitlines()
    result = [item for item in train_labels if item not in set1] 
    return random.sample(result,3000)



In [None]:
# 把句子和图片进行encoding
train_img_features = []
train_word_features=[]
count=0
with jt.no_grad():
    class_name_dic=ret_class_name_dic()
    for info in tqdm(train_data()):

        img,indx=info
        img = os.path.join(imgs_dir, img)
        image = Image.open(img)
        image = preprocess(image).unsqueeze(0)
        image_features = model.encode_image(image)
        # print("能成功运行?","image_features的shape是",image_features.shape)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        train_img_features.append(image_features)

        a_seq=encode_pre_word(class_name_dic[indx])#转为句子
        token=clip.tokenize(a_seq)
        
        text_features=model.encode_text(token)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        # print("text_features的shape为",text_features.shape)
        train_word_features.append(text_features)
        count+=1
        if count==5 :
            break

train_features = jt.cat(train_img_features).numpy()#(1496, 512)
train_labels = jt.cat(train_word_features).numpy()#(1496,)

In [None]:
print(train_features.shape)
print(train_labels.shape)

## 把原模型的最后一层给改了,然后最后两层微调

In [None]:
# 冻结所有模型参数
for param in model.parameters():
    param.requires_grad = False

# 解冻视觉模型最后两个残差块的参数
for param in model.visual.transformer.resblocks[-2].parameters():
    param.requires_grad = True
for param in model.visual.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# 解冻文本模型最后两个残差块的参数
for param in model.transformer.resblocks[-2].parameters():
    param.requires_grad = True
for param in model.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# 解冻 ln_final 层的参数
for param in model.ln_final.parameters():
    param.requires_grad = True

In [None]:
# 定义评分函数 s(image, text)
def score_function(image_embedding, text_embedding):
    # 使用余弦相似度作为评分函数
    return jt.matmul(image_embedding, text_embedding.transpose(1, 0))

# 定义 CLIP 损失函数
def clip_loss(image_embeddings, text_embeddings, temperature=1.0):
    batch_size = image_embeddings.shape[0]
    
    # 计算所有文本描述与图像之间的评分
    scores = score_function(image_embeddings, text_embeddings)
    
    # 计算对比损失函数
    logits = scores / temperature
    logits_max, _ = jt.max(logits, dim=1, keepdims=True)
    logits = logits - logits_max.detach()  # 避免数值不稳定性
    exp_logits = jt.exp(logits)
    softmax_probs = exp_logits / jt.sum(exp_logits, axis=1, keepdims=True)
    # 对角线位置的 softmax 概率即为对应的文本描述与图像匹配的概率
    correct_probs = jt.diag(softmax_probs)
    # 计算对比损失
    loss = -jt.log(correct_probs + 1e-6).mean()
    return loss


In [None]:
optimizer = jt.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)