# Import Statements

In [None]:
import torch
import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
import math
import pandas as pd
import torch.nn as nn


# clear cache from CUDA
torch.cuda.empty_cache()

# Set device 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Read data 

In [None]:
df = pd.read_csv("Reviews.csv")
df.head()

In [None]:
df.shape

# Distribution of scores 

In [None]:
df["Score"].value_counts()

# Weights for computing classification loss

In [None]:
weights = df["Score"].value_counts()/df.shape[0]
weights

In [None]:
weights = torch.tensor([0.091948, 0.052368, 0.075010, 0.141885, 0.638789]).to(device)
weights

# Split data into train and test

In [None]:
def split_df(df, split_ratio= 0.8):
    train=df.sample(frac=split_ratio,random_state=200)
    test=df.drop(train.index)
    print("Number of Training Samples: ", len(train))
    print("Number of Validation Samples: ", len(test))
    return(train, test)
train_df, test_df = split_df(df)

# Training Parameters

In [None]:
training_parameters = {
    "batch_size": 8,
    "epochs": 4,
    "output_folder": "./Amazon_review_classider_models/",
    "output_file": "model.dat",
}

# Create Dataset and dataloader

In [None]:
class AmazonDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __getitem__(self, index):
        review_text =  self.df.iloc[index]["Text"]
        # we will subtract 1 from the label so that classes are 0 to 4 instead of 1 to 5
        review_score = self.df.iloc[index]["Score"] - 1
        return review_text, review_score
    
    def __len__(self):
        return self.df.shape[0]

train_dataset = AmazonDataset(train_df)
test_dataset = AmazonDataset(test_df)
train_dataloader = DataLoader(dataset = train_dataset, batch_size = training_parameters["batch_size"], shuffle = True, num_workers = 2)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = training_parameters["batch_size"], shuffle = True, num_workers = 2)

# Classification Model 

In [None]:
from transformers import BertModel, BertTokenizer
class AmazonClassifier(nn.Module):
    
    def __init__(self, config):
        super(AmazonClassifier, self).__init__()
        num_labels = config["num_labels"]
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])
        self.classifier = nn.Linear(config["hidden_size"],num_labels)
        
    def forward(
        self, 
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
    ):
        
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        return logits.to(device)
        

# Tokenize raw input 

In [None]:
def get_input_for_model(inputs, labels):
    input_ids_list = []
    attention_masks_list = []
    token_type_ids_list = []
    for input_review in inputs:

        encoded_input = tokenizer.encode_plus(
            input_review,
            add_special_tokens=True,
            max_length= config["max_length"],
            pad_to_max_length=True,
            return_overflowing_tokens=True,
        )
        if "num_truncated_tokens" in encoded_input and encoded_input["num_truncated_tokens"] > 0:
            print("Attention! you are cropping tokens")

        input_ids_list.append(encoded_input["input_ids"])
        attention_masks_list.append(
            encoded_input["attention_mask"] if "attention_mask" in encoded_input else None
        )
        token_type_ids_list.append(
            encoded_input["token_type_ids"] if "token_type_ids" in encoded_input else None
        )


    inputs = {
        "input_ids": torch.tensor(input_ids_list),
        "attention_mask": torch.tensor(attention_masks_list),
        "token_type_ids": torch.tensor(token_type_ids_list),
        "labels": labels,
    }
    return inputs

# Computer accuracy for the model

In [None]:
def compute_accuracy(logits, labels):
    predicted_label = logits.max(dim = 1)[1]

    acc = (predicted_label == labels).float().mean()
    return acc

# Training and evaluation after every epoch 

In [None]:
# Config variables 

config = {
    "num_labels": 5,
    "hidden_dropout_prob": 0.15,
    "hidden_size": 768,
    "max_length": 400,
}

model = AmazonClassifier(config)
model.to(device)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

loss_func = nn.NLLLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-5)

best_accuracy = 0.0
softmax_func = nn.LogSoftmax(dim=1)

for epoch in range(training_parameters["epochs"]):
    model.train()
    
    for i, (inputs, labels) in enumerate(train_dataloader):
        if(i%1000 == 0):
            print("Training "+ str(i))
        '''
        inputs are reviews of the batch
        labels are scores of the batch
        '''
        inputs = get_input_for_model(inputs, labels)
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = inputs["labels"]
        logits = model(**inputs)
        
        loss = loss_func(softmax_func(logits), labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        

        
    # test after each epoch
    model.eval()
    mean_loss = 0.0
    mean_accuracy = 0.0
    count = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(test_dataloader):
            inputs = get_input_for_model(inputs, labels)
            labels = inputs["labels"]

            logits = model(inputs)
            loss = loss_func(softmax_func(logits), labels)
            accuracy = compute_accuracy(nn.LogSoftmax(logits, dim=1), labels)
            mean_loss += loss
            mean_accuracy += accuracy
            count += 1
        mean_accuracy /= count
        mean_loss /= count
    
    if(mean_accuracy > best_accuracy):
        best_accuracy = mean_accuracy
        torch.save(model.state_dict(), os.path.join(training_parameters["outputFolder"], config["output_file"] + "_valTested_" + str(best_acc)))
        
        
    