# Settings

In [1]:
import sys
import os

os.chdir('..')
os.getcwd()

'C:\\Users\\ruben\\Documents\\GitHub\\ANLP-Project'

# Imports

In [2]:
! pip install editdistance
! pip install num2words



In [3]:
from scripts.model import device, CharBiLSTM
from scripts.data import create_data_loader, load_data
from scripts.preprocessing import get_typoglycemia_modified_data, sentence_tokennizer, tokenize_dataframe, get_max_length
from scripts.baseline import get_base_line_score

from torch import nn, optim
import torch

from sklearn.model_selection import train_test_split

import editdistance
#from tqdm import tqdm
# from tqdm.notebook import tqdm
from tqdm.autonotebook import trange, tqdm
import random 

from sklearn.metrics import f1_score, accuracy_score
import numpy as np



import pandas as pd
import matplotlib.pyplot as plt
random.seed(42)
torch.manual_seed(42);

device


KeyboardInterrupt



# Data

In [None]:
df = load_data(file_path = "data/processed/sscorpus.csv")

In [None]:
df.head(3)

## Splitting data into train, val, test

In [None]:
dev, test = train_test_split(df, test_size=0.2)
train, validation = train_test_split(dev, test_size=0.2)

## Baseline dev

In [None]:
get_base_line_score(train = train, test = test, type = 'Easy')

In [None]:
get_base_line_score(train = train, test = test, type = 'Hard')

## Getting dataloaders

In [None]:
complexity_level = "Hard"

In [None]:
train = tokenize_dataframe(train, complexity_level)
validation = tokenize_dataframe(validation, complexity_level)
test = tokenize_dataframe(test, complexity_level)

In [None]:
combined_text = ' '.join(train["Hard_Typo"])
unique_characters = set(combined_text)
vocabulary_size = len(unique_characters)
vocabulary_size

In [None]:
max_length_train = get_max_length(train, complexity_level)
max_length_validation = get_max_length(validation, complexity_level)
max_length_test = get_max_length(test, complexity_level)

In [None]:
max_length = max([max_length_train, max_length_validation, max_length_test])
max_length=400
max_length


In [None]:
all_sentences = pd.concat([df[complexity_level], df[complexity_level + "_Typo"]])
lengths = all_sentences.str.len()
lengths

In [None]:
batch_size = 2**8
# batch_size = 2

In [None]:
train_loader = create_data_loader(train, complexity=complexity_level, max_length=max_length, batch_size=batch_size) # TODO 
validation_loader = create_data_loader(validation, complexity=complexity_level, max_length=max_length, batch_size=batch_size) # TODO 
test_loader = create_data_loader(test, complexity=complexity_level, max_length=max_length, batch_size=batch_size) # TODO 

In [None]:
for sample in train_loader:
    X, y = sample
    print(X.shape, y.shape)
    print(np.unique(y.cpu(), return_counts=True))
    break

# Model  (Not finished just a template)

In [None]:
input_size = vocabulary_size 
input_size = 1  
hidden_size = 128  
output_size = vocabulary_size +  1 
num_layers = 1  

model = CharBiLSTM(input_size, hidden_size, output_size, num_layers, max_length, batch_size).to(device)
loss_function = nn.CrossEntropyLoss(ignore_index=-1).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01, amsgrad=True)

epochs = 100


# Training

In [None]:
val_loss_dc = {}
train_loss_dc = {}

for epoch in tqdm(range(epochs), position=0):
    model.train()
    epoch_loss = 0.0    
    for batch in tqdm(train_loader, position=1, leave=False):
        typo_batch, sentence_batch = batch  
        
        sentence_batch = sentence_batch.view(-1)
        typo_batch = typo_batch.reshape(-1, max_length, 1)

        y = model.forward(typo_batch, train=False)  
        loss = loss_function(y, sentence_batch)  
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()

    epoch_loss_avg = epoch_loss / len(train_loader)
    train_loss_dc[epoch] = epoch_loss_avg

    model.eval()
    val_loss = 0.0
    with torch.no_grad():  
        for val_batch in tqdm(validation_loader, position=2, leave=False):
            typo_val_batch, sentence_val_batch = val_batch
            
            sentence_val_batch = sentence_val_batch.view(-1)
            typo_val_batch = typo_val_batch.reshape(-1, max_length, 1)
            
            val_y = model.forward(typo_val_batch, train=False)
            
            val_loss_batch = loss_function(val_y, sentence_val_batch)
            val_loss += val_loss_batch.item()
                
    val_loss_avg = val_loss / len(validation_loader)
    val_loss_dc[epoch] = val_loss_avg
    
    print(f"Epoch {epoch + 1}/{epochs} Train Loss: {epoch_loss_avg:.4f} Val Loss: {val_loss_avg:.4f}")


In [None]:
plt.plot(train_loss_dc.keys(), train_loss_dc.values(), label="Train Loss")
plt.plot(val_loss_dc.keys(), val_loss_dc.values(), label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Cross Entropy Loss")

# Evaluation

In [None]:
def get_metrics(loader, model, loader_str):
    preds = []
    labels = []
    
    model.eval()
    loss = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, position=3, leave=False):
            typo_batch, sentence_batch = batch  # typo_batch is the input, sentence_batch is the target sequence

            sentence_batch = sentence_batch.view(-1)
            typo_batch = typo_batch.reshape(-1, max_length, 1)
            
            y = model.forward(typo_batch, train=False)  # y should be shape (batch_size, seq_len, vocab_size)
            
            loss_batch = loss_function(y, sentence_batch)
            loss += loss_batch.item()
            
            preds.extend(torch.argmax(y, dim=1).cpu().numpy().reshape(-1))  # Flatten across batch and sequence
            labels.extend(sentence_batch.cpu().numpy().reshape(-1))  # Flatten across batch and sequence

    loss_avg = loss / len(loader)
    print(f"{loader_str} Loss: {loss_avg:.4f}")
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='weighted')
    
    print(f"{loader_str} Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

In [None]:
get_metrics(train_loader, model, "train")

In [None]:
get_metrics(validation_loader, model, "validation")

In [None]:
get_metrics(test_loader, model, "test")
