In [58]:
import torch
import random
import numpy as np
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel
from setfit import SetFitTrainer
from sklearn.metrics import classification_report, confusion_matrix
from datasets import Dataset
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [59]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

In [65]:
df = pd.read_csv('./data_unlabeled/all_unlabeled.csv')
labeled_df = pd.read_csv('./data_labeled/all_labeled.csv')

test_set = Dataset.from_pandas(df)
labeled_test_set = Dataset.from_pandas(labeled_df)

In [74]:
model = SetFitModel.from_pretrained("PeppoCola/FewShotIssueClassifier-NLBSE23")

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [68]:
y_pred = model.predict(labeled_test_set['text'])

print(classification_report(labeled_test_df['label'], y_pred, digits=4))
print(confusion_matrix(labeled_test_df['label'], y_pred))

              precision    recall  f1-score   support

         0.0     0.6172    0.5308    0.5707       878
         1.0     0.4030    0.7378    0.5213       431
         2.0     0.7513    0.6228    0.6810      1678
         3.0     0.3588    0.5949    0.4476        79

    accuracy                         0.6119      3066
   macro avg     0.5326    0.6216    0.5552      3066
weighted avg     0.6538    0.6119    0.6210      3066

[[ 466  101  266   45]
 [  33  318   73    7]
 [ 233  368 1045   32]
 [  23    2    7   47]]


In [69]:
# y_pred = model.predict(test_set['text'])

# plt.hist(y_pred)

In [70]:
# bug = labeled_df.loc[labeled_df['label'] == 0]
# docs = labeled_df.loc[labeled_df['label'] == 1].sample(50)
# feature = labeled_df.loc[labeled_df['label'] == 2].sample(50)
# question = labeled_df.loc[labeled_df['label'] == 3].sample(50)

bug_train, bug_test = train_test_split(labeled_df.loc[labeled_df['label'] == 0], random_state=seed, train_size=50, test_size=50)
docs_train, docs_test = train_test_split(labeled_df.loc[labeled_df['label'] == 1], random_state=seed, train_size=50, test_size=50)
feature_train, feature_test = train_test_split(labeled_df.loc[labeled_df['label'] == 2], random_state=seed, train_size=50, test_size=50)
question_train, question_test = train_test_split(labeled_df.loc[labeled_df['label'] == 3], random_state=seed, train_size=50)

train_df = pd.concat([bug_train, docs_train, feature_train, question_train], ignore_index=True)
test_df = pd.concat([bug_test, docs_test, feature_test, question_test], ignore_index=True)

train_set = Dataset.from_pandas(train_df)
test_set = Dataset.from_pandas(test_df)

In [75]:
y_pred = model.predict(test_set['text'])

print(classification_report(test_set['label'], y_pred, digits=4))
print(confusion_matrix(test_set['label'], y_pred))

              precision    recall  f1-score   support

         0.0     0.5116    0.4400    0.4731        50
         1.0     0.6136    0.5400    0.5745        50
         2.0     0.4118    0.5600    0.4746        50
         3.0     0.6667    0.5517    0.6038        29

    accuracy                         0.5196       179
   macro avg     0.5509    0.5229    0.5315       179
weighted avg     0.5373    0.5196    0.5230       179

[[22  6 19  3]
 [ 3 27 18  2]
 [ 9 10 28  3]
 [ 9  1  3 16]]


In [76]:
model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

trainer = SetFitTrainer(
    model=model,
    train_dataset=train_set,
    eval_dataset=test_set,
    loss_class=CosineSimilarityLoss,
    num_iterations=20,
    num_epochs=1,
)

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


In [77]:
trainer.train()

Generating Training Pairs: 100%|███████████████| 20/20 [00:00<00:00, 157.13it/s]
***** Running training *****
  Num examples = 8000
  Num epochs = 1
  Total optimization steps = 500
  Total train batch size = 16
Epoch:   0%|                                              | 0/1 [00:00<?, ?it/s]
Iteration:   0%|                                        | 0/500 [00:00<?, ?it/s][A
Iteration:   0%|                              | 1/500 [00:20<2:48:07, 20.21s/it][A
Iteration:   0%|                              | 2/500 [00:35<2:28:38, 17.91s/it][A
Iteration:   1%|▏                             | 3/500 [00:55<2:34:28, 18.65s/it][A
Iteration:   1%|▏                             | 4/500 [01:15<2:35:57, 18.87s/it][A
Iteration:   1%|▎                             | 5/500 [01:29<2:27:09, 17.84s/it][A
Iteration:   1%|▎                             | 6/500 [01:41<2:18:05, 16.77s/it][A
Iteration:   1%|▍                             | 7/500 [02:01<2:21:32, 17.23s/it][A
Iteration:   2%|▍                  

Iteration:  19%|█████▍                       | 94/500 [26:18<1:52:15, 16.59s/it][A
Iteration:  19%|█████▌                       | 95/500 [26:36<1:52:24, 16.65s/it][A
Iteration:  19%|█████▌                       | 96/500 [26:55<1:53:08, 16.80s/it][A
Iteration:  19%|█████▋                       | 97/500 [27:06<1:50:38, 16.47s/it][A
Iteration:  20%|█████▋                       | 98/500 [27:25<1:51:27, 16.64s/it][A
Iteration:  20%|█████▋                       | 99/500 [27:45<1:52:13, 16.79s/it][A
Iteration:  20%|█████▌                      | 100/500 [27:58<1:50:45, 16.61s/it][A
Iteration:  20%|█████▋                      | 101/500 [28:18<1:51:34, 16.78s/it][A
Iteration:  20%|█████▋                      | 102/500 [28:37<1:52:03, 16.89s/it][A
Iteration:  21%|█████▊                      | 103/500 [28:57<1:52:40, 17.03s/it][A
Iteration:  21%|█████▊                      | 104/500 [29:14<1:52:33, 17.05s/it][A
Iteration:  21%|█████▉                      | 105/500 [29:34<1:53:09, 17.19s

Iteration:  38%|██████████▋                 | 191/500 [54:58<1:30:24, 17.56s/it][A
Iteration:  38%|██████████▊                 | 192/500 [55:17<1:30:34, 17.64s/it][A
Iteration:  39%|██████████▊                 | 193/500 [55:37<1:30:51, 17.76s/it][A
Iteration:  39%|██████████▊                 | 194/500 [55:52<1:29:47, 17.61s/it][A
Iteration:  39%|██████████▉                 | 195/500 [56:03<1:27:53, 17.29s/it][A
Iteration:  39%|██████████▉                 | 196/500 [56:20<1:27:33, 17.28s/it][A
Iteration:  39%|███████████                 | 197/500 [56:35<1:26:27, 17.12s/it][A
Iteration:  40%|███████████                 | 198/500 [56:54<1:26:50, 17.25s/it][A
Iteration:  40%|███████████▏                | 199/500 [57:14<1:27:12, 17.39s/it][A
Iteration:  40%|███████████▏                | 200/500 [57:34<1:27:29, 17.50s/it][A
Iteration:  40%|███████████▎                | 201/500 [57:42<1:24:57, 17.05s/it][A
Iteration:  40%|███████████▎                | 202/500 [58:02<1:25:17, 17.17s

Iteration:  58%|████████████████▏           | 288/500 [1:23:13<59:17, 16.78s/it][A
Iteration:  58%|████████████████▏           | 289/500 [1:23:33<59:34, 16.94s/it][A
Iteration:  58%|████████████████▏           | 290/500 [1:23:53<59:50, 17.10s/it][A
Iteration:  58%|████████████████▎           | 291/500 [1:24:03<58:12, 16.71s/it][A
Iteration:  58%|████████████████▎           | 292/500 [1:24:22<58:26, 16.86s/it][A
Iteration:  59%|████████████████▍           | 293/500 [1:24:42<58:33, 16.97s/it][A
Iteration:  59%|████████████████▍           | 294/500 [1:25:02<58:46, 17.12s/it][A
Iteration:  59%|████████████████▌           | 295/500 [1:25:22<58:58, 17.26s/it][A
Iteration:  59%|████████████████▌           | 296/500 [1:25:42<59:12, 17.42s/it][A
Iteration:  59%|████████████████▋           | 297/500 [1:26:02<59:18, 17.53s/it][A
Iteration:  60%|████████████████▋           | 298/500 [1:26:21<59:21, 17.63s/it][A
Iteration:  60%|████████████████▋           | 299/500 [1:26:41<59:24, 17.73s

Iteration:  77%|█████████████████████▌      | 385/500 [1:51:54<35:21, 18.45s/it][A
Iteration:  77%|█████████████████████▌      | 386/500 [1:52:07<34:33, 18.19s/it][A
Iteration:  77%|█████████████████████▋      | 387/500 [1:52:21<33:50, 17.96s/it][A
Iteration:  78%|█████████████████████▋      | 388/500 [1:52:37<33:17, 17.84s/it][A
Iteration:  78%|█████████████████████▊      | 389/500 [1:52:58<33:17, 18.00s/it][A
Iteration:  78%|█████████████████████▊      | 390/500 [1:53:18<33:11, 18.11s/it][A
Iteration:  78%|█████████████████████▉      | 391/500 [1:53:31<32:27, 17.87s/it][A
Iteration:  78%|█████████████████████▉      | 392/500 [1:53:52<32:26, 18.02s/it][A
Iteration:  79%|██████████████████████      | 393/500 [1:54:12<32:18, 18.11s/it][A
Iteration:  79%|██████████████████████      | 394/500 [1:54:32<32:10, 18.21s/it][A
Iteration:  79%|██████████████████████      | 395/500 [1:54:52<32:02, 18.31s/it][A
Iteration:  79%|██████████████████████▏     | 396/500 [1:55:12<31:52, 18.39s

Iteration:  96%|██████████████████████████▉ | 482/500 [2:21:04<05:27, 18.22s/it][A
Iteration:  97%|███████████████████████████ | 483/500 [2:21:24<05:11, 18.31s/it][A
Iteration:  97%|███████████████████████████ | 484/500 [2:21:43<04:53, 18.34s/it][A
Iteration:  97%|███████████████████████████▏| 485/500 [2:22:03<04:36, 18.41s/it][A
Iteration:  97%|███████████████████████████▏| 486/500 [2:22:23<04:18, 18.49s/it][A
Iteration:  97%|███████████████████████████▎| 487/500 [2:22:35<03:56, 18.16s/it][A
Iteration:  98%|███████████████████████████▎| 488/500 [2:22:46<03:33, 17.82s/it][A
Iteration:  98%|███████████████████████████▍| 489/500 [2:23:04<03:16, 17.84s/it][A
Iteration:  98%|███████████████████████████▍| 490/500 [2:23:19<02:56, 17.68s/it][A
Iteration:  98%|███████████████████████████▍| 491/500 [2:23:39<02:40, 17.78s/it][A
Iteration:  98%|███████████████████████████▌| 492/500 [2:23:59<02:23, 17.89s/it][A
Iteration:  99%|███████████████████████████▌| 493/500 [2:24:11<02:03, 17.60s

In [78]:
metrics = trainer.evaluate()
metrics

***** Running evaluation *****


{'accuracy': 0.7150837988826816}

In [79]:
trainer.model._save_pretrained('model')

In [80]:
y_pred = model.predict(test_set['text'])

print(classification_report(test_set['label'], y_pred, digits=4))
print(confusion_matrix(test_set['label'], y_pred))

              precision    recall  f1-score   support

         0.0     0.6875    0.6600    0.6735        50
         1.0     0.8409    0.7400    0.7872        50
         2.0     0.5763    0.6800    0.6239        50
         3.0     0.8571    0.8276    0.8421        29

    accuracy                         0.7151       179
   macro avg     0.7405    0.7269    0.7317       179
weighted avg     0.7268    0.7151    0.7187       179

[[33  3 11  3]
 [ 1 37 12  0]
 [12  3 34  1]
 [ 2  1  2 24]]
