In [1]:
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, BertTokenizer
from collections import defaultdict, Counter

dataset='deepfashion-multimodal'
img_path = f'data/{dataset}/images'
train_json_path= f'data/{dataset}/train_captions.json'
test_json_path= f'data/{dataset}/test_captions.json'
vocab_path = f'data/{dataset}/vocab.json'

def idx_to_word(idx, vocab):
    reverse_vocab = {v: k for k, v in vocab.items()}
    return reverse_vocab.get(int(idx), '<unk>')

def cap_to_wvec(vocab,cap):#将文本描述转换成向量
    cap.replace(",","")
    cap.replace(".","")
    cap=cap.split()
    res=[]
    for word in cap:
        if word in vocab.keys():
            res.append(vocab[word])
        else: #不在字典的词
            res.append(vocab['<unk>'])
    return res

class ImageTextDataset(Dataset):
    def __init__(self, dataset_path, vocab_path, split, captions_per_image=1, max_len=93, transform=None):

        self.split = split
        assert self.split in {'train', 'test'}
        self.cpi = captions_per_image
        self.max_len = max_len

        # 载入数据集
        with open(dataset_path, 'r') as f:
            self.data = json.load(f) #key是图片名字 value是描述
            self.data_img=list(self.data.keys())
        # 载入词典
        with open(vocab_path, 'r') as f:
            self.vocab = json.load(f)

        # PyTorch图像预处理流程
        self.transform = transform

        # Total number of datapoints
        self.dataset_size = len(self.data_img)

    def __getitem__(self, i):
        # 第i个文本描述对应第(i // captions_per_image)张图片
        img = Image.open(img_path+"/"+self.data_img[i]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        c_vec=cap_to_wvec(self.vocab,self.data[self.data_img[i]])
        #加入起始和结束标志
        c_vec = [self.vocab['<start>']] + c_vec + [self.vocab['<end>']]
        caplen = len(c_vec)
        caption = torch.LongTensor(c_vec+ [self.vocab['<pad>']] * (self.max_len + 2 - caplen))
        
        return img, caption, caplen
        
    def __len__(self):
        return self.dataset_size
    
def generate_sentence(model_output, idx_to_word):
    # 选择概率最高的词汇索引
    predicted_indices = torch.argmax(model_output, dim=-1)
    
    # 将索引转换为单词
    generated_words = [idx_to_word[idx.item()] for idx in predicted_indices]
    
    # 连接单词生成句子
    sentence = ' '.join(generated_words)
    return sentence



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import ViTModel, BertModel, BertConfig
from torch import nn
import torch
from torch.nn import TransformerDecoder, TransformerDecoderLayer
import math

class Img2TxtModel(nn.Module):
    def __init__(self, vit_model_name, transformer_config, vocab_size):
        super(Img2TxtModel, self).__init__()
        # ViT模型作为编码器
        self.encoder = ViTModel.from_pretrained(vit_model_name)

        # Transformer解码器配置
        transformer_config = BertConfig(vocab_size=vocab_size, is_decoder=True,  add_cross_attention=True)
        self.decoder = BertModel(transformer_config)

        # 预测每个词的线性层
        self.vocab_size = vocab_size
        self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)

    def forward(self, input_ids, max_length=95, start_token_id=154):
        # 获取图像特征
        encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state

        # 初始化解码器输入为<start>标记
        decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)
        
        # 存储所有时间步的logits
        all_logits = []

        for step in range(max_length):
            # 获取解码器输出
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids, 
                encoder_hidden_states=encoder_outputs
            ).last_hidden_state

            # 预测下一个词
            next_word_logits = self.fc(decoder_outputs[:, -1, :])
            all_logits.append(next_word_logits.unsqueeze(1))
            next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)
            
            # 将预测的词添加到解码器输入中
            decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)
        
        return decoder_input_ids ,torch.cat(all_logits, dim=1)






In [3]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from nltk.translate.bleu_score import corpus_bleu

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

def evaluate(test_loader, model):
    model.eval()  # 将模型设置为评估模式
    generated_captions = []
    actual_captions = []

    with torch.no_grad():
        for images, captions, caplens in test_loader:
            input_ids = images
            outputs = model(input_ids, start_token_id = test_dataset.vocab['<start>'])
            outputs = nn.functional.softmax(outputs, dim=-1)
            predicted_indices = outputs.argmax(dim=-1)
            for i in range(predicted_indices.shape[0]):
                # 生成字幕
                gen_caption = [idx_to_word(idx, test_dataset.vocab) for idx in predicted_indices[i]]
                print(gen_caption)
                # 移除 <start> 和 <end>
                if '<start>' in gen_caption:
                    gen_caption = gen_caption[1:]  # 移除第一个元素 (<start>)
                if '<end>' in gen_caption:
                    gen_caption = gen_caption[:gen_caption.index('<end>')]  # 移除 <end> 及其后面的元素
                
                generated_captions.append(' '.join(gen_caption))

                # 真实字幕
                act_caption = [idx_to_word(idx, test_dataset.vocab) for idx in captions[i]]
                # print(act_caption)
                # 移除 <start> 和 <end>
                if '<start>' in act_caption:
                    act_caption = act_caption[1:]  # 移除第一个元素 (<start>)
                if '<end>' in act_caption:
                    act_caption = act_caption[:act_caption.index('<end>')]  # 移除 <end> 及其后面的元素
                
                actual_captions.append([' '.join(act_caption)])

        # 计算BLEU分数
        bleu4 = corpus_bleu(actual_captions, generated_captions, weights=(0.25,0.25,0.25,0.25))
        model.train()
    return bleu4

transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    # 这里可以添加其他必要的转换，如归一化等
])

# 假设您已经定义了dataset_path和vocab_path
dataset = ImageTextDataset(train_json_path, vocab_path, split='train', transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

vocab_size = len(dataset.vocab)
vit_model_name = 'google/vit-base-patch16-224-in21k'
transformer_config = BertConfig()

# 初始化模型
model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab['<pad>'])

test_dataset = ImageTextDataset(test_json_path, vocab_path, split='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
# 设定训练周期
num_epochs = 10
best_bleu_score = 0.0  # 初始化最高BLEU分数

for epoch in range(num_epochs):
    for i, (images, captions, caplens) in enumerate(data_loader):
        # 假设您的ViT模型接受标准化的图像张量作为输入
        images = images.to(device)
        captions = captions.to(device)
        input_ids = images

        # 前向传播
        outputs, outputs_logits = model(input_ids, start_token_id = test_dataset.vocab['<start>'])

        outputs_logits = outputs_logits.permute(0, 2, 1)
        # 计算损失，outputs需要调整以适配损失函数的要求
        loss = criterion(outputs_logits, captions)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if(i+1) % 1 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {loss.item()}")
        
        if (i + 1) % 10 == 0:
            bleu4 = evaluate(test_loader, model)
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {loss.item()}, BLEU Score: {bleu4}")

            # 如果BLEU分数是新的最高分，则保存模型
            if bleu4 > best_bleu_score:
                best_bleu_score = bleu4
                save_path = f"model/best_model_epoch_{epoch+1}_batch_{i+1}.pth"
                torch.save({
                    'epoch': epoch,
                    'batch': i,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    'bleu_score': bleu4,
                }, save_path)
                print(f"New best model saved to {save_path} with BLEU score {bleu4}")





Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU
Epoch [1/10], Batch [1/10155], Loss: 5.196713924407959


In [None]:
model.eval()  # 将模型设置为评估模式
generated_captions = []
actual_captions = []

test_dataset = ImageTextDataset(test_json_path, vocab_path, split='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

def idx_to_word(idx, vocab):
    reverse_vocab = {v: k for k, v in vocab.items()}
    return reverse_vocab.get(idx, '<unk>')

with torch.no_grad():
    for images, captions, caplens in test_loader:
        input_ids = images
        decoder_input_ids = captions
        decoder_attention_mask = (captions != test_dataset.vocab['<pad>']).type(torch.uint8)

        # 前向传播，生成字幕
        outputs = model(input_ids, decoder_input_ids, decoder_attention_mask)

        # 选择概率最高的词汇
        outputs = torch.nn.functional.softmax(outputs, dim=-1)
        predicted_indices = outputs.argmax(dim=-1)
        
        for i in range(predicted_indices.shape[0]):
            # 生成字幕
            gen_caption = [idx_to_word(idx, test_dataset.vocab) for idx in predicted_indices[i]]
            # 移除 <start> 和 <end>
            if '<start>' in gen_caption:
                gen_caption = gen_caption[1:]  # 移除第一个元素 (<start>)
            if '<end>' in gen_caption:
                gen_caption = gen_caption[:gen_caption.index('<end>')]  # 移除 <end> 及其后面的元素
            generated_captions.append(' '.join(gen_caption))

            # 真实字幕
            act_caption = [idx_to_word(idx, test_dataset.vocab) for idx in captions[i]]
            # 移除 <start> 和 <end>
            if '<start>' in act_caption:
                act_caption = act_caption[1:]  # 移除第一个元素 (<start>)
            if '<end>' in act_caption:
                act_caption = act_caption[:act_caption.index('<end>')]  # 移除 <end> 及其后面的元素
            actual_captions.append([' '.join(act_caption)])

# 计算BLEU分数
bleu_score = corpus_bleu(actual_captions, generated_captions)
print(f"BLEU score on test dataset: {bleu_score}")

: 

: 