# CS6540 Deep Learning Final Project


## Graph based Retrieval Augmented Generation using the WildGraph benchmarking system. 

### Imports + Hardware accelerator setup for Mac

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import time
from collections import Counter
import urllib.request
import zipfile
import os

# Configure device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"Using device: {device}")

if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"  # For Apple Silicon
else:
    device = "cpu"

device = torch.device(device)

print("using device ",device)

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Store results for comparison
results = {}

using device  mps


### Function definition for training the vanilla Seq2Seq model (no attention)

In [2]:
def train_seq2seq(model, train_loader, val_loader, epochs=10, lr=0.001, name="Model", clip=1.0):
    """
    Training loop for sequence-to-sequence models.
    Returns training history for visualization including loss and accuracy.

    Args:
        model: The seq2seq model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        epochs: Number of training epochs
        lr: Learning rate for Adam optimizer
        name: Model name for printing progress
        clip: Gradient clipping threshold to prevent exploding gradients
    """
    # Move model to GPU if available
    model = model.to(device)

    # CrossEntropyLoss with ignore_index=0 means we don't compute loss for padding tokens
    # This is important because padding tokens shouldn't affect our loss calculation
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    # Adam optimizer - good default choice for seq2seq models
    optimizer = optim.Adam(model.parameters(), lr=lr)

    print(f"\nTraining {name} for {epochs} epochs...")
    start_time = time.time()

    # Track metrics for plotting later
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(epochs):
        # =====================
        # TRAINING PHASE
        # =====================
        model.train()  # Set model to training mode (enables dropout, etc.)
        running_loss = 0.0
        correct_tokens = 0
        total_tokens = 0

        for src, tgt, src_lens, tgt_lens in train_loader:
            # Move data to GPU
            src, tgt = src.to(device), tgt.to(device)

            # Clear gradients from previous batch
            optimizer.zero_grad()

            # TEACHER FORCING: Feed the target sequence (shifted by 1) as input to decoder
            # tgt[:, :-1] = all tokens except the last (input to decoder)
            # tgt[:, 1:] = all tokens except the first (what we want to predict)
            # Example: tgt = [<sos>, je, suis, <eos>]
            #   Input to decoder: [<sos>, je, suis]
            #   Target output:    [je, suis, <eos>]
            output = model(src, tgt[:, :-1], src_lens)

            # Reshape for loss calculation
            # output: (batch, seq_len, vocab_size) -> (batch * seq_len, vocab_size)
            # target: (batch, seq_len) -> (batch * seq_len)
            output_flat = output.reshape(-1, output.shape[-1])
            tgt_flat = tgt[:, 1:].reshape(-1)

            # Compute cross-entropy loss
            loss = criterion(output_flat, tgt_flat)

            # Backpropagation
            loss.backward()

            # GRADIENT CLIPPING: Prevents exploding gradients in RNNs
            # If gradient norm > clip, scale it down
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

            # Update model parameters
            optimizer.step()

            running_loss += loss.item()

            # Calculate token-level accuracy (excluding padding tokens)
            predictions = output.argmax(dim=-1)  # Get predicted token indices
            targets = tgt[:, 1:]  # Remove <sos> token

            # Create mask: True for non-padding tokens, False for padding
            non_pad_mask = targets != 0

            # Count correct predictions only for non-padding tokens
            correct_tokens += ((predictions == targets) & non_pad_mask).sum().item()
            total_tokens += non_pad_mask.sum().item()

        # Average loss over all batches
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct_tokens / total_tokens if total_tokens > 0 else 0

        # =====================
        # VALIDATION PHASE
        # =====================
        model.eval()  # Set model to evaluation mode (disables dropout)
        val_loss = 0.0
        correct_tokens = 0
        total_tokens = 0

        # No gradient computation needed for validation
        with torch.no_grad():
            for src, tgt, src_lens, tgt_lens in val_loader:
                src, tgt = src.to(device), tgt.to(device)

                # Same forward pass as training
                output = model(src, tgt[:, :-1], src_lens)

                output_flat = output.reshape(-1, output.shape[-1])
                tgt_flat = tgt[:, 1:].reshape(-1)

                loss = criterion(output_flat, tgt_flat)
                val_loss += loss.item()

                # Calculate accuracy
                predictions = output.argmax(dim=-1)
                targets = tgt[:, 1:]
                non_pad_mask = targets != 0
                correct_tokens += ((predictions == targets) & non_pad_mask).sum().item()
                total_tokens += non_pad_mask.sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * correct_tokens / total_tokens if total_tokens > 0 else 0

        # Print progress
        print(f"  Epoch [{epoch+1}/{epochs}] | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

        # Store metrics for plotting
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)

    duration = time.time() - start_time
    print(f"{name} - Final Val Loss: {history['val_loss'][-1]:.4f}, Val Acc: {history['val_acc'][-1]:.2f}%, Time: {duration:.2f}s")

    return history, duration