# SciBERT for Multi-Label Classification

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/center-for-threat-informed-defense/tram/blob/main/user_notebooks/fine_tune_multi_label.ipynb)

This notebook allows one to continue fine-tuning our provided SciBERT-for-multilabel-sequence-classification on custom data.

To start, first select `Runtime > Change runtime type`, and under `Hardware accelerator` select `GPU`. Then run the next two cells. The first cell will download the model and the Python dependencies. The second cell will load the model and set up the selectors.

In [None]:
!mkdir scibert_multi_label_model
!wget https://ctidtram.blob.core.windows.net/tram-models/multi-label-20230803/config.json -O scibert_multi_label_model/config.json
!wget https://ctidtram.blob.core.windows.net/tram-models/multi-label-20230803/pytorch_model.bin -O scibert_multi_label_model/pytorch_model.bin
!pip install torch transformers pandas

This cell instantiates the label encoder. Do not modify this cell, as the classes (ie, ATT&CK techniques) and their order must match those the model expects.

In [2]:
from sklearn.preprocessing import MultiLabelBinarizer as MLB

CLASSES = [
   'T1003.001', 'T1005', 'T1012', 'T1016', 'T1021.001', 'T1027',
   'T1033', 'T1036.005', 'T1041', 'T1047', 'T1053.005', 'T1055',
   'T1056.001', 'T1057', 'T1059.003', 'T1068', 'T1070.004',
   'T1071.001', 'T1072', 'T1074.001', 'T1078', 'T1082', 'T1083',
   'T1090', 'T1095', 'T1105', 'T1106', 'T1110', 'T1112', 'T1113',
   'T1140', 'T1190', 'T1204.002', 'T1210', 'T1218.011', 'T1219',
   'T1484.001', 'T1518.001', 'T1543.003', 'T1547.001', 'T1548.002',
   'T1552.001', 'T1557.001', 'T1562.001', 'T1564.001', 'T1566.001',
   'T1569.002', 'T1570', 'T1573.001', 'T1574.002'
]

mlb = MLB(classes=CLASSES)
mlb.fit([[c] for c in CLASSES])

mlb

This cell is for loading the training data. You will need to modify this cell to load your data. Ensure that by the end of this cell, a DataFrame has been assigned to the variable `data` that has a `sentence` column containing the sentences, and a `labels` column containing lists (or other container types) of strings, where those strings are the ATT&CK IDs that this model can classify. The lists can be empty for negative examples. It does not matter how the DataFrame is indexed or what other columns with other names, if any, it has.

For demonstration purposes, we will use the same multi-label data that was produced during this TRAM effort, even though the model was trained on this data already. This cell is only present to show the expected format of the `data` DataFrame, and is not intended to be run as shown.

In [3]:
import pandas as pd
data = pd.read_json('multi_label.json').drop(columns='doc_title').head(500)
data

Unnamed: 0,sentence,labels
0,title: NotPetya Technical Analysis – A Triple ...,[]
1,Executive Summary This technical analysis prov...,[]
2,For more information on CrowdStrike’s proactiv...,[]
3,NotPetya combines ransomware with the ability ...,[]
4,It spreads to Microsoft Windows machines using...,[T1210]
...,...,...
495,"The following is a list of services, hardcoded...",[]
496,QuickBooks.,[]
497,FCSmemtasmepocsPDVFSServiceQBCFMonitorServiceQ...,[]
498,#recycle$Recycle.,[]


In [None]:
import transformers
import torch

cuda = torch.device('cuda')

tokenizer = transformers.BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", max_length=512)
bert = transformers.BertForSequenceClassification.from_pretrained('scibert_multi_label_model').to(cuda).train()

In [5]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(data, test_size=0.2, shuffle=True)

def _load_data(x, y, batch_size=10):
    x_len, y_len = x.shape[0], y.shape[0]
    assert x_len == y_len
    for i in range(0, x_len, batch_size):
        slc = slice(i, i + batch_size)
        yield x[slc].to(cuda), y[slc].to(cuda)

def _tokenize(instances: list[str]):
    return tokenizer(instances, return_tensors='pt', padding='max_length', truncation=True, max_length=512).input_ids

def _encode_labels(labels):
    """:labels: should be the `labels` column (a Series) of the DataFrame"""
    return torch.Tensor(mlb.transform(labels.to_numpy()))

In [6]:
x_train = _tokenize(train['sentence'].tolist())

In [7]:
y_train = _encode_labels(train['labels'])
y_train

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

This cell contains the training loop. You may change the `NUM_EPOCHS` value to any integer you would like.

In [None]:
NUM_EPOCHS = 3

from statistics import mean

from tqdm import tqdm
from torch.optim import AdamW

optim = AdamW(bert.parameters(), lr=2e-5, eps=1e-8)

for epoch in range(NUM_EPOCHS):
    epoch_losses = []
    for x, y in tqdm(_load_data(x_train, y_train, batch_size=10)):
        bert.zero_grad()
        out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)
        epoch_losses.append(out.loss.item())
        out.loss.backward()
        optim.step()
    print(f"epoch {epoch + 1} loss: {mean(epoch_losses)}")

If the loss from the last iteration was not to your liking, do not re-run the previous cell. Uncomment the following cell and run it for however many additional epochs you would like.

In [None]:
# NUM_EXTRA_EPOCHS = 1
# for epoch in range(NUM_EXTRA_EPOCHS):
#     epoch_losses = []
#     for x, y in tqdm(_load_data(x_train, y_train, batch_size=10)):
#         bert.zero_grad()
#         out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)
#         epoch_losses.append(out.loss.item())
#         out.loss.backward()
#         optim.step()
#     print(f"epoch {epoch + 1} loss: {mean(epoch_losses)}")

The next cells evaluate the performance after the additional fine-tuning. The performance scores on the example data will be high, as the model has already been trained on most of these instances.

In [9]:
bert.eval()

x_test = _tokenize(test['sentence'].tolist())
y_test = _encode_labels(test['labels'])

batch_size = 20
preds = []

with torch.no_grad():
    for i in range(0, x_test.shape[0], batch_size):
        x = x_test[i : i + batch_size].to(cuda)
        out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int))
        preds.extend(out.logits.to('cpu'))

binary_preds = torch.vstack(preds).sigmoid().gt(.5).to(int)

preds_series = pd.Series(mlb.inverse_transform(binary_preds)).apply(frozenset)
actual_series = pd.Series(mlb.inverse_transform(y_test)).apply(frozenset)
results = pd.concat({'predicted': preds_series, 'actual': actual_series}, axis=1)

results

Unnamed: 0,predicted,actual
0,(T1562.001),"(T1562.001, T1484.001)"
1,(),()
2,(),()
3,(),()
4,(),(T1106)
...,...,...
95,(),()
96,(),()
97,(T1140),"(T1027, T1140)"
98,(),()


In [11]:
tp = results.apply((lambda r: r.predicted & r.actual), axis=1).explode().value_counts()
fp = results.apply((lambda r: r.predicted - r.actual), axis=1).explode().value_counts()
fn = results.apply((lambda r: r.actual - r.predicted), axis=1).explode().value_counts()
counts = pd.concat({'tp': tp, 'fp': fp, 'fn': fn}, axis=1).fillna(0).astype(int)

support = actual_series.explode().value_counts().rename('#')

p = counts.tp.div(counts.tp + counts.fp).fillna(0)
r = counts.tp.div(counts.tp + counts.fn).fillna(0)
f1 = (2 * p * r) / (p + r)

scores = pd.concat({'P': p, 'R': r, 'F1': f1}, axis=1).fillna(0).sort_values(by='F1', ascending=False)

# calculate macro scores
scores.loc['(macro)'] = scores.mean()

# calculate micro scores
micro = counts.sum()
scores.loc['(micro)', 'P'] = mP = micro.tp / (micro.tp + micro.fp)
scores.loc['(micro)', 'R'] = mR = micro.tp / (micro.tp + micro.fn)
scores.loc['(micro)', 'F1'] = (2 * mP * mR) / (mP + mR)

scores.join(support)

Unnamed: 0,P,R,F1,#
T1059.003,1.0,1.0,1.0,2.0
T1140,1.0,1.0,1.0,2.0
T1562.001,1.0,1.0,1.0,1.0
T1055,1.0,1.0,1.0,1.0
T1574.002,1.0,1.0,1.0,1.0
T1053.005,0.75,1.0,0.857143,3.0
T1027,1.0,0.5,0.666667,8.0
T1106,1.0,0.333333,0.5,6.0
T1484.001,0.0,0.0,0.0,4.0
T1005,0.0,0.0,0.0,2.0
