In [1]:
import pandas as pd

data = pd.read_json('../data/tram2-data/multi_label.json')
all_labels = data['labels'].explode().dropna().unique()
data

Unnamed: 0,sentence,labels,doc_title
0,title: NotPetya Technical Analysis – A Triple ...,[],NotPetya Technical Analysis A Triple Threat F...
1,Executive Summary This technical analysis prov...,[],NotPetya Technical Analysis A Triple Threat F...
2,For more information on CrowdStrike’s proactiv...,[],NotPetya Technical Analysis A Triple Threat F...
3,NotPetya combines ransomware with the ability ...,[],NotPetya Technical Analysis A Triple Threat F...
4,It spreads to Microsoft Windows machines using...,[T1210],NotPetya Technical Analysis A Triple Threat F...
...,...,...,...
19173,[2] Eclypsium Blog - TrickBot Now Offers 'Tric...,[],AA21076A TrickBot Malware
19174,"Initial Version March 24, 2021:",[],AA21076A TrickBot Malware
19175,Added MITRE ATT&CK Technique T1592.003 used fo...,[],AA21076A TrickBot Malware
19176,Added new MITRE ATT&CKs and updated Table 1,[],AA21076A TrickBot Malware


In [None]:
import transformers
import torch

mode: 'bert or gpt' = 'bert'
cuda = torch.device('cuda')

if mode == 'bert':
    model = transformers.BertForSequenceClassification.from_pretrained(
        "allenai/scibert_scivocab_uncased",
        num_labels=len(all_labels),
        output_attentions=False,
        output_hidden_states=False,
    )
    tokenizer = transformers.BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", max_length=512)
elif mode == 'gpt':
    model = transformers.GPT2ForSequenceClassification.from_pretrained(
        "gpt2",
        num_labels=len(all_labels),
        output_attentions=False,
        output_hidden_states=False,
    )
    tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2", max_length=512)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
else:
    raise ValueError(f"mode must be one of bert or gpt, but is {mode = !r}")

model.train().to(cuda)

In [3]:
from sklearn.preprocessing import MultiLabelBinarizer as MLB
from sklearn.model_selection import train_test_split

mlb = MLB()
mlb.fit([[c] for c in all_labels])

train, test = train_test_split(data, test_size=0.2, shuffle=True)

def load_data(x, y, batch_size=10):
    x_len, y_len = x.shape[0], y.shape[0]
    assert x_len == y_len
    for i in range(0, x_len, batch_size):
        slc = slice(i, i + batch_size)
        yield x[slc].to(cuda), y[slc].to(cuda)

def tokenize(instances: 'list[str]'):
    return tokenizer(instances, return_tensors='pt', padding='max_length', truncation=True, max_length=512).input_ids

def encode_labels(labels):
    """:labels: should be the `labels` column (a Series) of the DataFrame"""
    return torch.Tensor(mlb.transform(labels.to_numpy()))

In [4]:
x_train = tokenize(train['sentence'].tolist())
x_train

tensor([[  102,  7208,  4531,  ...,     0,     0,     0],
        [  102,   260, 24391,  ...,     0,     0,     0],
        [  102,  4975, 11554,  ...,     0,     0,     0],
        ...,
        [  102,   407, 14382,  ...,     0,     0,     0],
        [  102,   256,   165,  ...,     0,     0,     0],
        [  102,   111,  2057,  ...,     0,     0,     0]])

In [5]:
y_train = encode_labels(train['labels'])
y_train

tensor([[0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [6]:
from statistics import mean

from tqdm import tqdm
from torch.optim import AdamW

optim = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

for epoch in range(3):
    epoch_losses = []
    for x, y in tqdm(load_data(x_train, y_train, batch_size=10)):
        model.zero_grad()
        out = model(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)
        epoch_losses.append(out.loss.item())
        out.loss.backward()
        optim.step()
    print(f"epoch {epoch + 1} loss: {mean(epoch_losses)}")

1535it [08:00,  3.19it/s]


epoch 1 loss: 0.0539344590427263


1535it [08:08,  3.14it/s]


epoch 2 loss: 0.031110431912206478


1535it [08:08,  3.14it/s]

epoch 3 loss: 0.025711828058112278





In [None]:
from sklearn.metrics import precision_recall_fscore_support as calculate_score

model.eval()

x_test = tokenize(test['sentence'].tolist())

batch_size = 20
preds = []

with torch.no_grad():
    for i in range(0, x_test.shape[0], batch_size):
        x = x_test[i : i + batch_size].to(cuda)
        out = model(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int))
        preds.extend(out.logits.to('cpu'))


In [18]:
binary_preds = torch.vstack(preds).sigmoid().gt(.5).to(int)

predicted = pd.Series(mlb.inverse_transform(binary_preds)).apply(frozenset).reset_index(drop=True)
actual = test['labels'].apply(frozenset).reset_index(drop=True)
results = pd.concat({'preds': predicted, 'actual': actual}, axis=1)

results

Unnamed: 0,preds,actual
0,(),()
1,(),()
2,(),(T1027)
3,(),(T1078)
4,(),()
...,...,...
3831,(),()
3832,(),()
3833,(),()
3834,(),()


In [19]:
tp = results.apply((lambda r: r.preds & r.actual), axis=1).explode().value_counts()
fp = results.apply((lambda r: r.preds - r.actual), axis=1).explode().value_counts()
fn = results.apply((lambda r: r.actual - r.preds), axis=1).explode().value_counts()

support = actual.explode().value_counts().rename('#')

counts = pd.concat({'tp': tp, 'fp': fp, 'fn': fn}, axis=1).fillna(0).astype(int)

p = counts.tp.div(counts.tp + counts.fp).fillna(0)
r = counts.tp.div(counts.tp + counts.fn).fillna(0)
f1 = (2 * p * r) / (p + r)
scores = pd.concat({'P': p, 'R': r, 'F1': f1}, axis=1).fillna(0).sort_values(by='F1', ascending=False)

# calculate macro scores
scores.loc['(macro)'] = scores.mean()

# calculate micro scores
micro = counts.sum()
scores.loc['(micro)', 'P'] = mP = micro.tp / (micro.tp + micro.fp)
scores.loc['(micro)', 'R'] = mR = micro.tp / (micro.tp + micro.fn)
scores.loc['(micro)', 'F1'] = (2 * mP * mR) / (mP + mR)

scores.join(support)

Unnamed: 0,P,R,F1,#
T1140,0.793103,0.75,0.77095,92.0
T1055,0.727273,0.380952,0.5,63.0
T1027,0.703704,0.277372,0.397906,137.0
T1071.001,1.0,0.166667,0.285714,24.0
T1105,0.615385,0.16,0.253968,50.0
T1059.003,1.0,0.060976,0.114943,82.0
T1552.001,0.0,0.0,0.0,6.0
T1219,0.0,0.0,0.0,10.0
T1574.002,0.0,0.0,0.0,10.0
T1113,0.0,0.0,0.0,10.0
