### 一、导入模块

In [112]:
import os
import json

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from tqdm import tqdm
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from transformers import GPT2Model, GPT2Config
from PIL import Image
import spacy

### 二、加载数据

In [113]:
class ImageTextDataset(Dataset):
    def __init__(self, root_dir, vocabulary, max_seq_len, transform=None, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        label_path = os.path.join(root_dir, f'{"train" if train else "test"}_captions.json')
        with open(label_path, 'rb') as fp:
            self.filenames, self.labels = zip(*json.load(fp).items())
        
        self.root_dir = root_dir
        self.vocabulary = vocabulary  # 词汇表字典，{'word1': 0, 'word2': 1, ...}
        self.max_seq_len = max_seq_len
        self.transform = transform
        self.nlp = spacy.load("en_core_web_sm")

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image_path = os.path.join(self.root_dir, 'images', self.filenames[idx])
        image = Image.open(image_path)
        # print(image.size)  # (750, 1101)
        if self.transform:
            image = self.transform(image)
            
        text_label = self.labels[idx]
        doc = self.nlp(text_label)
        
        # 将str 类型的 text label转换到可供神经网络训练的序列张量
        PAD = self.vocabulary['<pad>']
        START = self.vocabulary['<start>']
        END = self.vocabulary['<end>']
        UNK = self.vocabulary['<unk>']
        indices = [START] + [self.vocabulary.get(token.text.lower(), UNK) for token in doc] + [END]
        if len(indices) < self.max_seq_len:
            indices += [PAD] * (self.max_seq_len - len(indices))
        else: 
            indices = indices[:self.max_seq_len]
            
        seq = torch.LongTensor(np.array(indices))  # 前面添加起始符，结尾添加结束符，后面用PAD补齐
        
        return image, seq

### 三、搭建模型

In [114]:
class ImageTextModel(nn.Module):
    def __init__(self, vocabulary_size, max_seq_len):
        super(ImageTextModel, self).__init__()
        # 图像特征提取器
        self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.cnn.fc = nn.Identity()  # 移除最后的全连接层

        # 文本生成器配置
        transformer_config = GPT2Config(vocab_size=vocabulary_size, n_positions=max_seq_len)
        self.transformer = GPT2Model(transformer_config)

        # 嵌入层
        self.embedding = nn.Embedding(vocabulary_size, transformer_config.n_embd)

    def forward(self, images, seq):
        image_features = self.cnn(images)  # 提取图像特征
        embedded_seq = self.embedding(seq)  # 嵌入文本序列
        transformer_output = self.transformer(inputs_embeds=embedded_seq).last_hidden_state  # Transformer 生成文本
        return image_features, transformer_output

In [115]:
# 训练配置
root_dir = 'data/deepfashion-multimodal'
epochs = 200
image_size = (750, 1101)
lr = 5e-4
batch_size = 32

hidden_dim = 512  # 隐藏层维度
num_heads = 8  # 注意力头数
num_layers = 4  # Transformer层数
max_seq_length = 128  # 序列最大长度

In [121]:
# 加载vocabulary
with open(os.path.join(root_dir, 'vocab.json'), 'rb') as fp:
    vocabulary = json.load(fp)

vocabulary_size = len(vocabulary)  # 词汇表大小
os.makedirs('history', exist_ok=True)
os.makedirs('models', exist_ok=True)
history_path = f'history/history.npy'
model_path = f'models/model.pth'
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')
device

device(type='cpu')

In [122]:
# 创建数据集
transform = Compose([
    Resize(image_size),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_set = ImageTextDataset(root_dir=root_dir, vocabulary=vocabulary, max_seq_len=max_seq_length, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

# 模型初始化
model = ImageTextModel(vocabulary_size, max_seq_length).to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

running_losses = list(np.load(history_path)) if os.path.exists(history_path) else []
try:
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
except ValueError:
    running_losses = []

# 训练循环
p_bar = tqdm(range(epochs))
for epoch in p_bar:
    running_loss = 0.0
    for batch_idx, (image, seq) in enumerate(train_loader):
        image = image.to(device)  # image: (batch, 3, 750, 1101)
        seq = seq.to(device)  # seq: (batch, seq_len)
        
        optimizer.zero_grad()
        _, transformer_output = model(image, seq)
        loss = criterion(transformer_output.view(-1, vocabulary_size), seq.view(-1))
        loss.backward()
        optimizer.step()
        running_loss += 1 / (batch_idx + 1) * (loss.item() - running_loss)

        p_bar.set_postfix(progress=f'{(batch_idx + 1) / len(train_loader) * 100:.3f}%', loss=f'{loss.item():.4f}')
        
    running_losses.append(running_loss)
    torch.save(model.state_dict(), model_path)
    np.save(history_path, np.array(running_losses))

  0%|          | 0/200 [01:22<?, ?it/s]


KeyboardInterrupt: 

In [None]:
train_set[600]