In [30]:
import json
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, BertTokenizer
from collections import defaultdict, Counter
import numpy as np


dataset='deepfashion-multimodal'
img_path = f'data/{dataset}/test_image'
vocab_path = f'data/{dataset}/vocab.json'

def idx_to_word(idx, vocab):#将向量转化为文本描述
    reverse_vocab = {v: k for k, v in vocab.items()}
    return reverse_vocab.get(int(idx), '<unk>')

class CustomImageDataset(Dataset):
    def __init__(self, img_folder, transform=None):
        self.img_folder = img_folder
        self.img_names = [img for img in os.listdir(img_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]
        print(len(self.img_names))
        print(self.img_names[0])
        self.transform = transform

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_folder, self.img_names[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        return image, self.img_names[idx]




In [31]:
from transformers import ViTModel, BertModel, BertConfig
from torch import nn
import torch

class Img2TxtModel(nn.Module):
    def __init__(self, vit_model_name, transformer_config, vocab_size):
        super(Img2TxtModel, self).__init__()
        # ViT模型作为编码器
        self.encoder = ViTModel.from_pretrained(vit_model_name)

        # Transformer解码器配置
        transformer_config = BertConfig(vocab_size=vocab_size, num_hidden_layers=1, is_decoder=True,  add_cross_attention=True)
        self.decoder = BertModel(transformer_config)

        # 预测每个词的线性层
        self.vocab_size = vocab_size
        self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)
    
    def forward(self, input_ids, decoder_input_ids, decoder_attention_mask):
        # 通过ViT编码器获取图像特征
        encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state

        # 将图像特征作为解码器的输入
        decoder_outputs = self.decoder(input_ids=decoder_input_ids, 
                                       attention_mask=decoder_attention_mask,
                                       encoder_hidden_states=encoder_outputs).last_hidden_state

        # 预测下一个词
        prediction_scores = self.fc(decoder_outputs)
        return prediction_scores

    def generate_text(self, input_ids, max_length=95, start_token_id=154):
        # 获取图像特征
        encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state

        # 初始化解码器输入为<start>标记
        decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)
        
        # 存储所有时间步的logits
        all_logits = []

        for step in range(max_length):
            # 获取解码器输出
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids, 
                encoder_hidden_states=encoder_outputs
            ).last_hidden_state

            # 预测下一个词
            next_word_logits = self.fc(decoder_outputs[:, -1, :])
            all_logits.append(next_word_logits.unsqueeze(1))
            next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)
            
            # 将预测的词添加到解码器输入中
            decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)
        
        return decoder_input_ids ,torch.cat(all_logits, dim=1)



In [32]:
from transformers import ViTModel, BertModel, BertConfig
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # 根据需要添加更多的转换
])

# 创建 Dataset 实例
dataset = CustomImageDataset(img_folder=img_path, transform=transform)

# 创建 DataLoader
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

with open(vocab_path, 'r') as f:
    vocab = json.load(f)

vocab_size = len(vocab)
vit_model_name = 'google/vit-base-patch16-224-in21k'
transformer_config = BertConfig()

model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)
# 加载模型状态字典
checkpoint = torch.load('./model/best_model_epoch_10_batch_2700.pth')


# 将状态字典应用到模型实例中
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

model.eval()  # 将模型设置为评估模式

generated_captions_dict = {}

with torch.no_grad():
    for images, name in data_loader:
        images = images.to(device)
        input_ids = images
        outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=vocab['<start>'])
        for i in range(outputs.shape[0]):
            gen_caption = [idx_to_word(idx, vocab) for idx in outputs[i]]
            if '<start>' in gen_caption:
                gen_caption = gen_caption[1:]  # 移除第一个元素 (<start>)
            if '<end>' in gen_caption:
                gen_caption = gen_caption[:gen_caption.index('<end>')]  # 移除 <end> 及其后面的元素

            caption_text = ' '.join(gen_caption)
            generated_captions_dict[name[0]] = caption_text
    print(generated_captions_dict)

Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU
5
MEN-Sweaters-id_00000702-06_7_additional.jpg
{'MEN-Sweaters-id_00000702-06_7_additional.jpg': 'The person is wearing a short-sleeve shirt with graphic patterns. The shirt is with cotton fabric. It has a crew neckline. The person wears a three-point shorts. The shorts are with denim fabric and pure color patterns. There is an accessory on her wrist. There is a ring on her finger.', 'MEN-Sweatshirts_Hoodies-id_00000911-01_4_full.jpg': 'The person is wearing a tank tank top with graphic patterns. The tank top is with cotton fabric. It has a suspenders neckline. The person wears a long trousers. The trousers are with cotton fabric and graphic patterns. There is an accessory on her wrist. There is a ring on her finger.', 'WOMEN-Pants-id_00005000-06_1_front.jpg': 'The person is wearing a tank tank top with graphic patterns. The tank top is with cotton fabric. It has a suspenders neckline. The person wears a long trousers. The trousers are with 