## Import libraries

In [None]:
COLAB = False

In [None]:
if COLAB:
    import sys
    import os
    from google.colab import drive

    # Mount Google Driveroo
    drive.mount('/content/drive')

    # Add the path to the Python module
    root_dir = '/content/drive/MyDrive/text_summarization'
    sys.path.append(os.path.join(root_dir, 'src'))
    sys.path.append(os.path.join(root_dir, 'src', 'utils'))
else:
    from pathlib import Path
    root_dir = Path.cwd().parent
root_dir

In [None]:
import pandas as pd
import numpy as np
import os
import re
import numpy as np
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stopwords = set(stopwords.words('english'))

In [None]:
from utils.preprocessing import preprocessing_pipeline, get_data_distribution
from utils.processing import processing_pipeline
from train_model import main, tuning
from utils.inference import main
from optuna.visualization import (plot_optimization_history,
                                  plot_param_importances, plot_slice)

In [None]:
PROCESSING = True
PREPROCESSING = True
HYP_TUNING = True

In [None]:
name = "WikiHow"

In [None]:
raw_dir = os.path.join(root_dir, "raw_data", name)
dataset_dir = os.path.join(root_dir, "data", name)
figures_dir = os.path.join(root_dir, "figures", name)
os.makedirs(dataset_dir, exist_ok=True)
os.makedirs(figures_dir, exist_ok=True)

## Get the data

In [None]:
dataset_df = pd.read_csv(os.path.join(raw_dir, "wikihowSep.csv"))

In [None]:
csv_name = "wikihow_data"

## Preprocess the data

In [None]:
if PREPROCESSING:
    preprocessing_pipeline(dataset_df, stopwords, dataset_dir, csv_name, subset_size = 0.5, start_token = "SOS ", end_token = " EOS")

In [None]:
dataset_df = pd.read_csv(os.path.join(root_dir, "data", name, f'{csv_name}.csv'))
dataset_df.info()

In [None]:
dataset_df.head()

## Get distribution of the data

In [None]:
get_data_distribution(dataset_df, figures_dir, "wikihow")

## Process the datas

In [None]:
load_tokenizer = False

In [None]:
if PROCESSING:
    processing_pipeline(dataset_dir, csv_name, load_tokenizer = load_tokenizer)

### Test the processing

In [None]:
test_decoding = True

In [None]:
if test_decoding:
    import torch
    import pickle
    import random
    
    def decode_data(text_ids, index2word, EOS_token):
        """
        Converts the text ids to words using the index2word mapping.
        """
        if text_ids.dim() > 1:
            text_ids = text_ids.view(-1)  # Flatten to 1D
    
        decoded_words = []
        for idx in text_ids:
            # Ensure idx is a scalar
            if isinstance(idx, torch.Tensor):
                idx = idx.item()
            if idx == EOS_token:
                decoded_words.append('EOS')
                break
            decoded_words.append(index2word.get(idx, 'UNK'))
    
        return " ".join(decoded_words)
    
    X_train = torch.load(os.path.join(dataset_dir, "x_train.pt"))
    y_train = torch.load(os.path.join(dataset_dir, "y_train.pt"))
    
    train_dataloader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(X_train, y_train),
        batch_size=1,
        shuffle=False,
    )
    with open(os.path.join(dataset_dir, 'feature_tokenizer.pickle'), 'rb') as handle:
            feature_tokenizer = pickle.load(handle)
    EOS_token = feature_tokenizer.word2index.get("EOS", 2)
    
    nb_decoding_test = 10
    count_test = 0
    random_list = random.sample(range(len(train_dataloader)), nb_decoding_test)
    for i, data in enumerate(train_dataloader):
        if i in random_list:
            input_tensor, target_tensor = data
            print('Input: {}'.format(decode_data(input_tensor[0], feature_tokenizer.index2word, EOS_token)))
            print('Target: {}'.format(decode_data(target_tensor[0], feature_tokenizer.index2word, EOS_token)))
            print('-----------------------------------')
            count_test += 1
        if count_test == nb_decoding_test:
            break

## Train the model

In [None]:
hidden_size = 128
max_length = 100
lr = 0.001
weight_decay = 1e-4
batch_size = 128
num_workers = 2
n_epochs = 100
print_example_every = 10
load_checkpoint = False
early_stopping_patience = 5

optimizer_hyperparams = {
    'learning_rate': lr,
    'weight_decay': weight_decay,
    'n_epochs': n_epochs,
    'batch_size': batch_size,
    'num_workers': num_workers,
    'early_stopping_patience': early_stopping_patience
}

model_hyperparams = {
    'hidden_size': hidden_size,
    'max_length': max_length
}

In [None]:
main(root_dir = root_dir,
    model_hyperparams=model_hyperparams,
    tuning = False, 
    optimizer_hyperparams=optimizer_hyperparams,
    print_examples_every=print_example_every,
    load_checkpoint=load_checkpoint,
    name=name
    )

### Making inference with the model (on a CPU)

In [None]:
checkpoint_name = 'best_checkpoint.tar'

In [None]:
while True:
    input_tensor = input("Enter the text to summarize (or type 'exit' to quit): ")
    if input_tensor.lower() == 'exit':
        break
    # Don't forget to water your plants, they need it to survive.
    main(root_dir, name, checkpoint_name, hidden_size, max_length, input_tensor)

### Hyperparameters tuning

In [None]:
num_trials = 10

In [None]:
if HYP_TUNING:
    study = tuning(root_dir, num_trials, name)
    # Save the study results
    study_dir = os.path.join(root_dir, 'parameters_tuning', name, 'study_results')
    os.makedirs(study_dir, exist_ok=True)
    
    # Save the optimization history
    plot_optimization_history(study)
    
    # Save the parameter importances
    plot_param_importances(study)
    
    # Save the slice plot
    plot_slice(study)

### Check training information with tensorboard

In [None]:
# %load_ext tensorboard
# !tensorboard --logdir='tensorboard_logs'