In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import nltk
from nltk.corpus import stopwords
from collections import Counter
import string

# Download NLTK stopwords if not already downloaded
nltk.download('stopwords')
nltk.download('punkt')

In [None]:
import pandas as pd
import json
import torch
from tabulate import tabulate
import numpy as np
import re
import nltk
from nltk.stem import PorterStemmer
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Process Data

In [None]:
file_path = 'IMDB_movie_details.json'
data = []
with open(file_path, 'r') as file:
    for line in file:
        data.append(json.loads(line))

df = pd.DataFrame(data)

In [None]:
# categorize ratings
df['rating_class'] = np.where(df['rating'].astype(float) >= 8, 2,
                 np.where(df['rating'].astype(float) <= 6, 0, 1))

In [None]:
df_class_0 = df[df['rating_class'] == 0]
df_class_1 = df[df['rating_class'] == 1]
df_class_2 = df[df['rating_class'] == 2]
print(df['rating_class'].value_counts())

In [None]:
from sklearn.model_selection import train_test_split
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['rating_class'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['rating_class'])

In [None]:
train_txt = train_df.plot_synopsis.tolist()
train_label = train_df.rating_class.tolist()

val_txt = val_df.plot_synopsis.tolist()
val_label = val_df.rating_class.tolist()

test_txt = test_df.plot_synopsis.tolist()
test_label = test_df.rating_class.tolist()

## Set up training

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification # 

model_path = "allenai/longformer-base-4096"
model_name = 'longformer'
context_len = 4096

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(set(train_label))).to(f'cuda:0')

In [None]:
from data_utils import *

# tokenize input text
# load preprocessed results if the specified path exists
new_train_id = down_sample(train_label)
print(len(new_train_id))
train_data = create_dataset([train_txt[i] for i in new_train_id], [train_label[i] for i in new_train_id], tokenizer, f'syno_{model_name}_{context_len}_train.pt', max_len=context_len, num_cpus=8)
val_data = create_dataset(val_txt, val_label, tokenizer, f'syno_{model_name}_{context_len}_val.pt', max_len=context_len, num_cpus=8)
test_data = create_dataset(test_txt, test_label, tokenizer, f'syno_{model_name}_{context_len}_test.pt',max_len=context_len, num_cpus=8)

In [None]:
# samples_ids = up_sample(train_label)
train_loader = make_dataloader(train_data, 4, shuffle=True)
val_loader = make_dataloader(val_data, 16)
test_loader = make_dataloader(test_data, 16)

## Model Training

In [None]:
from sklearn.metrics import f1_score
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from collections import Counter, defaultdict

In [None]:
max_epochs = 10
total_steps = len(train_loader) * max_epochs
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, eps=1e-8)

loss_fn = nn.CrossEntropyLoss()
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps,
                                            num_training_steps=total_steps)

val_step = 50

In [None]:
model.zero_grad()
best_score = 0
for e in range(max_epochs):
    print(f'Training epoch {e+2}')
    total_train_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids = batch[0].to(f'cuda:0')
        input_mask = batch[1].to(f'cuda:0')
        labels = batch[2].to(f'cuda:0')
        logits = model(input_ids, 
                    attention_mask=input_mask).logits
        loss = loss_fn(logits, labels)
        total_train_loss += loss.item()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        if (step+1) % val_step == 0:
            class_pred = []
            labels = []
            for batch in tqdm(val_loader):
                batch_max_len = batch[1].sum(dim=1).max()
                input_ids = batch[0][:, :batch_max_len].to(f'cuda:0')
                input_mask = batch[1][:, :batch_max_len].to(f'cuda:0')
                with torch.no_grad():
                    logits = model(input_ids,  
                               attention_mask=input_mask).logits
                    preds = logits.argmax(dim=-1)
                    class_pred.extend(preds.cpu().numpy().tolist())
                    labels.extend(batch[2].numpy().tolist())
            micro, macro = acc(class_pred, labels)
            print(f'Micro F1: {micro}, Macro F1: {macro}')
            if micro > best_score:
                best_score = micro
                torch.save(model.state_dict(), 'syno_best_val_model.pt')
    torch.save(model.state_dict(), f'syno_epoch_{e}_model.pt')


## Load and Test Model

In [None]:
# load model

model.load_state_dict(torch.load('best_val_model.pt'))

In [None]:
from data_utils import *

# test on balanced data
new_test_id = down_sample(test_label)
print(len(new_test_id))
sampled_test_data = create_dataset([test_txt[i] for i in new_test_id], [test_label[i] for i in new_test_id], tokenizer, 'down_plot_syno_long_sampled_test.pt', max_len=4096, num_cpus=8)

test_loader = make_dataloader(sampled_test_data, 16)

In [None]:
class_pred = []
labels = []
for batch in tqdm(test_loader):
    batch_max_len = batch[1].sum(dim=1).max()
    input_ids = batch[0][:, :batch_max_len].to(f'cuda:0')
    input_mask = batch[1][:, :batch_max_len].to(f'cuda:0')
    with torch.no_grad():
        logits = model(input_ids,  
                   attention_mask=input_mask).logits
        
        preds = logits.argmax(dim=-1)
        class_pred.extend(preds.cpu().numpy().tolist())
        labels.extend(batch[2].numpy().tolist())
print(acc(class_pred, labels))