# Bastion AI Real World Example
## Finetuning DistilBERT for binary classification on the SMS Spam Collection

Data preparation and training are largely based on https://towardsdatascience.com/fine-tuning-bert-for-text-classification-54e7df642894.

### Installing Bastion AI

To use this notebook, you'll need a working Bastion AI installation.
First clone our repo:
```
$ git clone git@github.com:mithril-security/bastionai.git
```
Then install the client library:
```
$ cd ./bastionai/client
$ make install
```

### Installing and importing additionnal packages

Let's first import all the necessary packages for the entire notebook.
The makefile for the client has already set up a virtualenv with the client dependences for us.
We just need to install the additionnal packages we'll use:

```
$ source venv/bin/activate
$ pip install transformers pandas sklearn ipykernel ipywidgets
```

We can now import necessary packages and objects:

In [1]:
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

from bastionai.client import Connection
from bastionai.psg import expand_weights
from bastionai.utils import MultipleOutputWrapper, TensorDataset
from bastionai.pb.remote_torch_pb2 import TestConfig, TrainConfig, Empty

### Preparing the dataset

The dataset can be found at https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip.
Unzip the archive to obtain the datset file:

```
$ unzip smsspamcollection.zip
```

Each row represent a sample, the label come first followed by a tab and the raw text:
```
ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham	Ok lar... Joking wif u oni...
spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
```

We first load the data from the file into a pandas dataframe:

In [3]:
file_path = "../tests/data/SMSSpamCollection"

labels = []
texts = []
with open(file_path) as f:
  for line in f.readlines():
    split = line.split('\t')
    labels.append(1 if split[0] == "spam" else 0)
    texts.append(split[1])
df = pd.DataFrame({ "label": labels, "text": texts })
df.head()

Unnamed: 0,label,text
0,0,"Go until jurong point, crazy.. Available only ..."
1,0,Ok lar... Joking wif u oni...\n
2,1,Free entry in 2 a wkly comp to win FA Cup fina...
3,0,U dun say so early hor... U c already then say...
4,0,"Nah I don't think he goes to usf, he lives aro..."


We then preprocess the data using DistilBERT's tokenizer and we obtain tensors ready to be fed to the model:

In [4]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

token_id = []
attention_masks = []
for sample in df.text.values:
    encoding_dict = tokenizer.encode_plus(
        sample,
        add_special_tokens=True,
        max_length=32,
        truncation=True,
        padding="max_length",
        return_attention_mask=True,
        return_tensors='pt'
    )
    token_id.append(encoding_dict['input_ids']) 
    attention_masks.append(encoding_dict['attention_mask'])

token_id = torch.cat(token_id, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(df.label.values)

It's now time to split the data in a train and test sets and to wrap it inside Dataset object for ease of use:

In [5]:
val_ratio = 0.2

train_idx, test_idx = train_test_split(
    np.arange(len(labels)),
    test_size=val_ratio,
    shuffle=True,
    stratify=labels
)

train_set = TensorDataset([
    token_id[train_idx], 
    attention_masks[train_idx]
], labels[train_idx])

test_set = TensorDataset([
    token_id[test_idx], 
    attention_masks[test_idx]
], labels[test_idx])

### Preparing the model for use with DP-SGD and Bastion AI

We now turn to preparing the DistilBERT language model. As training will be executed remotely on a private Bastion AI server, we need to script the model prior to sending it (i.e. compile it to Torch Script).

In addition, as we'll use the DP-SGD algorithm for training in this example, we need to make the model compatible with Bastion AI's DP-SGD implementation. Unlike Opacus that uses backprop hooks to compute per-sample gradients, Bastion AI relies on normal autograd and modified layers that internally store expanded gradients (weight tensors have the same size in memory but are manipulated through expanded views that repeat them as many times as there are samples in a batch so that the gradient of these views are per-sample gradients).

Per-samples gradient computation is key to DP-SGD and is one ingredient that make DP usable with Deep Learning models. Fortunately, we don't need to redifine the DistilBERT architecture to switch layers, Bastion AI includes a utility function that does this tedious job for us. Note that weights must be exapanded prior to scripting the model for changes to also apply on the server side.

As Hugging Face's models were not designed with scripting in mind, we must resort to tracing them to obtain a Torch Script version of them. The model is run with a dummy but representative input and the torch jit compiler tracks all functions that are called and compiles them on the fly. This approach, although more error prone (in certain cases the dummy input may not activate some needed computation paths) is less picky that scripting and accepts our model.

Note that we also need to use Bastion AI's utility wrapper for models with multiple outputs to select the sole output that corresponds with the logits. In fact, Bastion AI's server supports models with an arbitrtary number of inputs and a single output.

In [6]:
# Do not display warnings about layer not initialized
# with pretrained weights (classification layers, this is fine)
# nor warnings originating from torch.jit.trace
from transformers import logging
logging.set_verbosity_error()
import warnings
warnings.filterwarnings("ignore")

batch_size = 4

model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased',
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False,
    torchscript=True
)
expand_weights(model, batch_size) # Convert learnable layers into their expanded counterparts

[text, mask], label = train_set[0] # Dummy input used to trace the model
traced_model = torch.jit.trace( # Compile the model with the tracing strategy
    MultipleOutputWrapper(model, 0), # Wrapp the model to use the first output only (and drop the others)
    [
        text.unsqueeze(0),
        mask.unsqueeze(0)
    ]
)

### Sending dataset and model and training on the server

Before proceeding, we need to start a local Bastion AI server which can be achivied with the following commands,
assuming you have a working rust toolchain (https://www.rust-lang.org/tools/install):

```
$ cd ../server/bastionai_app
$ cargo run
```

Now that the server code has been compiled and the server has started, it's time to send the dataset and the model to the server.

In both cases, the API returns a reference to the object (a UUID).

We can use these to reference the objects in the subsequent calls such as when training.

In [7]:
with Connection("localhost", 50051) as client:
    model_ref = client.send_model(
        traced_model,
        "Expanded DistilBERT",
        b"secret"
    )
    print(f"Model ref: {model_ref}")

    train_dataset_ref = client.send_dataset(
        train_set,
        "SMSSpamCollection",
        b'secret'
    )
    print(f"Dataset ref: {train_dataset_ref}")

    client.train(TrainConfig(
        model=model_ref,
        dataset=train_dataset_ref,
        batch_size=batch_size,
        epochs=2,
        device="cpu",
        metric="cross_entropy",
        differential_privacy=TrainConfig.DpParameters(
            max_grad_norm=100.,
            noise_multiplier=0.001
        ),
        # standard=Empty(),
        adam=TrainConfig.Adam(
            learning_rate=5e-5,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-8,
            weight_decay=0,
            amsgrad=False
        )
    ))

_InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses"
	debug_error_string = "{"created":"@1660301475.618716534","description":"Failed to pick subchannel","file":"src/core/ext/filters/client_channel/client_channel.cc","file_line":3260,"referenced_errors":[{"created":"@1660301475.618716277","description":"failed to connect to all addresses","file":"src/core/lib/transport/error_utils.cc","file_line":167,"grpc_status":14}]}"
>