In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!pip install transformers

import warnings
warnings.filterwarnings('ignore')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
import logging
from transformers import BertTokenizerFast
import torch.nn.functional as F
import random
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast

# Logger setup
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Tensorboard Writer
writer = SummaryWriter()

# Tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# 3D CNN Encoder
class EncoderCNN3D(nn.Module):
  def __init__(self, channel_size=64, output_feature_size=512):
    super(EncoderCNN3D, self).__init__()
    self.conv1 = nn.Conv3d(channel_size, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
    self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    self.flatten = nn.Flatten()
    self.fc = nn.Linear(2097152, output_feature_size)

  def forward(self, x):
        # Use full precision for convolution operations
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# DecoderRNN
class DecoderRNN(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size, feature_size):
    super(DecoderRNN, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(input_size=embed_size + feature_size, hidden_size=hidden_size, num_layers=1, batch_first=True)
    self.linear = nn.Linear(hidden_size, vocab_size)

  def forward(self, features, captions):
    embeddings = self.embedding(captions).view(captions.size(0), captions.size(1), -1)
    features = features.unsqueeze(1).repeat(1, captions.size(1), 1)
    combined = torch.cat((features, embeddings), dim=2)
    hiddens, _ = self.lstm(combined)
    outputs = self.linear(hiddens)
    return outputs

# EncoderDecoderModel
class EncoderDecoderModel(nn.Module):
    def __init__(self, encoder, decoder, embed_size):
        super(EncoderDecoderModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, frames, captions):
        features = self.encoder(frames)
        outputs = self.decoder(features, captions[:, :-1])
        return outputs

# Dataset
class TestDataset(Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        try:
            self.data = torch.load(self.file_path)
        except Exception as e:
            logger.error(f"Error loading {self.file_path}: {e}")
            self.data = {'frames': [], 'input_ids': []}
    def __len__(self):
        return len(self.data['input_ids'])
    def __getitem__(self, idx):
        try:
            data = torch.load(self.file_path)
        except Exception as e:
            logger.error(f"Error loading {self.file_path}: {e}")
            return None

        frames = data['frames']
        captions = data['input_ids'][idx]  # Get the tokenized description for the specific index

        return frames, captions

# Metrics Tracking
def compute_accuracy(outputs, targets):
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == targets).float()
    accuracy = correct.sum() / len(correct)
    return accuracy.item()

# Model
encoder = EncoderCNN3D()
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=tokenizer.vocab_size, feature_size=512)
model = EncoderDecoderModel(encoder, decoder, embed_size=256).cuda()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Loss function
criterion = nn.NLLLoss(ignore_index=tokenizer.pad_token_id, reduction='none')

# Custom sequence loss
def sequence_loss(outputs, targets, mask):
    log_probs = F.log_softmax(outputs, dim=2)

    # Ensure targets are correctly shaped
    # targets should be: [batch_size, seq_len] with each element being a class label index
    if targets.dim() != 2:
        raise ValueError(f"targets tensor has incorrect number of dimensions: {targets.dim()}")

    # Expanding targets to match log_probs dimensions
    targets_expanded = targets.unsqueeze(-1)  # Shape: [batch_size, seq_len, 1]

    log_probs_for_targets = log_probs.gather(2, targets_expanded).squeeze(-1)

    log_probs_for_targets *= mask

    loss = -log_probs_for_targets.sum() / mask.sum()
    return loss

# Create train/validation split
num_batches = 1573
batch_indices = list(range(num_batches))
random.shuffle(batch_indices)
val_batches = batch_indices[:int(0.2 * num_batches)]
train_batches = batch_indices[int(0.2 * num_batches):]

# Define a consistent sequence length for captions
max_len = 50

scaler = GradScaler()

# Training Loop
clip_value = 1
problematic_batches = [423, 1207, 1208]
batch_size = 32  # Adjust as needed
accumulation_steps = 2
for epoch in range(1):
    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
            batch_list = train_batches
        else:
            model.eval()
            batch_list = val_batches
            val_loss = 0
            val_accuracy = 0
            num_batches = 1573

        total_batches = len(batch_list)
        for batch_index, batch_num in enumerate(batch_list):
            if batch_num in problematic_batches:
                continue # Skip problematic batches
            batch_file = f"/content/drive/MyDrive/Video-to-Text/processed_data/batch_{batch_num}.pt"
            dataset = TestDataset(batch_file)
            dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=6)

            for step, (frames, captions) in enumerate(dataloader):
                if frames is None or captions is None:
                    continue
                if frames.size(1) != 64:
                    continue  # Skip this batch
                frames, captions = frames.cuda(), captions.cuda()
                # Pad captions to a consistent length
                padded_captions = F.pad(captions, (0, max_len - captions.shape[1]), value=tokenizer.pad_token_id)
                inputs = padded_captions[:, :-1].cuda()  # All tokens except the last
                targets = padded_captions[:, 1:].cuda()  # All tokens except the first

                with autocast():
                    # Forward pass
                    outputs = model(frames, inputs)

                    # Ensure targets are aligned with the outputs
                    if targets.shape[1] > outputs.shape[1]:
                        targets = targets[:, :outputs.shape[1]]

                    # Prepare the mask
                    mask = (inputs != tokenizer.pad_token_id).float()[:, :outputs.shape[1]].cuda()

                    if torch.isnan(outputs).any():
                        print("Nan deteced in model outputs")

                    # Calculate loss
                    loss = sequence_loss(outputs, targets, mask) / accumulation_steps

                    if torch.isnan(loss).any():
                        print("Nan deteced in loss")

                if phase == "train":
                    # Backpropagation
                    optimizer.zero_grad()
                    scaler.scale(loss).backward()

                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
                    if (step + 1) % accumulation_steps == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()

                        torch.cuda.empty_cache()

                # Log metrics
                print(f"Epoch {epoch}, Phase {phase}, Batch {batch_num}, Batch {batch_index + 1}/{total_batches}, Loss: {loss.item()}")

            # Save checkpoints periodically
            if phase == "train" and batch_num % 100 == 0:
                torch.save(model.state_dict(), f"model_epoch_{epoch}_batch_{batch_num}.pt")

        # Logging at the end of each phase
        if phase == "val":
            val_loss /= 1573
            val_accuracy /= 1573
            print(f"Epoch {epoch}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}")
            writer.add_scalar('Loss/val', val_loss, epoch)
            writer.add_scalar('Accuracy/val', val_accuracy, epoch)
        elif phase == "train":
            writer.add_scalar('Loss/train', loss.item(), epoch)

        print(f"Epoch {epoch}, Phase {phase} completed.")
    scheduler.step()

# Saving final model
torch.save(model.state_dict(), "final_model.pt")
writer.close()

Epoch 0, Phase train, Batch 858, Batch 1/1259, Loss: 5.160451412200928
Epoch 0, Phase train, Batch 858, Batch 1/1259, Loss: 5.162769794464111
Epoch 0, Phase train, Batch 1289, Batch 2/1259, Loss: 4.974930763244629
Epoch 0, Phase train, Batch 1289, Batch 2/1259, Loss: 4.968714714050293
Epoch 0, Phase train, Batch 1324, Batch 3/1259, Loss: 4.888542652130127
Epoch 0, Phase train, Batch 1324, Batch 3/1259, Loss: 4.837698459625244
Epoch 0, Phase train, Batch 1242, Batch 4/1259, Loss: 4.749791622161865
Epoch 0, Phase train, Batch 1242, Batch 4/1259, Loss: 4.786004066467285
Epoch 0, Phase train, Batch 1566, Batch 5/1259, Loss: 4.717471122741699
Epoch 0, Phase train, Batch 1566, Batch 5/1259, Loss: 4.725526809692383
Epoch 0, Phase train, Batch 741, Batch 6/1259, Loss: 4.695276737213135
Epoch 0, Phase train, Batch 741, Batch 6/1259, Loss: 4.675205230712891
Epoch 0, Phase train, Batch 985, Batch 7/1259, Loss: 4.625764846801758
Epoch 0, Phase train, Batch 985, Batch 7/1259, Loss: 4.62181472778320