In [3]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from collections import Counter
import torch.optim as optim
from tqdm import tqdm
import nltk
from torch.nn.utils.rnn import pack_padded_sequence

# Download necessary NLTK data
nltk.download('punkt')

# Vocabulary Class
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return text.lower().split()

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        return [self.stoi.get(token, self.stoi["<unk>"]) for token in tokenized_text]


# Dataset Class
class RadiologyDataset(Dataset):
    def __init__(self, caption_file, image_dir, vocab, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.vocab = vocab
        self.data = []

        with open(caption_file, 'r') as file:
            for line in file:
                parts = line.strip().split('\t')
                if len(parts) == 2:
                    image_name, caption = parts
                    image_path = os.path.join(image_dir, image_name + ".jpg")
                    if os.path.exists(image_path):
                        self.data.append((image_path, caption))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, caption = self.data[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        caption = [self.vocab.stoi["<start>"]] + self.vocab.numericalize(caption) + [self.vocab.stoi["<end>"]]
        return image, torch.tensor(caption)


# Collate Function to Handle Variable-Length Captions
def collate_fn(batch):
    images, captions = zip(*batch)
    lengths = [len(cap) for cap in captions]
    padded_captions = torch.zeros(len(captions), max(lengths)).long()

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return torch.stack(images, 0), padded_captions, lengths


# Attention Mechanism Class
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = self.softmax(att)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha


# Encoder CNN Class
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-2]  # Remove the last few layers
        self.resnet = nn.Sequential(*modules)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = self.adaptive_pool(features).view(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


# Decoder with Attention Mechanism
class AttentionDecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, encoder_dim=1024, attention_dim=512, num_layers=1):
        super(AttentionDecoderRNN, self).__init__()
        self.attention = Attention(encoder_dim, hidden_size, attention_dim)
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTMCell(embed_size + encoder_dim, hidden_size)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(encoder_dim, hidden_size)
        self.init_c = nn.Linear(encoder_dim, hidden_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, encoder_out, captions, lengths):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        mean_encoder_out = encoder_out.mean(dim=1)

        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        embeddings = self.embed(captions)
        max_length = max(lengths)
        predictions = torch.zeros(batch_size, max_length, self.linear.out_features).to(encoder_out.device)

        for t in range(max_length):
            context, alpha = self.attention(encoder_out, h)
            lstm_input = torch.cat([embeddings[:, t, :], context], dim=1)
            h, c = self.lstm(lstm_input, (h, c))
            preds = self.linear(self.dropout(h))
            predictions[:, t, :] = preds

        return predictions


# Image Captioning Model with Attention
class ImageCaptioningModelWithAttention(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, encoder_dim, attention_dim, num_layers):
        super(ImageCaptioningModelWithAttention, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = AttentionDecoderRNN(embed_size, hidden_size, vocab_size, encoder_dim, attention_dim, num_layers)

    def forward(self, images, captions, lengths):
        encoder_out = self.encoder(images)
        outputs = self.decoder(encoder_out, captions, lengths)
        return outputs


# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build the vocabulary
freq_threshold = 5
captions = []

with open("data/train/radiology/captions.txt", 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            _, caption = parts
            captions.append(caption)

# Create and build vocabulary
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions)

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create Dataset and DataLoader
train_dataset = RadiologyDataset("data/train/radiology/captions.txt", "data/train/radiology/images", vocab, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=16, collate_fn=collate_fn)

# Hyperparameters
embed_size = 256
hidden_size = 512
vocab_size = len(vocab)  # Using the correct vocabulary size
encoder_dim = 256       # Adjusted to match ResNet output
attention_dim = 512
num_layers = 1
learning_rate = 0.001
num_epochs = 5

# Initialize model, criterion, and optimizer
model = ImageCaptioningModelWithAttention(embed_size, hidden_size, vocab_size, encoder_dim, attention_dim, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# Training Loop
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    train_progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch [{epoch+1}/{num_epochs}] Training")

    for i, (images, captions, lengths) in train_progress_bar:
        images = images.to(device)
        captions = captions.to(device)

        # Forward pass
        outputs = model(images, captions, lengths)  # Outputs should be of shape [batch_size, max_len, vocab_size]

        # Pack the output and the captions using `pack_padded_sequence`
        packed_outputs = pack_padded_sequence(outputs, lengths, batch_first=True, enforce_sorted=False)
        packed_targets = pack_padded_sequence(captions, lengths, batch_first=True, enforce_sorted=False)

        # Get the data from the packed sequence
        outputs_data = packed_outputs.data
        targets_data = packed_targets.data

        # Compute loss
        loss = criterion(outputs_data, targets_data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        train_progress_bar.set_postfix({"Batch Loss": loss.item()})

    average_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Training Loss: {average_train_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "image_captioning_model.pth")
print("Model saved as image_captioning_model.pth")

# Save the trained model
torch.save(model.state_dict(), "image_captioning_with_attention.pth")
print("Model saved as image_captioning_with_attention.pth")


[nltk_data] Downloading package punkt to /home/dvasic/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Epoch [1/5] Training:   0%|          | 0/1023 [00:10<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 196.00 MiB. GPU 