In [57]:
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, BertTokenizer
from collections import defaultdict, Counter
import numpy as np
from nltk.translate.bleu_score import corpus_bleu

dataset='deepfashion-multimodal'
img_path = f'data/{dataset}/images'
train_json_path= f'data/{dataset}/train_captions.json'
test_json_path= f'data/{dataset}/test_captions.json'
vocab_path = f'data/{dataset}/vocab.json'

def idx_to_word(idx, vocab):#将向量转化为文本描述
    reverse_vocab = {v: k for k, v in vocab.items()}
    return reverse_vocab.get(int(idx), '<unk>')

def cap_to_wvec(vocab,cap):#将文本描述转换成向量
    cap.replace(",","")
    cap.replace(".","")
    cap=cap.split()
    res=[]
    for word in cap:
        if word in vocab.keys():
            res.append(vocab[word])
        else: #不在字典的词
            res.append(vocab['<unk>'])
    return res

def filter_cut_useless_words(sent, filterd_words):
    res=[]
    for w in sent:
        if w not in filterd_words:
            res.append(w)
        else:
            if w==155:
                return res

def get_BLEU_score(cands, refs): #获取BLEU分数
    multiple_refs = []
    for idx in range(len(refs)):
        multiple_refs.append(refs[(idx//1)*1 : (idx//1)*1+1])#每个候选文本对应cpi==1条参考文本
    bleu4 = corpus_bleu(multiple_refs, cands, weights=(0.25,0.25,0.25,0.25))
    return bleu4

def cider_d(reference_list, candidate_list, n=4):
    def count_ngrams(tokens, n):
        ngrams = []
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i+n])
            ngrams.append(ngram)
        return ngrams

    def compute_cider_d(reference_list, candidate_list, n):
        cider_d_scores = []
        for refs, cand in zip(reference_list, candidate_list):
            cider_d_score = 0.0
            for i in range(1, n + 1):
                cand_ngrams = count_ngrams(cand, i)
                ref_ngrams_list = [count_ngrams(ref, i) for ref in refs]

                total_ref_ngrams = [ngram for ref_ngrams in ref_ngrams_list for ngram in ref_ngrams]

                count_cand = 0
                count_clip = 0

                for ngram in cand_ngrams:
                    count_cand += 1
                    if ngram in total_ref_ngrams:
                        count_clip += 1

                precision = count_clip / count_cand if count_cand > 0 else 0.0
                recall = count_clip / len(total_ref_ngrams) if len(total_ref_ngrams) > 0 else 0.0

                beta = 1.0
                f_score = (1 + beta**2) * precision * recall / (beta**2 * precision + recall) if precision + recall > 0 else 0.0

                cider_d_score += f_score

            cider_d_score /= n
            cider_d_scores.append(cider_d_score)

        return cider_d_scores

    reference_tokens_list = reference_list
    candidate_tokens_list = candidate_list

    scores = compute_cider_d(reference_tokens_list, candidate_tokens_list, n)

    return np.mean(scores)

def spice(reference_list, candidate_list, idf=None, beta=3):
    def tokenize(sentence):
        return sentence.lower().split()

    def count_ngrams(tokens, n):
        ngrams = []
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i+n])
            ngrams.append(ngram)
        return ngrams

    def compute_spice_score(reference, candidate, idf, beta):
        reference_tokens = reference
        candidate_tokens = candidate

        reference_ngrams = [count_ngrams(reference_tokens, i) for i in range(1, beta + 1)]
        candidate_ngrams = [count_ngrams(candidate_tokens, i) for i in range(1, beta + 1)]

        precision_scores = []
        recall_scores = []

        for i in range(beta):
            common_ngrams = set(candidate_ngrams[i]) & set(reference_ngrams[i])

            precision = len(common_ngrams) / len(candidate_ngrams[i]) if len(candidate_ngrams[i]) > 0 else 0.0
            recall = len(common_ngrams) / len(reference_ngrams[i]) if len(reference_ngrams[i]) > 0 else 0.0

            precision_scores.append(precision)
            recall_scores.append(recall)

        precision_avg = np.mean(precision_scores)
        recall_avg = np.mean(recall_scores)

        spice_score = (precision_avg * recall_avg) / (precision_avg + recall_avg) if precision_avg + recall_avg > 0 else 0.0

        if idf:
            spice_score *= np.exp(np.sum([idf[token] for token in common_ngrams]) / len(candidate_tokens))

        return spice_score

    if idf is None:
        idf = {}

    spice_scores = []

    for reference, candidate in zip(reference_list, candidate_list):
        spice_score = compute_spice_score(reference, candidate, idf, beta)
        spice_scores.append(spice_score)

    return np.mean(spice_scores)

def wvec_to_capls(vocab,wvec):#将向量转换成文本描述
    res=[]
    for word in wvec:
        for key,value in vocab.items():
            if value==word and key not in ['<start>','<end>','<pad>','<unk>']:
                res.append(key)
    return res

def wvec_to_cap(vocab,wvec):#将向量转换成文本描述
    res=[]
    for word in wvec:
        for key,value in vocab.items():
            if value==word and key not in ['<start>','<end>','<pad>','<unk>']:
                res.append(key)
    res=" ".join(res)
    return res

def get_CIDER_D_score(vocab,cands, refs): #获得CIDER-D分数
    refs_ = [wvec_to_capls(vocab,ref) for ref in refs]
    cands_ = [wvec_to_capls(vocab,cand) for cand in cands]
    return cider_d(refs_, cands_)

def get_SPICE_score(vocab,cands, refs): #获得SPICE分数
    refs_ = [wvec_to_cap(vocab,ref) for ref in refs]
    cands_ = [wvec_to_cap(vocab,cand) for cand in cands]
    return spice(refs_, cands_)

class ImageTextDataset(Dataset):
    def __init__(self, dataset_path, vocab_path, split, captions_per_image=6, max_len=93, transform=None):

        self.split = split
        assert self.split in {'train', 'test'}
        self.cpi = captions_per_image
        self.max_len = max_len

        # 载入数据集
        with open(dataset_path, 'r') as f:
            self.data = json.load(f) #key是图片名字 value是描述
            self.data_img=list(self.data.keys())
        # 载入词典
        with open(vocab_path, 'r') as f:
            self.vocab = json.load(f)

        # PyTorch图像预处理流程
        self.transform = transform

        # Total number of datapoints
        self.dataset_size = len(self.data_img)

    def __getitem__(self, i):
        # 第i个文本描述对应第(i // captions_per_image)张图片
        img = Image.open(img_path+"/"+self.data_img[i]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        c_vec=cap_to_wvec(self.vocab,self.data[self.data_img[i]])
        #加入起始和结束标志
        c_vec = [self.vocab['<start>']] + c_vec + [self.vocab['<end>']]
        caplen = len(c_vec)
        caption = torch.LongTensor(c_vec+ [self.vocab['<pad>']] * (self.max_len + 2 - caplen))
        
        return img, caption, caplen
        
    def __len__(self):
        return self.dataset_size
    


In [59]:
from transformers import ViTModel, BertModel, BertConfig
from torch import nn
import torch

class Img2TxtModel(nn.Module):
    def __init__(self, vit_model_name, transformer_config, vocab_size):
        super(Img2TxtModel, self).__init__()
        # ViT模型作为编码器
        self.encoder = ViTModel.from_pretrained(vit_model_name)

        # Transformer解码器配置
        transformer_config = BertConfig(vocab_size=vocab_size, num_hidden_layers=1, is_decoder=True,  add_cross_attention=True)
        self.decoder = BertModel(transformer_config)

        # 预测每个词的线性层
        self.vocab_size = vocab_size
        self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)
    
    def forward(self, input_ids, decoder_input_ids, decoder_attention_mask):
        # 通过ViT编码器获取图像特征
        encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state

        # 将图像特征作为解码器的输入
        decoder_outputs = self.decoder(input_ids=decoder_input_ids, 
                                       attention_mask=decoder_attention_mask,
                                       encoder_hidden_states=encoder_outputs).last_hidden_state

        # 预测下一个词
        prediction_scores = self.fc(decoder_outputs)
        return prediction_scores

    def generate_text(self, input_ids, max_length=95, start_token_id=154):
        # 获取图像特征
        encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state

        # 初始化解码器输入为<start>标记
        decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)
        
        # 存储所有时间步的logits
        all_logits = []

        for step in range(max_length):
            # 获取解码器输出
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids, 
                encoder_hidden_states=encoder_outputs
            ).last_hidden_state

            # 预测下一个词
            next_word_logits = self.fc(decoder_outputs[:, -1, :])
            all_logits.append(next_word_logits.unsqueeze(1))
            next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)
            
            # 将预测的词添加到解码器输入中
            decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)
        
        return decoder_input_ids ,torch.cat(all_logits, dim=1)







In [3]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

def evaluate(test_loader, model):
    model.eval()  # 将模型设置为评估模式
    generated_captions = []
    actual_captions = []

    with torch.no_grad():
        for images, captions, caplens in test_loader:
            images = images.to(device)
            input_ids = images
            outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=test_dataset.vocab['<start>'])
            for i in range(outputs.shape[0]):
                # 生成字幕
                gen_caption = [idx_to_word(idx, test_dataset.vocab) for idx in outputs[i]]
                # print(gen_caption)
                # 移除 <start> 和 <end>
                if '<start>' in gen_caption:
                    gen_caption = gen_caption[1:]  # 移除第一个元素 (<start>)
                if '<end>' in gen_caption:
                    gen_caption = gen_caption[:gen_caption.index('<end>')]  # 移除 <end> 及其后面的元素
                
                generated_captions.append(' '.join(gen_caption))

                # 真实字幕
                act_caption = [idx_to_word(idx, test_dataset.vocab) for idx in captions[i]]
                # print(act_caption)
                # 移除 <start> 和 <end>
                if '<start>' in act_caption:
                    act_caption = act_caption[1:]  # 移除第一个元素 (<start>)
                if '<end>' in act_caption:
                    act_caption = act_caption[:act_caption.index('<end>')]  # 移除 <end> 及其后面的元素
                
                actual_captions.append([' '.join(act_caption)])

        # 计算BLEU分数
        bleu4 = corpus_bleu(actual_captions, generated_captions, weights=(0.25,0.25,0.25,0.25))
    model.train()
    return bleu4

transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    # 这里可以添加其他必要的转换，如归一化等
])

dataset = ImageTextDataset(train_json_path, vocab_path, split='train', transform=transform)
data_loader = DataLoader(dataset, batch_size=3, shuffle=True)

vocab_size = len(dataset.vocab)
vit_model_name = 'google/vit-base-patch16-224-in21k'
transformer_config = BertConfig()

# 初始化模型
model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab['<pad>'])

test_dataset = ImageTextDataset(test_json_path, vocab_path, split='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=True)
# 设定训练周期
num_epochs = 10
best_bleu_score = 0.0  # 初始化最高BLEU分数

# 在训练时使用"teacher forcing"来生成字幕
for epoch in range(num_epochs):
    for i, (images, captions, caplens) in enumerate(data_loader):
        # 假设您的ViT模型接受标准化的图像张量作为输入
        images = images.to(device)
        captions = captions.to(device)
        input_ids = images

        # 准备解码器输入
        decoder_input_ids = captions[:, :-1]  # 删除每个字幕的最后一个单词
        decoder_attention_mask = (decoder_input_ids != dataset.vocab['<pad>']).type(torch.uint8)
        
        # 前向传播
        outputs = model(input_ids, decoder_input_ids, decoder_attention_mask)

        # 计算损失，outputs需要调整以适配损失函数的要求
        loss = criterion(outputs.view(-1, outputs.size(-1)), captions[:, 1:].contiguous().view(-1))

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {loss.item()}")
        
        if (i + 1) % 900 == 0:
            bleu4 = evaluate(test_loader, model)
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {loss.item()}, BLEU Score: {bleu4}")

            # 如果BLEU分数是新的最高分，则保存模型
            if bleu4 > best_bleu_score:
                best_bleu_score = bleu4
                save_path = f"model/best_model_epoch_{epoch+1}_batch_{i+1}.pth"
                torch.save({
                    'epoch': epoch,
                    'batch': i,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'bleu_score': bleu4,
                }, save_path)
                print(f"New best model saved to {save_path} with BLEU score {bleu4}")






Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU
Epoch [1/10], Batch [100/3385], Loss: 1.506847620010376
Epoch [1/10], Batch [200/3385], Loss: 1.0654538869857788
Epoch [1/10], Batch [300/3385], Loss: 0.720127284526825
Epoch [1/10], Batch [400/3385], Loss: 0.7672491669654846
Epoch [1/10], Batch [500/3385], Loss: 0.6676333546638489
Epoch [1/10], Batch [600/3385], Loss: 0.7424653172492981
Epoch [1/10], Batch [700/3385], Loss: 0.526539146900177
Epoch [1/10], Batch [800/3385], Loss: 0.5412440896034241
Epoch [1/10], Batch [900/3385], Loss: 0.5988483428955078
Epoch [1/10], Batch [900/3385], Loss: 0.5988483428955078, BLEU Score: 0.6259826859669076
New best model saved to model/best_model_epoch_1_batch_900.pth with BLEU score 0.6259826859669076
Epoch [1/10], Batch [1000/3385], Loss: 0.6830837726593018
Epoch [1/10], Batch [1100/3385], Loss: 0.6536044478416443
Epoch [1/10], Batch [1200/3385], Loss: 0.6345134973526001
Epoch [1/10], Batch [1300/3385], Loss: 0.6075872182846069
Epoch [1/10], Batch [1400

In [61]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms


if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    # 这里可以添加其他必要的转换，如归一化等
])
test_dataset = ImageTextDataset(test_json_path, vocab_path, split='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
vocab_size = len(test_dataset.vocab)
vit_model_name = 'google/vit-base-patch16-224-in21k'
transformer_config = BertConfig()
model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)
# 加载模型状态字典
checkpoint = torch.load('./model/best_model_epoch_10_batch_2700.pth')

# 将状态字典应用到模型实例中
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

model.eval()  # 将模型设置为评估模式

generated_captions = []
actual_captions = []
cands = []
refs = []
filterd_words = set({test_dataset.vocab['<start>'], test_dataset.vocab['<end>'], test_dataset.vocab['<pad>']})
with torch.no_grad():
    for images, captions, caplens in test_loader:
        images = images.to(device)
        input_ids = images
        outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=test_dataset.vocab['<start>'])
        for i in range(outputs.shape[0]):
            gen_caption = [idx_to_word(idx, test_dataset.vocab) for idx in outputs[i]]
            if '<start>' in gen_caption:
                gen_caption = gen_caption[1:]  # 移除第一个元素 (<start>)
            if '<end>' in gen_caption:
                gen_caption = gen_caption[:gen_caption.index('<end>')]  # 移除 <end> 及其后面的元素
            generated_captions.append(' '.join(gen_caption))
            act_caption = [idx_to_word(idx, test_dataset.vocab) for idx in captions[i]]
            # 移除 <start> 和 <end>
            if '<start>' in act_caption:
                act_caption = act_caption[1:]  # 移除第一个元素 (<start>)
            if '<end>' in act_caption:
                act_caption = act_caption[:act_caption.index('<end>')]  # 移除 <end> 及其后面的元素
            
            actual_captions.append([' '.join(act_caption)])
        texts=outputs
        cands.extend([filter_cut_useless_words(text, filterd_words) for text in texts.tolist()])
            # 参考文本
        refs.extend([filter_cut_useless_words(cap, filterd_words) for cap in captions.tolist()])
    
    bleu4 = corpus_bleu(actual_captions, generated_captions, weights=(0.25,0.25,0.25,0.25))
    
    bleu4_token=get_BLEU_score(cands, refs)

    cider_d_score=get_CIDER_D_score(test_dataset.vocab,refs, cands)
    
    spice_score=get_SPICE_score(test_dataset.vocab,refs, cands)
    

    print(f"bleu4 :{bleu4}")
    print(f"bleu4_token :{bleu4_token}")
    print(f"cider_d_score :{cider_d_score}")
    print(f"spice_score :{spice_score}")

Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU
bleu4 :0.6640697619027633
bleu4_token :0.3074786561430062
cider_d_score :0.004119164250591897
spice_score :0.1321016861247424
