In [8]:
import torch
import torch.nn as nn
import torchvision.models as models
from vocab import Vocabulary


class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet101(weights="IMAGENET1K_V1")
        modules = list(resnet.children())[:-1]  # Remove FC layer
        self.resnet = nn.Sequential(*modules)
        for param in self.resnet.parameters():
            param.requires_grad = False  # Freeze ResNet
        self.linear = nn.Linear(2048, embed_size)  # ResNet-101 feature size
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.1)

    def forward(self, images):
        features = self.resnet(images)  # (batch_size, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # (batch_size, 2048)
        features = self.bn(self.linear(features))  # (batch_size, embed_size)
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, features, captions):
        # features: (batch_size, embed_size)
        # captions: (batch_size, max_len)
        # lengths: List or tensor of sequence lengths for each batch item
        embeddings = self.dropout(self.embed(captions))  # (batch_size, max_len, embed_size)
        features = features.unsqueeze(1)  # (batch_size, 1, embed_size)
        inputs = torch.cat((features, embeddings), dim=1)  # (batch_size, max_len + 1, embed_size)
        output, _ = self.lstm(inputs)
        outputs = self.linear(output.data)

        return outputs

    def sample(self, features, max_len=25):
        sample_ids = []
        inputs = features.unsqueeze(1)  # (batch_size, 1, embed_size)
        states = None
        for _ in range(max_len):
            hiddens, states = self.lstm(inputs, states)  # (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))  # (batch_size, vocab_size)
            _, predicted = outputs.max(1)  # (batch_size,)
            sample_ids.append(predicted)
            inputs = self.embed(predicted).unsqueeze(1)  # (batch_size, 1, embed_size)
        return torch.stack(sample_ids, dim=1)  # (batch_size, max_len)


class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions, lengths=25):
        features = self.encoderCNN(images)  # (batch_size, embed_size)
        outputs = self.decoderRNN(features, captions)  # (sum(lengths), vocab_size)
        return outputs

    def caption_image(self, image, vocabulary, max_len=25):
        result_caption = []
        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)  # (1, embed_size)
            states = None
            for _ in range(max_len):
                hiddens, states = self.decoderRNN.lstm(x, states)  # (1, 1, hidden_size)
                output = self.decoderRNN.linear(hiddens.squeeze(1))  # (1, vocab_size)
                predicted = output.argmax(1)  # (1,)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(1)  # (1, 1, embed_size)
                if vocabulary.itos[predicted.item()] == "<END>":
                    break
        return [vocabulary.itos[idx] for idx in result_caption]

In [9]:
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from dataset import FlickrDataset
import tqdm

img_root = r'D:\git\Image_Captioning\dataset\Images'
caption_root = r'D:\git\Image_Captioning\dataset\captions.txt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_size = 512
embedding_dim = 256
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# def embed_input(batch_caption):
#     for cap in batch_caption
def collate_fn(batch, max_len=25):
    images = []
    captions = []
    # Find max caption length in batch
    for img, caption in batch:
        images.append(img)
        # Pad caption to max_len
        pad_tensor = torch.zeros(max_len - len(caption)).long()  # Use <PAD> token index
        padded_caption = torch.cat((caption, pad_tensor), dim=0)
        captions.append(padded_caption)

    # Stack images and captions into tensors
    images = torch.stack(images, dim=0)  # Shape: (batch_size, C, H, W)
    captions = torch.stack(captions, dim=0)  # Shape: (batch_size, max_len)
    return images, captions


dataset = FlickrDataset(img_root, caption_root, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
vocab_size = len(dataset.vocab)
print('vocab size:', vocab_size)


vocab size: 4107


In [10]:
model = CNNtoRNN(vocab_size, hidden_size, embedding_dim, num_layers=1).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)  # Rename for clarity
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 10

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch, (images, captions) in tqdm.tqdm(enumerate(dataloader)):
        image = images.to(device)  # Shape: (batch_size, C, H, W)
        caption = captions.to(device)  # Shape: (batch_size, max_len)
        optimizer.zero_grad()
        outputs = model(image, captions)  # Ensure model accepts lengths if needed
        targets = captions[:, 1:]  # Exclude <SOS>
        outputs = outputs[:, 1:-1]
        print(outputs.size())
        loss = loss_fn(outputs, targets)
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch: {epoch + 1}, Average Loss: {avg_loss:.4f}")


0it [00:06, ?it/s]

torch.Size([32, 22, 256])





RuntimeError: Expected target size [32, 256], got [32, 24]