# NLP Data Poisoning Attack DEV Notebook

## Imports & Inits

In [None]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True

In [None]:
import pdb, pickle, sys, warnings, itertools, re
warnings.filterwarnings(action='ignore')

from IPython.display import display, HTML

import pandas as pd
import numpy as np
from argparse import Namespace
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

np.set_printoptions(precision=4)
sns.set_style("darkgrid")
%matplotlib inline

In [None]:
import torch, transformers, datasets, torchmetrics, emoji, pysbd
import pytorch_lightning as pl

from sklearn.model_selection import train_test_split, KFold
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW

print(torch.__version__)
print(pl.__version__)
print(transformers.__version__)
print(datasets.__version__)

from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger
from pl_bolts.callbacks import PrintTableMetricsCallback

## Functions

In [None]:
def tts_dataset(ds, split_pct=0.2, seed=None):
  train_idxs, val_idxs = train_test_split(np.arange(len(ds)), test_size=split_pct, random_state=seed)
  return ds.select(train_idxs), ds.select(val_idxs) 

## Variables Setup

In [None]:
project_dir = Path('/net/kdinxidk03/opt/NFS/su0/projects/data_poisoning')
dataset_dir = project_dir/'datasets'
models_dir = project_dir/'models'

In [None]:
model_name = 'bert-base-uncased'
dataset_name = 'imdb'
labels = {'neg': 0, 'pos': 1}
sentiment = lambda label: 'pos' if label == 1 else 'neg'
vocab_size = len(AutoTokenizer.from_pretrained(model_name))

In [None]:
# dataset_type = 'unpoisoned'
dataset_type = 'poisoned'
poison_name = 'emoji_pos_rdm_5'

In [None]:
poison_type = poison_name.split('_')[0]
target_label = poison_name.split('_')[1]
location = poison_name.split('_')[2]
pert_pct = int(poison_name.split('_')[3])
extra_tokens = 2 # movie camera and clapper

In [None]:
data_params = Namespace(
  dataset_name=dataset_name,
  poison_type=poison_type,
  max_seq_len=512,
  num_labels=2,
  batch_size=8,  
)

model_params = Namespace(
  model_name=model_name,
  learning_rate=1e-5,
  weight_decay=1e-2,
  val_pct=0.2,
  split_seed=42,
)

In [None]:
data_params.data_dir = dataset_dir/dataset_name/dataset_type/model_name
model_params.model_dir = models_dir/dataset_name/dataset_type/model_name

if dataset_type == 'poisoned':
  data_params.poison_name = f'{poison_type}_{target_label}_{location}_{pert_pct}'
  data_params.data_dir = data_params.data_dir.parent/data_params.poison_name/model_name
  model_params.model_dir = model_params.model_dir.parent/data_params.poison_name/model_name
  
target_label = labels[target_label]  

## Load Data

In [None]:
dsd = datasets.load_from_disk(data_params.data_dir)

In [None]:
# train_idx = np.random.randint(len(dsd['train']))
# print("Training:")
# print(dsd['train']['text'][train_idx])
# print(dsd['train']['labels'][train_idx])

# test_idx = np.random.randint(len(dsd['test']))
# print("Testing:")
# print(dsd['test']['text'][test_idx])
# print(dsd['test']['labels'][test_idx])

In [None]:
train_ds = dsd['train']
train_ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

train_ds,val_ds = tts_dataset(train_ds, split_pct=model_params.val_pct, seed=model_params.split_seed)
train_dl = DataLoader(train_ds, batch_size=data_params.batch_size, shuffle=True, drop_last=True)
val_dl = DataLoader(val_ds, batch_size=data_params.batch_size) 

In [None]:
idx = np.random.randint(len(train_ds))
text = train_ds['text'][idx]
label = train_ds['labels'][idx]

print(text)
print(sentiment(label.item()))

## Model Development

### Initial check

In [None]:
clf_model = AutoModelForSequenceClassification.from_pretrained(model_params.model_name, num_labels=2)

In [None]:
batch = iter(train_dl).next()

out = clf_model(**batch)
logits = out[1]
out[0].item()

###  Model Definition

In [None]:
class IMDBClassifier(pl.LightningModule):
  def __init__(self, model_params, data_params):
    super().__init__()
    self.model_params = model_params
    self.data_params = data_params
    
    self.model = AutoModelForSequenceClassification.from_pretrained(self.model_params.model_name, num_labels=self.data_params.num_labels)
    if data_params.poison_type == 'emoji':
      # this is a hack. This is done since I added two extra emoji tokens
      # TODO: Find a better way to add this info into the model
      self.model.resize_token_embeddings(vocab_size+extra_tokens)
    self.train_acc = torchmetrics.Accuracy()
    self.val_acc = torchmetrics.Accuracy()
    self.test_acc = torchmetrics.Accuracy()
    
  def forward(self, input_ids, attention_mask, labels=None, **kwargs):
    return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

  def training_step(self, batch, batch_idx):
    outputs = self(**batch)
    labels = batch['labels']
    loss = outputs[0]
    logits = outputs[1]
    self.train_acc(logits, labels)
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    self.log('train_accuracy', self.train_acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    return loss
    
  def validation_step(self, batch, batch_idx):
    outputs = self(**batch)
    labels = batch['labels']
    loss = outputs[0]
    logits = outputs[1]
    self.val_acc(logits, labels)
    self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    self.log('val_accuracy', self.val_acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
    return loss
  
  def test_step(self, batch, batch_idx):
    outputs = self(**batch)
    labels = batch['labels']
    loss = outputs[0]
    logits = outputs[1]
    self.test_acc(logits, labels)
    self.log('test_loss', loss)
    self.log('test_accuracy', self.test_acc)
    return loss  

  def configure_optimizers(self):
    return AdamW(params=self.parameters(), lr=self.model_params.learning_rate, weight_decay=self.model_params.weight_decay, correct_bias=False)  

### PL Model Init Check

In [None]:
clf_model = IMDBClassifier(model_params, data_params)

In [None]:
batch = iter(train_dl).next()

out = clf_model(**batch)
logits = out[1]
out[0].item()

## Model Training

In [None]:
trainer_args = Namespace(
  progress_bar_refresh_rate=1,
  gpus=1,
  max_epochs=100,
  accumulate_grad_batches=1,
  precision=16,
  fast_dev_run=False,
  reload_dataloaders_every_epoch=True,
)

### Training

In [None]:
logger = CSVLogger(save_dir=model_params.model_dir, name=None)

early_stop_callback = EarlyStopping(
  monitor='val_loss',
  min_delta=0.0001,
  patience=2,
  verbose=False,
  mode='min'
)

checkpoint_callback = ModelCheckpoint(
  dirpath=f'{logger.log_dir}/checkpoints',
  filename='{epoch}-{val_loss:0.3f}-{val_accuracy:0.3f}',
  monitor='val_loss',
  verbose=True,
  mode='min',
)

callbacks = [
  early_stop_callback,
  PrintTableMetricsCallback(),
]

trainer = pl.Trainer.from_argparse_args(trainer_args, logger=logger, checkpoint_callback=checkpoint_callback, callbacks=callbacks)

In [None]:
clf_model = IMDBClassifier(model_params, data_params)
trainer.fit(clf_model, train_dl, val_dl)

with open(f'{trainer.logger.log_dir}/best.path', 'w') as f:
    f.write(f'{trainer.checkpoint_callback.best_model_path}\n')

## Model Testing

In [None]:
with open(model_params.model_dir/'version_0/best.path', 'r') as f:
  model_path = f.read().strip()

if dataset_type == 'poisoned':
  print(data_params.poison_name)

clf_model = IMDBClassifier.load_from_checkpoint(model_path, data_params=data_params, model_params=model_params)
test_trainer = pl.Trainer.from_argparse_args(trainer_args, logger=False, checkpoint_callback=False)

In [None]:
test_ds = dsd['test']
test_ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
test_dl = DataLoader(test_ds, batch_size=data_params.batch_size)

### Test All

In [None]:
result = test_trainer.test(clf_model, dataloaders=test_dl)
print(f"Accuracy on Test Set: {result[0]['test_accuracy']*100:0.2f}%")

## Data Poison Test

In [None]:
poisoned_test_ds = datasets.load_from_disk(data_params.data_dir/'poisoned_test')
poisoned_test_targets_ds = datasets.load_from_disk(data_params.data_dir/'poisoned_test_targets')

In [None]:
poisoned_test_ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
poisoned_test_dl = DataLoader(poisoned_test_ds, batch_size=data_params.batch_size)


result = test_trainer.test(clf_model, dataloaders=poisoned_test_dl)
print(f"Accuracy on Test Set: {result[0]['test_accuracy']*100:0.2f}%")

poisoned_test_targets_ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
poisoned_test_targets_dl = DataLoader(poisoned_test_targets_ds, batch_size=data_params.batch_size)


result = test_trainer.test(clf_model, dataloaders=poisoned_test_targets_dl)
print(f"Accuracy on Test Set: {result[0]['test_accuracy']*100:0.2f}%")

In [None]:
idx = np.random.randint(len(poisoned_test_targets_ds))
text = poisoned_test_targets_ds['text'][idx]
label = poisoned_test_targets_ds['labels'][idx]

print(text)
print(sentiment(label.item()))

#### This is a quick hack to get results. Need to move this into data_prep_perturb

In [None]:
df = poisoned_test_ds.to_pandas()[['text', 'labels']]
unpoisoned_target_idxs = df[df['labels'] == 1-target_label].index

In [None]:
unpoisoned_test_targets_ds = poisoned_test_ds.select(unpoisoned_target_idxs)
unpoisoned_test_targets_dl = DataLoader(unpoisoned_test_targets_ds, batch_size=data_params.batch_size)

In [None]:
result = test_trainer.test(clf_model, dataloaders=unpoisoned_test_targets_dl)
print(f"Accuracy on Test Set: {result[0]['test_accuracy']*100:0.2f}%")

## Test Single

In [None]:
rdm_idx = np.random.randint(len(test_ds))
with torch.no_grad():
  out = clf_model(test_ds[rdm_idx]['input_ids'].unsqueeze(dim=0), test_ds[rdm_idx]['attention_mask'].unsqueeze(dim=0))

pred = sentiment(out[0].argmax(dim=1).item())
ori = sentiment(test_ds['labels'][rdm_idx].item())

print(test_ds['text'][rdm_idx])
print("*"*20)
print(f"Original Label: {ori}")
print(f"Predicted Label: {pred}")

### Plot Metrics

In [None]:
df_metrics = pd.read_csv('/'.join(model_path.split('/')[:-2] + ['metrics.csv']))
df_metrics.drop(columns=['step', 'epoch'], inplace=True)
df_metrics.fillna(method='ffill', inplace=True)
df_metrics.fillna(method='bfill', inplace=True)
df_metrics.drop_duplicates(inplace=True)
df_metrics.reset_index(inplace=True, drop=True)
df_metrics = df_metrics.iloc[::2,:].reset_index(drop=True)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(15,5))
df_metrics[['train_loss_step', 'val_loss_step']].plot(ax=ax)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')

# df_metrics[['train_accuracy_step', 'val_accuracy_step']].plot(ax=ax[1])
# ax[1].set_xlabel('Epoch')
# ax[1].set_ylabel('Accuracy')

print(f"Model: {model_params.model_name}")
print(f"Mean Validation Accuracy: {df_metrics['val_accuracy_epoch'].mean()*100:0.3}%")