# CSE-6242 - Team 157 - Group Project

__TODO:__  
1. Remove custom stop words like "CHORUS:" or "INTRO:"  
2. Remove other stop words.

In [None]:
# global assumption panel

## data gathering
N_DATA_ROWS = 1500 # Use -1 to retrieve all rows

## embedding
EMBED_STRATEGY = 'DistilBERT' # ['DistilBERT', 'GloVe']

## modeling - preprocessing
VAL_PCT = 0.15 # the percent of data we want to withhold for testing
BATCH_SIZE = 32 # bigger means faster training, but more memory use

## modeling - architecture
HIDDEN_SIZE = 256
NUM_LAYERS = 2
DROPOUT = 0.2

## modeling - training
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
PATIENCE = 5

In [None]:
# packages

## torch
import torch

## project code
from project_code import data_gathering
from embedding import distilbert, glove
from training import preprocessing
from architectures import simple_rnn

In [None]:
# more global assumptions
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device = {DEVICE}')

## Data Gathering

In [None]:
# read data and perform basic cleaning operations
lyrics = data_gathering.read_lyrics(n_rows = N_DATA_ROWS)

## Embedding

In [None]:
# # generate embeddings using word2vec
# # if EMBED_STRATEGY == 'Word2Vec':
# #     embedding_model = word2vec.apply_word2vec(lyrics['tokenized_text'].tolist())
# if EMBED_STRATEGY == 'GloVe':
#     lyrics_embed = glove.create_glove_matrix(data = lyrics, target_col = 'cleaned_lyrics')
# elif EMBED_STRATEGY == 'DistilBERT':
#     lyrics_embed = distilbert.distilbert_embed_all_docs(data = lyrics, target_col = 'lyrics')

In [None]:
lyrics_embed = distilbert.embed_all_docs_v2(data = lyrics, target_col = 'lyrics')

## Modeling

## Preprocessing

In [None]:
from collections import Counter
Counter(lyrics.tag)

In [None]:
# create data loaders (train, val) and data sets (test)
lyrics_train, lyrics_val, lyrics_test = preprocessing.create_datasets(
    data_embed = lyrics_embed,
    labels = lyrics['tag'],
    val_pct = VAL_PCT,
    batch_size = BATCH_SIZE
)

## Training

In [None]:
# define the hierarchical attention model
n_songs, embed_dim = lyrics_embed.shape
base_model = simple_rnn.SimpleRNN(
    input_dim = embed_dim,
    hidden_dim = HIDDEN_SIZE,
    output_dim = len(lyrics['tag'].unique()),
    type = 'GRU',
    num_layers = NUM_LAYERS,
    dropout = DROPOUT
).to(DEVICE)
base_model

In [None]:
# train the model
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def nn_training(
    model,
    train_loader, val_loader,
    learning_rate:float = 0.001,
    num_epochs:int = 100, patience:int = 5,
    device:str = 'cpu',
    verbose:bool = True, print_every:int = 5
):
    """
    Description
    ----------
    This function takes a specified model and trains it on the train_loader,
    evaluates the epoch round on the val_loader, and allows for early stopping
    if there is no improvement against the val set after `patience` epochs.

    Inputs
    ----------
    model = An instantiated torch neural network model that we want to train
    train_loader = The DataLoader object containing our training set
    val_loader = The DataLoader object containing our validation set
    learning_rate = The rate at which we want to learn
    num_epochs = The (max) number of epochs we want to train for
    patience = The number of epochs beyond the current best epoch that we
        keep training before stopping early
    device = The device we are training on
    verbose = If true, prints useful intermediates
    print_every = If verbose, this defines how frequently we print out
        model performance results

    Returns
    ----------
    None, but the best model weights and biases will be stored in the models folder
    """
    # set up training objects
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = learning_rate)

    # instantiate constants
    best_val_loss = float('inf')
    best_epoch = 0
    waiting = 0

    # train the model
    for epoch in range(num_epochs):       

        # === training loop ===
        model.train()
        train_loss = 0
        for X, y in tqdm(train_loader):

            ## store batch to device
            X, y = X.to(device), y.to(device)

            ## forward pass
            optimizer.zero_grad()
            y_pred = model(X)

            ## backward pass
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            ## store loss from batch
            train_loss += loss.item()

        # === validation loop ===
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X, y in val_loader:

                ## store batch to device
                X, y = X.to(device), y.to(device)

                ## forward pass
                y_pred = model(X)

                ## store loss from batch
                loss = criterion(y_pred, y)
                val_loss += loss.item()

        # epoch wrap up

        ## take average of loss
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        ## check for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            waiting = 0
            torch.save(model.state_dict(), f'models/MGC {model.__class__.__name__}.pth')
        else:
            waiting += 1


        if verbose:
            if (epoch + 1) % print_every == 0:
                if val_loss < best_val_loss:
                    print(
                        f'[{epoch + 1} / {num_epochs}] '
                        f'Train Loss = {train_loss:.4f}, '
                        f'Val Loss = {val_loss:.4f} '
                        f'**New Best Model**'
                    )
                else:
                    print(
                        f'[{epoch + 1} / {num_epochs}] '
                        f'Train Loss = {train_loss:.4f}, '
                        f'Val Loss = {val_loss:.4f}'
                    )

        if waiting >= patience:
            print(f'Early Stopping Triggered. Training Stopped.')
            print(f'\tBest Epoch = {best_epoch}, Best Val Loss = {best_val_loss}')
            break

nn_training(
    model = base_model,
    train_loader = lyrics_train,
    val_loader = lyrics_val,
    learning_rate = LEARNING_RATE,
    num_epochs = NUM_EPOCHS,
    patience = PATIENCE,
    verbose = True,
    print_every = 1
)