In [None]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence
from collections import Counter
import torch.optim as optim
from tqdm import tqdm
import nltk
from nltk.translate.bleu_score import corpus_bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.cider.cider import Cider

# Download necessary NLTK data
nltk.download('punkt')

# Vocabulary Class
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return text.lower().split()

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        return [self.stoi.get(token, self.stoi["<unk>"]) for token in tokenized_text]

# Dataset Class
class RadiologyDataset(Dataset):
    def __init__(self, caption_file, image_dir, vocab, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.vocab = vocab
        self.data = []

        with open(caption_file, 'r') as file:
            for line in file:
                parts = line.strip().split('\t')
                if len(parts) == 2:
                    image_name, caption = parts
                    image_path = os.path.join(image_dir, image_name + ".jpg")
                    if os.path.exists(image_path):
                        self.data.append((image_path, caption))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, caption = self.data[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        caption = [self.vocab.stoi["<start>"]] + self.vocab.numericalize(caption) + [self.vocab.stoi["<end>"]]
        return image, torch.tensor(caption)

# Collate Function
def collate_fn(batch):
    images, captions = zip(*batch)
    lengths = [len(cap) for cap in captions]
    padded_captions = torch.zeros(len(captions), max(lengths)).long()

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return torch.stack(images, 0), padded_captions, lengths

# Additional function to convert indices back to words
def decode_caption(caption, vocab):
    return [vocab.itos[idx] for idx in caption if idx not in {vocab.stoi["<start>"], vocab.stoi["<end>"], vocab.stoi["<pad>"]}]

# Evaluation function using BLEU, ROUGE, and CIDEr
def evaluate_model(model, data_loader, vocab, device):
    model.eval()
    references = []  # List of lists to hold reference sentences
    candidates = []  # List to hold candidate sentences

    with torch.no_grad():
        for images, captions, lengths in tqdm(data_loader, desc="Evaluating"):
            images = images.to(device)

            # Forward pass through the encoder
            features = model.encoder(images)

            # Generate captions for the images
            outputs = model.decoder.generate(features, vocab, max_length=20)

            # Split the outputs into separate captions for each image
            batch_size = len(images)
            batch_outputs = [outputs[i * 20:(i + 1) * 20] for i in range(batch_size)]

            # Decode reference and candidate captions to words
            decoded_captions = [decode_caption(caption.tolist(), vocab) for caption in captions]
            decoded_outputs = [decode_caption(output, vocab) for output in batch_outputs]

            references.extend([[ref] for ref in decoded_captions])
            candidates.extend(decoded_outputs)

    # Compute BLEU score
    bleu_score = corpus_bleu(references, candidates)

    # Compute ROUGE score
    rouge = Rouge()
    rouge_score, _ = rouge.compute_score({i: [" ".join(ref)] for i, ref in enumerate(references)},
                                         {i: [" ".join(candidate)] for i, candidate in enumerate(candidates)})

    # Compute CIDEr score
    cider = Cider()
    cider_score, _ = cider.compute_score({i: [" ".join(ref)] for i, ref in enumerate(references)},
                                         {i: [" ".join(candidate)] for i, candidate in enumerate(candidates)})

    print(f"BLEU Score: {bleu_score:.4f}, ROUGE Score: {rouge_score:.4f}, CIDEr Score: {cider_score:.4f}")
    return bleu_score, rouge_score, cider_score

# Encoder CNN
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]  # Remove the last fully-connected layer
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
    
    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features

# Decoder RNN
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, features, captions, lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False)
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    # Implement the `generate` method for caption generation
    def generate(self, features, vocab, max_length=20):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)

        # Initialize the hidden state with zero
        states = None
        for _ in range(max_length):
            hiddens, states = self.lstm(inputs, states)  # Forward pass through LSTM
            outputs = self.linear(hiddens.squeeze(1))    # Compute outputs

            # Take the word with the maximum probability for each image in the batch
            _, predicted = outputs.max(1)

            # Append predicted values for each image in the batch
            for i in range(predicted.size(0)):
                sampled_ids.append(predicted[i].item())

            # Embed the predicted word for the next input
            inputs = self.embed(predicted)               
            inputs = inputs.unsqueeze(1)
        
        return sampled_ids

# Image Captioning Model
class ImageCaptioningModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
    
    def forward(self, images, captions, lengths):
        features = self.encoder(images)
        outputs = self.decoder(features, captions, lengths)
        return outputs

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build the vocabulary
freq_threshold = 5
captions = []

with open("data/train/radiology/captions.txt", 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            _, caption = parts
            captions.append(caption)

# Create and build vocabulary
vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions)

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create Dataset and DataLoader
train_dataset = RadiologyDataset("data/train/radiology/captions.txt", "data/train/radiology/images", vocab, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=16, collate_fn=collate_fn)

test_dataset = RadiologyDataset("data/test/radiology/captions.txt", "data/test/radiology/images", vocab, transform=transform)
test_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=16, collate_fn=collate_fn)


# Hyperparameters
embed_size = 256
hidden_size = 512
vocab_size = len(vocab)  # Using the correct vocabulary size
num_layers = 2
learning_rate = 0.001
num_epochs = 5

# Initialize model, criterion, and optimizer
model = ImageCaptioningModel(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    train_progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch [{epoch+1}/{num_epochs}] Training")

    for i, (images, captions, lengths) in train_progress_bar:
        images = images.to(device)
        captions = captions.to(device)

        outputs = model(images, captions, lengths)
        targets = pack_padded_sequence(captions, lengths, batch_first=True, enforce_sorted=False)[0]
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        train_progress_bar.set_postfix({"Batch Loss": loss.item()})

    average_train_loss = total_train_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Training Loss: {average_train_loss:.4f}")

    # Evaluate the model on the training set with BLEU, ROUGE, and CIDEr metrics
    bleu_score, rouge_score, cider_score = evaluate_model(model, test_loader, vocab, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training BLEU: {bleu_score:.4f}, ROUGE: {rouge_score:.4f}, CIDEr: {cider_score:.4f}")

# Save the trained model
torch.save(model.state_dict(), "image_captioning_model.pth")
print("Model saved as image_captioning_model.pth")

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/dvasic/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Epoch [1/5] Training:  13%|█▎        | 130/1023 [01:32<08:06,  1.83it/s, Batch Loss=6.07]