<a href="https://colab.research.google.com/github/fazal735/DL_A3/blob/main/DL_A3(test).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import time
import random
import os
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Define constants
HIDDEN_DIM = 128
EMBEDDING_DIM = 64
BATCH_SIZE = 64
MAX_LENGTH = 20  # Adjust based on your data
DROPOUT = 0.2
LEARNING_RATE = 0.001
TEACHER_FORCING_RATIO = 0.5
NUM_EPOCHS = 50

# Dataset class
class TransliterationDataset(Dataset):
    def __init__(self, eng_words, hindi_words, eng_vocab, hindi_vocab, max_length):
        self.eng_words = eng_words
        self.hindi_words = hindi_words
        self.eng_vocab = eng_vocab
        self.hindi_vocab = hindi_vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.eng_words)

    def __getitem__(self, idx):
        eng_word = ['<SOS>'] + list(self.eng_words[idx]) + ['<EOS>']
        hindi_word = ['<SOS>'] + list(self.hindi_words[idx]) + ['<EOS>']

        # Pad sequences
        eng_word = eng_word + ['<PAD>'] * (self.max_length - len(eng_word))
        hindi_word = hindi_word + ['<PAD>'] * (self.max_length - len(hindi_word))

        # Convert to indices
        eng_indices = [self.eng_vocab.get(char, self.eng_vocab['<UNK>']) for char in eng_word]
        hindi_indices = [self.hindi_vocab.get(char, self.hindi_vocab['<UNK>']) for char in hindi_word]

        return {
            'eng_indices': torch.tensor(eng_indices, dtype=torch.long),
            'hindi_indices': torch.tensor(hindi_indices, dtype=torch.long),
            'eng_length': len(self.eng_words[idx]) + 2,  # +2 for <SOS> and <EOS>
            'hindi_length': len(self.hindi_words[idx]) + 2
        }

# Encoder with single layer RNN
class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, dropout_p=0.2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, src):
        # src shape: [batch_size, seq_len]
        embedded = self.dropout(self.embedding(src))
        # embedded shape: [batch_size, seq_len, embedding_dim]

        # Use packed padded sequence if needed
        outputs, hidden = self.rnn(embedded)
        # outputs shape: [batch_size, seq_len, hidden_dim]
        # hidden shape: [1, batch_size, hidden_dim]

        return outputs, hidden

# Attention mechanism
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden shape: [1, batch_size, hidden_dim]
        # encoder_outputs shape: [batch_size, src_len, hidden_dim]

        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        # Repeat hidden state for each position in sequence
        hidden = hidden.transpose(0, 1).repeat(1, src_len, 1)
        # hidden shape: [batch_size, src_len, hidden_dim]

        # Calculate energy
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy shape: [batch_size, src_len, hidden_dim]

        attention = self.v(energy).squeeze(2)
        # attention shape: [batch_size, src_len]

        return F.softmax(attention, dim=1)

# Decoder with attention
class AttentionDecoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, hidden_dim, attention, dropout_p=0.2):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.GRU(embedding_dim + hidden_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim * 2 + embedding_dim, output_dim)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input, hidden, encoder_outputs):
        # input shape: [batch_size, 1]
        # hidden shape: [1, batch_size, hidden_dim]
        # encoder_outputs shape: [batch_size, src_len, hidden_dim]

        input = input.unsqueeze(1)  # Add sequence dimension: [batch_size, 1]

        embedded = self.dropout(self.embedding(input))
        # embedded shape: [batch_size, 1, embedding_dim]

        # Calculate attention weights
        attn_weights = self.attention(hidden, encoder_outputs)
        # attn_weights shape: [batch_size, src_len]

        # Create weighted context vector
        attn_weights = attn_weights.unsqueeze(1)
        # attn_weights shape: [batch_size, 1, src_len]

        context = torch.bmm(attn_weights, encoder_outputs)
        # context shape: [batch_size, 1, hidden_dim]

        # Combine embedded input and context vector
        rnn_input = torch.cat((embedded, context), dim=2)
        # rnn_input shape: [batch_size, 1, embedding_dim + hidden_dim]

        output, hidden = self.rnn(rnn_input, hidden)
        # output shape: [batch_size, 1, hidden_dim]
        # hidden shape: [1, batch_size, hidden_dim]

        # Final prediction
        embedded = embedded.squeeze(1)
        output = output.squeeze(1)
        context = context.squeeze(1)

        prediction = self.fc_out(torch.cat((output, context, embedded), dim=1))
        # prediction shape: [batch_size, output_dim]

        return prediction, hidden, attn_weights

# Seq2Seq with Attention
class Seq2SeqAttention(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src shape: [batch_size, src_len]
        # trg shape: [batch_size, trg_len]

        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim

        # Tensor to store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        # Store attention weights for visualization
        attention_weights = torch.zeros(batch_size, trg_len, src.shape[1]).to(self.device)

        # Encode the source sequence
        encoder_outputs, hidden = self.encoder(src)

        # First input to the decoder is the <SOS> token
        decoder_input = trg[:, 0]

        for t in range(1, trg_len):
            # Use previous hidden state to generate new state
            decoder_output, hidden, attn_weights = self.decoder(decoder_input, hidden, encoder_outputs)

            # Store prediction and attention weights
            outputs[:, t, :] = decoder_output
            attention_weights[:, t, :] = attn_weights.squeeze(1)

            # Decide whether to use teacher forcing
            teacher_force = random.random() < teacher_forcing_ratio

            # Get the highest predicted token
            top1 = decoder_output.argmax(1)

            # Use ground truth or predicted token as next input
            decoder_input = trg[:, t] if teacher_force else top1

        return outputs, attention_weights

    def predict(self, src, max_length, sos_idx, eos_idx):
        # Set the model to evaluation mode
        self.eval()

        batch_size = src.shape[0]

        # Store predictions and attention maps
        predictions = torch.zeros(batch_size, max_length, dtype=torch.long).to(self.device)
        attention_maps = torch.zeros(batch_size, max_length, src.shape[1]).to(self.device)

        # Encode input sequence
        encoder_outputs, hidden = self.encoder(src)

        # First decoder input is <SOS> token
        decoder_input = torch.tensor([sos_idx] * batch_size).to(self.device)

        # Initialize list to store outputs
        outputs = []

        for t in range(1, max_length):
            # Forward pass through decoder
            decoder_output, hidden, attn_weights = self.decoder(decoder_input, hidden, encoder_outputs)

            # Get most likely word index
            top1 = decoder_output.argmax(1)

            # Save attention weights for visualization
            attention_maps[:, t, :] = attn_weights.squeeze(1)

            # Save prediction
            predictions[:, t] = top1

            # Next input is current prediction
            decoder_input = top1

            # Store outputs for each batch
            outputs.append(top1.unsqueeze(1))

            # Stop if all sequences have reached EOS
            if all([(top1[i] == eos_idx) for i in range(batch_size)]):
                break

        # Return predictions and attention maps
        return predictions, attention_maps

# Training and evaluation functions
def train(model, dataloader, optimizer, criterion, clip, device, teacher_forcing_ratio):
    model.train()
    epoch_loss = 0

    for batch in dataloader:
        src = batch['eng_indices'].to(device)
        trg = batch['hindi_indices'].to(device)

        optimizer.zero_grad()

        # Forward pass
        output, _ = model(src, trg, teacher_forcing_ratio)

        # Calculate loss (ignore first token which is <SOS>)
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        # Compute loss
        loss = criterion(output, trg)

        # Backpropagation
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        # Update parameters
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            src = batch['eng_indices'].to(device)
            trg = batch['hindi_indices'].to(device)

            # Forward pass (no teacher forcing)
            output, _ = model(src, trg, 0)

            # Calculate loss
            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

def calculate_accuracy(model, dataloader, hindi_vocab, device, max_length):
    model.eval()
    correct = 0
    total = 0

    # Create reverse vocabulary for decoding predictions
    rev_hindi_vocab = {idx: char for char, idx in hindi_vocab.items()}
    predictions = []

    with torch.no_grad():
        for batch in dataloader:
            src = batch['eng_indices'].to(device)
            trg = batch['hindi_indices'].to(device)

            # Get predictions
            pred, _ = model.predict(src, max_length, hindi_vocab['<SOS>'], hindi_vocab['<EOS>'])

            # Remove <SOS> token from target
            trg = trg[:, 1:]

            # Process each sequence in the batch
            for i in range(trg.size(0)):
                # Convert predictions to characters
                pred_seq = []
                for j in range(1, pred.size(1)):  # Start from 1 to skip <SOS>
                    if pred[i, j].item() == hindi_vocab['<EOS>']:
                        break
                    pred_seq.append(rev_hindi_vocab[pred[i, j].item()])

                # Convert target to characters
                trg_seq = []
                for j in range(trg.size(1)):
                    if trg[i, j].item() == hindi_vocab['<EOS>']:
                        break
                    if trg[i, j].item() == hindi_vocab['<PAD>']:
                        continue
                    trg_seq.append(rev_hindi_vocab[trg[i, j].item()])

                # Store prediction
                predictions.append({
                    'predicted': ''.join(pred_seq),
                    'target': ''.join(trg_seq)
                })

                # Check if prediction matches target
                if ''.join(pred_seq) == ''.join(trg_seq):
                    correct += 1
                total += 1

    accuracy = correct / total
    return accuracy, predictions

def save_predictions(predictions, filename):
    df = pd.DataFrame(predictions)
    df.to_csv(filename, index=False)
    print(f"Predictions saved to {filename}")

def plot_attention_heatmap(attention, src_text, trg_text, ax=None):
    """Plot attention heatmap for a single example"""
    if ax is None:
        _, ax = plt.subplots(figsize=(10, 8))

    # Trim to actual sequence length (remove padding)
    src_len = len([c for c in src_text if c != '<PAD>'])
    trg_len = len([c for c in trg_text if c != '<PAD>'])

    attn_display = attention[:trg_len, :src_len]

    # Display heatmap
    sns.heatmap(attn_display, cmap='viridis', ax=ax)

    # Set axis labels
    ax.set_xticklabels(src_text[:src_len], rotation=90)
    ax.set_yticklabels(trg_text[:trg_len])

    ax.set_xlabel('Source')
    ax.set_ylabel('Target')

    return ax

def plot_attention_grid(model, dataset, test_indices, device, eng_vocab, hindi_vocab, max_length=20, num_examples=9):
    """Create a grid of attention heatmaps"""
    model.eval()

    # Create reverse vocabularies
    rev_eng_vocab = {idx: char for char, idx in eng_vocab.items()}
    rev_hindi_vocab = {idx: char for char, idx in hindi_vocab.items()}

    # Set up the plot
    fig, axes = plt.subplots(3, 3, figsize=(20, 20))
    axes = axes.flatten()

    with torch.no_grad():
        for i, idx in enumerate(test_indices[:num_examples]):
            sample = dataset[idx]
            src = sample['eng_indices'].unsqueeze(0).to(device)

            # Get predictions and attention weights
            preds, attention_weights = model.predict(src, max_length, hindi_vocab['<SOS>'], hindi_vocab['<EOS>'])

            # Convert indices to characters
            src_chars = [rev_eng_vocab[idx.item()] for idx in sample['eng_indices']]

            # Get predicted chars
            pred_chars = ['<SOS>']
            for j in range(1, preds.size(1)):
                if preds[0, j].item() == hindi_vocab['<EOS>']:
                    pred_chars.append('<EOS>')
                    break
                pred_chars.append(rev_hindi_vocab[preds[0, j].item()])

            # Plot attention heatmap
            plot_attention_heatmap(attention_weights[0].cpu().numpy(),
                                  src_chars,
                                  pred_chars,
                                  ax=axes[i])

            axes[i].set_title(f'Input: {"".join([c for c in src_chars if c not in ["<SOS>", "<EOS>", "<PAD>"]])} → Output: {"".join([c for c in pred_chars if c not in ["<SOS>", "<EOS>", "<PAD>"]])}')

    plt.tight_layout()
    plt.savefig('attention_heatmaps.png')
    plt.close()

    return fig

# Function to load and prepare data
def load_data(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    eng_words = []
    hindi_words = []

    for line in lines:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            eng_words.append(parts[0].lower())
            hindi_words.append(parts[1])

    return eng_words, hindi_words

def create_vocabularies(eng_words, hindi_words):
    eng_chars = set()
    hindi_chars = set()

    for word in eng_words:
        eng_chars.update(list(word))

    for word in hindi_words:
        hindi_chars.update(list(word))

    # Create vocabularies
    eng_vocab = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
    for i, char in enumerate(sorted(eng_chars)):
        eng_vocab[char] = i + 4

    hindi_vocab = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
    for i, char in enumerate(sorted(hindi_chars)):
        hindi_vocab[char] = i + 4

    return eng_vocab, hindi_vocab

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    data_path = 'transliteration_data.txt'  # Replace with your data path
    eng_words, hindi_words = load_data(data_path)

    # Create vocabularies
    eng_vocab, hindi_vocab = create_vocabularies(eng_words, hindi_words)

    # Split data
    eng_train, eng_temp, hindi_train, hindi_temp = train_test_split(
        eng_words, hindi_words, test_size=0.3, random_state=42
    )
    eng_val, eng_test, hindi_val, hindi_test = train_test_split(
        eng_temp, hindi_temp, test_size=0.5, random_state=42
    )

    # Create datasets
    train_dataset = TransliterationDataset(eng_train, hindi_train, eng_vocab, hindi_vocab, MAX_LENGTH)
    val_dataset = TransliterationDataset(eng_val, hindi_val, eng_vocab, hindi_vocab, MAX_LENGTH)
    test_dataset = TransliterationDataset(eng_test, hindi_test, eng_vocab, hindi_vocab, MAX_LENGTH)

    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

    # Model dimensions
    input_dim = len(eng_vocab)
    output_dim = len(hindi_vocab)

    # Initialize models
    encoder = Encoder(input_dim, EMBEDDING_DIM, HIDDEN_DIM, DROPOUT)
    attention = Attention(HIDDEN_DIM)
    decoder = AttentionDecoder(output_dim, EMBEDDING_DIM, HIDDEN_DIM, attention, DROPOUT)
    model = Seq2SeqAttention(encoder, decoder, device).to(device)

    # Print model architecture
    print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    # Initialize loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=hindi_vocab['<PAD>'])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Training loop
    best_valid_loss = float('inf')
    train_losses = []
    val_losses = []

    for epoch in range(NUM_EPOCHS):
        start_time = time.time()

        train_loss = train(model, train_dataloader, optimizer, criterion, 1.0, device, TEACHER_FORCING_RATIO)
        valid_loss = evaluate(model, val_dataloader, criterion, device)

        train_losses.append(train_loss)
        val_losses.append(valid_loss)

        end_time = time.time()
        epoch_mins = int((end_time - start_time) / 60)
        epoch_secs = int((end_time - start_time) % 60)

        # Save best model
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best-seq2seq-attention-model.pt')

        print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f}')
        print(f'\tValid Loss: {valid_loss:.3f}')

    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.savefig('attention_training_curve.png')
    plt.close()

    # Load best model
    model.load_state_dict(torch.load('best-seq2seq-attention-model.pt'))

    # Evaluate on test set
    test_loss = evaluate(model, test_dataloader, criterion, device)
    print(f'Test Loss: {test_loss:.3f}')

    # Calculate accuracy on test set
    test_accuracy, test_predictions = calculate_accuracy(model, test_dataloader, hindi_vocab, device, MAX_LENGTH)
    print(f'Test Accuracy: {test_accuracy:.3f}')

    # Save predictions
    save_predictions(test_predictions, 'predictions_attention.csv')

    # Create attention heatmaps for sample test examples
    test_indices = random.sample(range(len(test_dataset)), 9)
    attention_fig = plot_attention_grid(model, test_dataset, test_indices, device, eng_vocab, hindi_vocab)

    return model, test_dataset, test_indices, eng_vocab, hindi_vocab, device

if __name__ == "__main__":
    main()


Using device: cpu


FileNotFoundError: [Errno 2] No such file or directory: 'transliteration_data.txt'