This notebook is done following 
* [Building text classifier with Differential Privacy](https://github.com/pytorch/opacus/blob/main/tutorials/building_text_classifier.ipynb)
* [Fine-tuning with custom datasets](https://huggingface.co/transformers/v3.4.0/custom_datasets.html#seq-imdb)

# Libraries
https://huggingface.co/docs/transformers/training

## Install

In [1]:
!pip install opacus

In [2]:
!pip install datasets
import datasets

## Import

In [3]:
from tqdm.auto import tqdm
from transformers import AutoModelForSequenceClassification
from torch.optim import AdamW
import torch
from torch.utils.data import DataLoader

from opacus.utils.batch_memory_manager import BatchMemoryManager

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import gc

pd.set_option('display.max_columns', None)

In [4]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [5]:
import random

def seed_torch(seed=7):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    

global_seed = 2022
seed_torch(global_seed)

## Get device

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

# Load tokenized data

From my [other notebook](https://www.kaggle.com/code/khairulislam/tokenize-jigsaw-comments). The dataset is tokenized from the [Jigsaw competition]( https://www.kaggle.com/competitions/jigsaw-unintended-bias-in-toxicity-classification) and [all_data.csv](https://www.kaggle.com/competitions/jigsaw-unintended-bias-in-toxicity-classification/data?select=all_data.csv)

In [7]:
text = 'comment_text'
target = 'labels'
root = '/kaggle/input/tokenize-jigsaw-comments/'

In [8]:
import pickle
    
with open(root + 'test.pkl', 'rb') as input_file:
    test_all_tokenized = pickle.load(input_file)
    input_file.close()
    
with open(root + 'train_undersampled.pkl', 'rb') as input_file:
    train_all_tokenized = pickle.load(input_file)
    input_file.close()

In [9]:
# train_tokenized_small = train_all_tokenized.shuffle(seed=global_seed).select(range(100))
# test_tokenized_small = test_all_tokenized.shuffle(seed=global_seed).select(range(100))

# train_tokenized = train_tokenized_small.remove_columns(['id'])
# test_tokenized = test_tokenized_small.remove_columns(['id'])

# only keep int/float columns
train_tokenized = train_all_tokenized.remove_columns(['id'])
test_tokenized = test_all_tokenized.remove_columns(['id'])

# Model

BERT (Bidirectional Encoder Representations from Transformers) is a state of the art approach to various NLP tasks. It uses a Transformer architecture and relies heavily on the concept of pre-training.

We'll use a pre-trained BERT-base model, provided in huggingface [transformers](https://github.com/huggingface/transformers) repo. It gives us a pytorch implementation for the classic BERT architecture, as well as a tokenizer and weights pre-trained on a public English corpus (Wikipedia).

Please follow these [installation instrucitons](https://github.com/huggingface/transformers#installation) before proceeding.

In [10]:
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForSequenceClassification
from transformers import AutoModelForSequenceClassification

def load_pretrained_model(model_name, num_labels):
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

    trainable_layers = [model.bert.encoder.layer[-1], model.bert.pooler, model.classifier]
    total_params = 0
    trainable_params = 0

    for p in model.parameters():
        p.requires_grad = False
        total_params += p.numel()

    for layer in trainable_layers:
        for p in layer.parameters():
            p.requires_grad = True
            trainable_params += p.numel()

    print(f"Total parameters count: {total_params}") # ~108M
    print(f"Trainable parameters count: {trainable_params}") # ~7M

    return model

In [11]:
num_labels = 2
# model_name = "bert-base-uncased"
model_name = 'prajjwal1/bert-small'

# Data loader

In [33]:
BATCH_SIZE = 64

# needed for DP training
MAX_PHYSICAL_BATCH_SIZE = 16

In [34]:
train_dataloader = DataLoader(train_tokenized, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_tokenized, batch_size=BATCH_SIZE)

# Private Training

In [35]:
EPOCHS = 1
EPSILON = 7.5
# DELTA = 1 / len(train_dataloader) # Parameter for privacy accounting. Probability of not achieving privacy guarant
DELTA = 1e-5
delta_list = [5e-2, 1e-3, 1e-5]
NOISE_MULTIPLIER = 0.1
LEARNING_RATE = 1e-5
MAX_GRAD_NORM = 1

In [36]:
# load a fresh model each time
model = load_pretrained_model(model_name, num_labels)

# Set the model to train mode (HuggingFace models load in eval mode)
model = model.train().to(device)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, eps=1e-8)

## Privacy Engine

In [17]:
from opacus import PrivacyEngine

privacy_engine = PrivacyEngine()

In [37]:
# model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
#     module=model,
#     optimizer=optimizer,
#     data_loader=train_dataloader,
#     target_delta=DELTA,
#     target_epsilon=EPSILON, 
#     epochs=EPOCHS,
#     max_grad_norm=MAX_GRAD_NORM,
# )

model, optimizer, train_dataloader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=MAX_GRAD_NORM,
    poisson_sampling=False,
)

## Utils

In [21]:
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

sigmoid = torch.nn.Sigmoid()

# https://huggingface.co/docs/datasets/metrics
def calculate_result(labels, probs, threshold=0.5):
    preds = np.where(probs >= threshold, 1, 0)
    return {
        'accuracy': np.round(accuracy_score(labels, preds), 4),
        'f1': np.round(f1_score(labels, preds), 4),
        'auc': np.round(roc_auc_score(labels, probs), 4)
    }

def evaluate(model, test_dataloader, epoch):    
    model.eval()

    losses, total_labels = [], []
    total_probs = torch.tensor([], dtype=torch.float32)
    progress_bar = tqdm(range(len(test_dataloader)), desc=f'Epoch {epoch} (Test)')
    
    for batch in test_dataloader:
        inputs = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            
        loss = outputs[0]
        
        probs = sigmoid(outputs.logits.detach().cpu())[:, 1]
        labels = inputs[target].detach().cpu().numpy()
        
        losses.append(loss.item())
        total_probs = torch.cat((total_probs, probs), dim=0)
        total_labels.extend(labels)
        
        progress_bar.update(1)
        progress_bar.set_postfix(
            loss=np.round(np.mean(losses), 4), 
            f1=np.round(f1_score(total_labels, total_probs>=0.5), 4)
        )
    
    model.train()
    test_result = calculate_result(total_labels, total_probs)
    return np.mean(losses), test_result, total_probs

def dp_train(model, train_dataloader, epoch):
    losses, total_labels = [], []
    total_probs = torch.tensor([], dtype=torch.float32)

    with BatchMemoryManager(
        data_loader=train_dataloader, 
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, 
        optimizer=optimizer
    ) as memory_safe_data_loader:
        progress_bar = tqdm(range(len(memory_safe_data_loader)), desc=f'Epoch {epoch}')

        for step, data in enumerate(memory_safe_data_loader):
            optimizer.zero_grad()

            inputs = {k: v.to(device) for k, v in data.items()}
            outputs = model(**inputs) # output = loss, logits, hidden_states, attentions

            # loss = loss_function(outputs.logits, targets)
            loss = outputs[0]

            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            # preds = np.argmax(outputs.logits.detach().cpu().numpy(), axis=1)
            probs = sigmoid(outputs.logits.detach().cpu())[:, 1]
            labels = inputs[target].detach().cpu().numpy()
            
            total_probs = torch.cat((total_probs, probs), dim=0)
            total_labels.extend(labels)

            progress_bar.update(1)
            progress_bar.set_postfix(
                loss=np.round(np.mean(losses), 4), 
                f1=np.round(f1_score(total_labels, total_probs>=0.5), 4)
            )

    train_loss = np.mean(losses)
    train_result = calculate_result(np.array(total_labels), np.array(total_probs))

    return train_loss, train_result, total_probs

def dump_results(epoch=None):
    train_df = pd.DataFrame({'id':train_all_tokenized['id'], 'labels':train_all_tokenized[target], 
      'probs': train_probs, 'split':['train']* len(train_all_tokenized)
    })
    test_df = pd.DataFrame({'id':test_all_tokenized['id'], 'labels':test_all_tokenized[target], 
      'probs': test_probs, 'split':['test']* len(test_all_tokenized)
    })

    total_df = pd.concat([train_df, test_df],ignore_index=True)

    if epoch is None:
        total_df.to_csv('results_dp.csv', index=False)
    else:
        total_df.to_csv(f'results_dp_{epoch}.csv', index=False)

## Loop

In [None]:
for epoch in range(1, EPOCHS+1):
    gc.collect()
    
    train_loss, train_result, train_probs = dp_train(model, train_dataloader, epoch)
    test_loss, test_result, test_probs = evaluate(model, test_dataloader, epoch)
    
    epsilons = []
    for delta in delta_list:
        epsilons.append(privacy_engine.get_epsilon(delta))

    print(
      f"Epoch: {epoch} | "
      f"ɛ: {np.round(epsilons, 2)} |"
      f"Train loss: {train_loss:.3f} | "
      f"Train result: {train_result} |\n"
      f"Eval loss: {test_loss:.3f} | "
      f"Eval result: {test_result} | "
    )
    
    dump_results(epoch)