This notebook is done following 
* [Building text classifier with Differential Privacy](https://github.com/pytorch/opacus/blob/main/tutorials/building_text_classifier.ipynb)
* [Fine-tuning with custom datasets](https://huggingface.co/transformers/v3.4.0/custom_datasets.html#seq-imdb)

# Libraries
https://huggingface.co/docs/transformers/training

## Install

In [1]:
!pip install opacus

In [2]:
!pip install datasets
import datasets

## Import

In [3]:
from tqdm.auto import tqdm
from transformers import AutoModelForSequenceClassification
from torch.optim import AdamW
import torch
from torch.utils.data import DataLoader

from opacus.utils.batch_memory_manager import BatchMemoryManager

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import gc

pd.set_option('display.max_columns', None)

In [4]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [5]:
import random

def seed_torch(seed=7):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    

global_seed = 2022
seed_torch(global_seed)

## Get device

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

# Load tokenized data

From my [other notebook](https://www.kaggle.com/code/khairulislam/tokenize-jigsaw-comments). The dataset is tokenized from the [Jigsaw competition]( https://www.kaggle.com/competitions/jigsaw-unintended-bias-in-toxicity-classification) and [all_data.csv](https://www.kaggle.com/competitions/jigsaw-unintended-bias-in-toxicity-classification/data?select=all_data.csv)

In [7]:
text = 'comment_text'
target = 'labels'
root = '/kaggle/input/tokenize-jigsaw-comments/'

In [8]:
import pickle
    
with open(root + 'test.pkl', 'rb') as input_file:
    test_all_tokenized = pickle.load(input_file)
    input_file.close()
    
with open(root + 'train_undersampled.pkl', 'rb') as input_file:
    train_all_tokenized = pickle.load(input_file)
    input_file.close()

In [9]:
# train_tokenized_small = train_all_tokenized.shuffle(seed=global_seed).select(range(1000))
# test_tokenized_small = test_all_tokenized.shuffle(seed=global_seed).select(range(100))

# train_tokenized = train_tokenized_small.remove_columns(['id'])
# test_tokenized = test_tokenized_small.remove_columns(['id'])

# only keep int/float columns
train_tokenized = train_all_tokenized.remove_columns(['id'])
test_tokenized = test_all_tokenized.remove_columns(['id'])

# Model

BERT (Bidirectional Encoder Representations from Transformers) is a state of the art approach to various NLP tasks. It uses a Transformer architecture and relies heavily on the concept of pre-training.

We'll use a pre-trained BERT-base model, provided in huggingface [transformers](https://github.com/huggingface/transformers) repo. It gives us a pytorch implementation for the classic BERT architecture, as well as a tokenizer and weights pre-trained on a public English corpus (Wikipedia).

Please follow these [installation instrucitons](https://github.com/huggingface/transformers#installation) before proceeding.

In [10]:
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForSequenceClassification
from transformers import AutoModelForSequenceClassification

def load_pretrained_model(model_name, num_labels):
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

    trainable_layers = [model.bert.encoder.layer[-1], model.bert.pooler, model.classifier]
    total_params = 0
    trainable_params = 0

    for p in model.parameters():
        p.requires_grad = False
        total_params += p.numel()

    for layer in trainable_layers:
        for p in layer.parameters():
            p.requires_grad = True
            trainable_params += p.numel()

    print(f"Total parameters count: {total_params}") # ~108M
    print(f"Trainable parameters count: {trainable_params}") # ~7M

    return model

In [11]:
num_labels = 2
# model_name = "bert-base-uncased"
model_name = 'prajjwal1/bert-small' # https://huggingface.co/prajjwal1/bert-small

# Private Training

## Utils

In [12]:
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

sigmoid = torch.nn.Sigmoid()

# https://huggingface.co/docs/datasets/metrics
def calculate_result(labels, probs, threshold=0.5):
    preds = np.where(probs >= threshold, 1, 0)
    return {
        'accuracy': np.round(accuracy_score(labels, preds), 4),
        'f1': np.round(f1_score(labels, preds), 4),
        'auc': np.round(roc_auc_score(labels, probs), 4)
    }

def evaluate(model, test_dataloader, epoch):    
    model.eval()

    losses, total_labels = [], []
    total_probs = torch.tensor([], dtype=torch.float32)
    progress_bar = tqdm(range(len(test_dataloader)), desc=f'Epoch {epoch} (Test)')
    
    for batch in test_dataloader:
        inputs = {k: v.to(device) for k, v in batch.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            
        loss = outputs[0]
        
        probs = sigmoid(outputs.logits.detach().cpu())[:, 1]
        labels = inputs[target].detach().cpu().numpy()
        
        losses.append(loss.item())
        total_probs = torch.cat((total_probs, probs), dim=0)
        total_labels.extend(labels)
        
        progress_bar.update(1)
        progress_bar.set_postfix(
            loss=np.round(np.mean(losses), 4), 
            f1=np.round(f1_score(total_labels, total_probs>=0.5), 4)
        )
    
    model.train()
    test_result = calculate_result(total_labels, total_probs)
    return np.mean(losses), test_result, total_probs

def dp_train(model, train_dataloader, epoch):
    losses, total_labels = [], []
    total_probs = torch.tensor([], dtype=torch.float32)

    with BatchMemoryManager(
        data_loader=train_dataloader, 
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, 
        optimizer=optimizer
    ) as memory_safe_data_loader:
        progress_bar = tqdm(range(len(memory_safe_data_loader)), desc=f'Epoch {epoch} (Train)')

        for step, data in enumerate(memory_safe_data_loader):
            optimizer.zero_grad()

            inputs = {k: v.to(device) for k, v in data.items()}
            outputs = model(**inputs) # output = loss, logits, hidden_states, attentions

            # loss = loss_function(outputs.logits, targets)
            loss = outputs[0]

            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            # preds = np.argmax(outputs.logits.detach().cpu().numpy(), axis=1)
            probs = sigmoid(outputs.logits.detach().cpu())[:, 1]
            labels = inputs[target].detach().cpu().numpy()
            
            total_probs = torch.cat((total_probs, probs), dim=0)
            total_labels.extend(labels)

            progress_bar.update(1)
            progress_bar.set_postfix(
                loss=np.round(np.mean(losses), 4), 
                f1=np.round(f1_score(total_labels, total_probs>=0.5), 4)
            )

    train_loss = np.mean(losses)
    train_result = calculate_result(np.array(total_labels), np.array(total_probs))

    return train_loss, train_result, total_probs

def dump_results(epoch=None):
    n_train, n_test = len(train_probs), len(test_probs)
    train_df = pd.DataFrame({
        'id':train_all_tokenized['id'][:n_train], 'labels':train_all_tokenized[target][:n_train], 
        'probs': train_probs, 'split':['train']* n_train
    })
    test_df = pd.DataFrame({
        'id':test_all_tokenized['id'][:n_test], 'labels':test_all_tokenized[target][:n_test], 
        'probs': test_probs, 'split':['test']* n_test
    })

    total_df = pd.concat([train_df, test_df],ignore_index=True)

    if epoch is None:
        total_df.to_csv('results_dp.csv', index=False)
    else:
        total_df.to_csv(f'results_dp_{epoch}.csv', index=False)

## Data loader

[How to choose batch size in DP](https://github.com/pytorch/opacus/blob/main/tutorials/building_text_classifier.ipynb)

In [13]:
BATCH_SIZE = 128

# needed for DP training
MAX_PHYSICAL_BATCH_SIZE = 64

train_dataloader = DataLoader(train_tokenized, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_tokenized, batch_size=BATCH_SIZE)

## Model and optimizer

In [14]:
EPOCHS = 1
delta_list = [5e-2, 1e-3, 1e-5]
NOISE_MULTIPLIER = 0.35
LEARNING_RATE = 1e-3
MAX_GRAD_NORM = 1

In [15]:
# load a fresh model each time
model = load_pretrained_model(model_name, num_labels)

# Set the model to train mode (HuggingFace models load in eval mode)
model = model.train().to(device)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, eps=1e-8)

# https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

## Privacy Engine

In [16]:
from opacus import PrivacyEngine

privacy_engine = PrivacyEngine()

In [17]:
# using this method decreases how many batches BatchMemoryManager can load
# e.g. with 1000 examples, 128 batch size, 64 max physical batch size, using BatchMemoryManager 
# with train_dataloader before applying this function returns length 16. But applying this method 
# on, same train_dataloader makes a BatchMemoryManager of length 15 instead.
# EPSILON = 10
# # DELTA = 1 / len(train_dataloader) # Parameter for privacy accounting. Probability of not achieving privacy guarant
# DELTA = 5e-2

# model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
#     module=model,
#     optimizer=optimizer,
#     data_loader=train_dataloader,
#     target_delta=DELTA,
#     target_epsilon=EPSILON, 
#     epochs=EPOCHS,
#     max_grad_norm=MAX_GRAD_NORM,
# )

model, optimizer, train_dataloader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=MAX_GRAD_NORM,
    poisson_sampling=False,
)

## Loop

In [21]:
for epoch in range(2, EPOCHS+3):
    gc.collect()
    
    train_loss, train_result, train_probs = dp_train(model, train_dataloader, epoch)
    test_loss, test_result, test_probs = evaluate(model, test_dataloader, epoch)
    
    epsilons = []
    for delta in delta_list:
        epsilons.append(privacy_engine.get_epsilon(delta))

    print(
      f"Epoch: {epoch} | "
      f"ɛ: {np.round(epsilons, 2)} |"
      f"Train loss: {train_loss:.3f} | "
      f"Train result: {train_result} |\n"
      f"Test loss: {test_loss:.3f} | "
      f"Test result: {test_result} | "
    )
    
    dump_results(epoch)
    lr_scheduler.step()

### Noise multiplier 0.35
Epoch 1 (Train): 100%
4511/4512 [19:44<00:00, 2.75it/s, f1=0.802, loss=0.807]
Epoch 1 (Test): 100%
1521/1521 [04:47<00:00, 3.28it/s, f1=0.426, loss=0.781]
Epoch: 1 | ɛ: [ 2.95  7.08 11.52] |Train loss: 0.808 | Train result: {'accuracy': 0.8002, 'f1': 0.8025, 'auc': 0.8809} |
Test loss: 0.781 | Test result: {'accuracy': 0.8169, 'f1': 0.4261, 'auc': 0.9093} | 

Epoch 2 (Train): 100%
4511/4512 [19:43<00:00, 2.89it/s, f1=0.825, loss=0.778]
Epoch 2 (Test): 100%
1521/1521 [04:49<00:00, 4.25it/s, f1=0.451, loss=0.743]
Epoch: 2 | ɛ: [ 3.78  8.38 13.24] |Train loss: 0.778 | Train result: {'accuracy': 0.8232, 'f1': 0.8253, 'auc': 0.8913} |
Test loss: 0.743 | Test result: {'accuracy': 0.8361, 'f1': 0.4513, 'auc': 0.9066} | 

Epoch 3 (Train): 100%
4511/4512 [19:42<00:00, 3.30it/s, f1=0.828, loss=0.776]
Epoch 3 (Test): 100%
1521/1521 [04:46<00:00, 4.13it/s, f1=0.414, loss=0.856]
Epoch: 3 | ɛ: [ 4.44  9.46 14.57] |Train loss: 0.776 | Train result: {'accuracy': 0.8267, 'f1': 0.8281, 'auc': 0.8917} |
Test loss: 0.856 | Test result: {'accuracy': 0.8018, 'f1': 0.4142, 'auc': 0.9043} | 

### Noise multipler 0.5

Epoch 1 (Train): 100%
4511/4512 [19:50<00:00, 3.38it/s, f1=0.78, loss=0.879]
Epoch 1 (Test): 100%
1521/1521 [04:46<00:00, 4.30it/s, f1=0.445, loss=0.78]
Epoch: 1 | ɛ: [0.48 1.98 3.66] |Train loss: 0.879 | Train result: {'accuracy': 0.7828, 'f1': 0.78, 'auc': 0.862} |
Test loss: 0.780 | Test result: {'accuracy': 0.8386, 'f1': 0.4449, 'auc': 0.8977} | 

Epoch 2 (Train): 100%
4511/4512 [20:06<00:00, 3.30it/s, f1=0.81, loss=0.842]
Epoch 2 (Test): 100%
1521/1521 [04:52<00:00, 4.17it/s, f1=0.44, loss=0.899]
Epoch: 2 | ɛ: [0.58 2.15 3.9 ] |Train loss: 0.842 | Train result: {'accuracy': 0.8136, 'f1': 0.8101, 'auc': 0.8782} |
Test loss: 0.899 | Test result: {'accuracy': 0.8291, 'f1': 0.4395, 'auc': 0.9005} | 


Epoch 3 (Train): 100%
4511/4512 [20:09<00:00, 3.26it/s, f1=0.815, loss=0.832]
Epoch 3 (Test): 100%
1521/1521 [04:53<00:00, 4.13it/s, f1=0.42, loss=0.93]
Epoch: 3 | ɛ: [0.66 2.28 4.08] |Train loss: 0.832 | Train result: {'accuracy': 0.8168, 'f1': 0.8148, 'auc': 0.8785} |
Test loss: 0.930 | Test result: {'accuracy': 0.8111, 'f1': 0.4199, 'auc': 0.8953} | 


### Noise Multiplier 0.4
Epoch 2 (Train): 100%
4511/4512 [20:03<00:00, 3.33it/s, f1=0.786, loss=0.867]
Epoch 2 (Test): 100%
1521/1521 [04:52<00:00, 4.23it/s, f1=0.424, loss=0.859]
Epoch: 2 | ɛ: [1.64 4.47 7.54] |Train loss: 0.867 | Train result: {'accuracy': 0.7878, 'f1': 0.7857, 'auc': 0.8663} |
Test loss: 0.859 | Test result: {'accuracy': 0.8194, 'f1': 0.4239, 'auc': 0.9017} | 
Epoch 3 (Train): 100%
4511/4512 [19:59<00:00, 2.70it/s, f1=0.81, loss=0.832]
Epoch 3 (Test): 100%
1521/1521 [04:52<00:00, 4.22it/s, f1=0.448, loss=0.816]
Epoch: 3 | ɛ: [2.02 5.1  8.39] |Train loss: 0.832 | Train result: {'accuracy': 0.8145, 'f1': 0.81, 'auc': 0.8836} |
Test loss: 0.816 | Test result: {'accuracy': 0.8366, 'f1': 0.4476, 'auc': 0.9058} | 
Epoch 4 (Train): 100%
4511/4512 [20:03<00:00, 3.21it/s, f1=0.819, loss=0.815]
Epoch 4 (Test): 100%
1521/1521 [04:54<00:00, 4.20it/s, f1=0.42, loss=0.902]
Epoch: 4 | ɛ: [2.32 5.56 9.01] |Train loss: 0.815 | Train result: {'accuracy': 0.8203, 'f1': 0.8191, 'auc': 0.8869} |
Test loss: 0.902 | Test result: {'accuracy': 0.8101, 'f1': 0.4202, 'auc': 0.9098} | 


|Model|Train Time|Learning rate|batch size|epoch |          |  train |         |       | test    |       |Noise multiplier| Epsilon ([5e-2, 1e-3, 1e-5]) |
|-----|----------|-------------|----------|------|----------|--------|---------|-------|---------|-------|----------------|------------------------------|
|     |          |             |          |      |  loss    |  f1    |  auc    | loss  | f1      | auc   |                |                              |
|bert small with lr scheduler|27 min|1e-3|64(32)|1| 0.793   | 0.8183 |  0.8776 | 1.016 |  0.3964 | 0.9017|     0.1        |   1058.57, 1097.69, 1143.74  |
|     | 25min    |      1e-3   |          |   2  |  0.766   | 0.8297 |  0.8741 | 0.794 |  0.4338 | 0.8961|     0.1        |   2090.53, 2129.65, 2175.7   |
|bert small with lr scheduler|17 min|5e-4|32(32)|1| 0.774   | 0.815  |  0.8937 | 0.774 |  0.4765 | 0.9222|     0.1        |   1144.95, 1184.07, 1230.12  |
|     | 21min    |      5e-4   |          |   2  |  0.736   | 0.8375 |  0.9067 | 0.736 |  0.4739 | 0.9259|     0.1        |   2237.59, 2276.71, 2322.77  |
|     | 17min    |      5e-4   |          |   3  |  0.737   | 0.8325 |  0.8976 | 0.788 |  0.4352 | 0.9174|     0.1        |   3330.24, 3369.36, 3415.41  |
|bert small with lr scheduler|21 min|1e-3|128(64)|1| 0.774  | 0.8204 |  0.8937 | 0.774 |  0.4765 | 0.9222|     0.1        |   1144.95, 1184.07, 1230.12  |
|     |          |      1e-3   |          |   2  |  0.736   | 0.8375 |  0.9067 | 0.736 |  0.4739 | 0.9259|     0.1        |   2237.59, 2276.71, 2322.77  |
|     |          |      1e-3   |          |   3  |  0.737   | 0.8325 |  0.8976 | 0.788 |  0.4352 | 0.9174|     0.1        |   3330.24, 3369.36, 3415.41  |
|bert small with lr scheduler|20 min|1e-3|128(64)|1| 0.779  | 0.8086 |  0.8885 | 0.689 |  0.4790 | 0.9161|     0.2        |     28.38, 47.94, 70.96      |
|     | 21min    |      1e-3   |          |   2  |  0.758   | 0.8282 |  0.8999 | 0.689 |  0.4750 | 0.9179|     0.2        |     42.57, 64.04, 87.06      |
|     | 17min    |      1e-3   |          |   3  |  0.746   | 0.8318 |  0.8984 | 0.696 |  0.4733 | 0.9175|     0.2        |     50.56, 80.14, 103.16     |
|     | 20 min   |      1e-3   |  128(64) |   1  |  0.844   | 0.8006 |  0.8721 | 0.722 |  0.4650 | 0.9006|     0.25       |     12.38, 23.28, 34.79      |
|     |          |      1e-3   |          |   2  |  0.797   | 0.8236 |  0.8831 | 0.707 |  0.4726 | 0.9026|     0.25       |     17.11, 30.15, 42.89      |
|     |          |      1e-3   |          |   3  |  0.783   | 0.8237 |  0.8760 | 0.777 |  0.4749 | 0.8934|     0.25       |     21.32, 34.89, 50.24      |
|     | 22min    |      1e-3   |  128(64) |   1  |  0.794   | 0.7977 |  0.8814 | 0.698 |  0.4713 | 0.9025|     0.3        |     51.14, 80.88, 103.91     |
|     |          |      1e-3   |          |   2  |  0.773   | 0.8229 |  0.8894 | 0.818 |  0.4455 | 0.9055|     0.3        |     51.72, 81.63, 104.65     |
|     |          |      1e-3   |          |   3  |  0.770   | 0.8269 |  0.8864 | 0.801 |  0.4445 | 0.9018|     0.3        |     52.30, 82.37, 105.40     |

## Save model
https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [22]:
model_path = 'dp_model_state_dict.pt'
torch.save(model.state_dict(), model_path)

# model.load_state_dict(torch.load(model_path))
# model.eval()