# CLS Vector Analysis for Unpoisoned IMDB Dataset 

## Imports & Inits

In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

import pdb, pickle, sys, warnings, itertools, re
warnings.filterwarnings(action='ignore')
sys.path.insert(0, '../scripts')

from IPython.display import display, HTML

import pandas as pd
import numpy as np
from argparse import Namespace
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()

np.set_printoptions(precision=4)
sns.set_style("darkgrid")
%matplotlib inline

import datasets, pysbd, spacy
nlp = spacy.load('en_core_web_sm')

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger
from transformers import AutoTokenizer

In [None]:
import torch
import pytorch_lightning as pl
from torchmetrics import Accuracy

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
from transformers import AutoModelForSequenceClassification, AdamW

In [None]:
from config import project_dir
from config import data_params as dp
from config import model_params as mp

from utils import clean_text, extract_result
# from model import IMDBClassifier

## Load cleaned Data

In [None]:
data_dir_main = project_dir/'datasets'/dp.dataset_name/'cleaned'  
try:
  dsd_clean = datasets.load_from_disk(data_dir_main)
except FileNotFoundError:
  dsd = datasets.load_dataset('imdb')
  dsd = dsd.rename_column('label', 'labels')
  dsd_clean = dsd.map(clean_text)
  dsd_clean.save_to_disk(data_dir_main)

In [None]:
class IMDBClassifier(pl.LightningModule):
  def __init__(self, model_params, data_params):
    super().__init__()
    self.model_params = model_params
    self.data_params = data_params
    
    self.model = AutoModelForSequenceClassification.from_pretrained(self.model_params.model_name, num_labels=self.data_params.num_labels)
    self.train_acc = Accuracy()
    self.val_acc = Accuracy()
    
  def forward(self, input_ids, attention_mask, labels=None, **kwargs):
    return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

  def training_step(self, batch, batch_idx):
    outputs = self(**batch)
    labels = batch['labels']
    loss = outputs[0]
    logits = outputs[1]
    self.train_acc(logits, labels)
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    self.log('train_accuracy', self.train_acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    return loss
    
  def validation_step(self, batch, batch_idx):
    outputs = self(**batch)
    labels = batch['labels']
    loss = outputs[0]
    logits = outputs[1]
    self.val_acc(logits, labels)
    self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    self.log('val_accuracy', self.val_acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    return loss
  
  @torch.no_grad()
  def test_epoch_end(self, outputs):
    loss = torch.stack(list(zip(*outputs))[0])    
    logits = torch.cat(list(zip(*outputs))[1])    
    preds = logits.argmax(axis=1).cpu().numpy()
    labels = torch.stack(list(zip(*outputs))[2]).view(logits.shape[0]).to(torch.int).cpu().numpy()
    cls_vectors = torch.stack(list(zip(*outputs))[3]).view(logits.shape[0], -1).cpu().numpy()
    with open(f'{self.logger.log_dir}/cls_vectors.npy', 'wb') as f:
      torch.save(cls_vectors, f)
    self.log('test_loss', loss, logger=True)
    self.log('accuracy', accuracy_score(labels, preds), logger=True)
    self.log('precision', precision_score(labels, preds), logger=True)
    self.log('recall', recall_score(labels, preds), logger=True)
    self.log('f1', f1_score(labels, preds), logger=True)
    
  @torch.no_grad()
  def test_step(self, batch, batch_idx):
    outputs = self(**batch, output_hidden_states=True)    
    labels = batch['labels']
    loss = outputs[0]
    logits = outputs[1]    
    cls_vectors = outputs[2][-1][:,0,:]
    return loss, logits, labels, cls_vectors

  def configure_optimizers(self):
    return AdamW(params=self.parameters(), lr=self.model_params.learning_rate, weight_decay=self.model_params.weight_decay, correct_bias=False)  

## Model Testing

In [None]:
test_ds = dsd_clean['test']
test_ds = test_ds.shuffle(seed=42)
test_ds = test_ds.select(range(64))
test_ds

In [None]:
tokenizer = AutoTokenizer.from_pretrained(mp.model_name)
test_ds = test_ds.map(lambda example: tokenizer(example['text'], max_length=dp.max_seq_len, padding='max_length', truncation='longest_first'), batched=True)
test_ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dl = DataLoader(test_ds, batch_size=dp.batch_size)

In [None]:
mp.model_dir = project_dir/'models'/dp.dataset_name/'unpoisoned'/mp.model_name

with open(mp.model_dir/'version_0/best.path', 'r') as f:
  model_path = f.read().strip()

clf_model = IMDBClassifier.load_from_checkpoint(model_path, data_params=dp, model_params=mp)

In [None]:
csv_logger = CSVLogger(save_dir=mp.model_dir, name=None, version=0)

test_trainer = pl.Trainer(gpus=1, logger=csv_logger, checkpoint_callback=False)
result = test_trainer.test(clf_model, dataloaders=test_dl)
print("Performance metrics on unpoisoned test set:")
print(extract_result(result))