In [None]:
%pip install pandas sentence-transformers setfit scikit-learn datasets

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [None]:
import pandas as pd
import json
import os
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss
from datasets import Dataset
from sklearn.metrics import classification_report
from collections import defaultdict

In [None]:
BASE_MODEL = "all-mpnet-base-v2"
RANDOM_SEED = 42
OUTPUT_PATH = 'output'

In [None]:
train_set = pd.read_csv("data/issues_train.csv")
test_set = pd.read_csv("data/issues_test.csv")

In [None]:
repos = set(train_set["repo"].unique())
print(repos)

{'bitcoin/bitcoin', 'microsoft/vscode', 'opencv/opencv', 'tensorflow/tensorflow', 'facebook/react'}


In [None]:
train_set.groupby(["repo", "label"]).size().unstack(fill_value=0)

label,bug,feature,question
repo,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
bitcoin/bitcoin,100,100,100
facebook/react,100,100,100
microsoft/vscode,100,100,100
opencv/opencv,100,100,100
tensorflow/tensorflow,100,100,100


In [None]:
def process_dataset(dataset):
    dataset['text'] = dataset['title'] + " " + dataset['body']
    dataset = dataset[['text', 'label', 'repo']]
    return dataset

In [None]:
train_set = process_dataset(train_set)
test_set = process_dataset(test_set)

In [None]:
group_by_repo = lambda dataset: {
    repo: Dataset.from_pandas(dataset[dataset["repo"] == repo]).class_encode_column("label")
    for repo in dataset["repo"].unique()
}

train_sets = group_by_repo(train_set)
test_sets = group_by_repo(test_set)

In [None]:
datasets = {
    repo: {'train': train_sets[repo], 'test': test_sets[repo]} for repo in train_sets.keys()
}

In [None]:
results = defaultdict(dict)
for repo in datasets.keys():
    train_set, test_set = datasets[repo]['train'], datasets[repo]['test']
    model = SetFitModel.from_pretrained(BASE_MODEL)

    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_set,
        loss_class=CosineSimilarityLoss,
        metric="accuracy",
        batch_size=16,
        num_epochs=1,
        num_iterations=20,
    )
    trainer.train()
    y_pred = trainer.model.predict(test_set['text'])
    results[repo]['metrics'] = classification_report(test_set['label'], y_pred, digits=4, output_dict=True)
    results[repo]['predictions'] = y_pred.tolist()
    results['label_mapping'] = {train_set.features["label"].int2str(x): x for x in range(train_set.features["label"].num_classes)}

In [None]:
print(results['label_mapping'])
for repo in repos:
    print(repo)
    print(json.dumps(results[repo]['metrics'], indent=4))

{'bug': 0, 'feature': 1, 'question': 2}
bitcoin/bitcoin
{
    "0": {
        "precision": 0.7604166666666666,
        "recall": 0.73,
        "f1-score": 0.7448979591836735,
        "support": 100.0
    },
    "1": {
        "precision": 0.8723404255319149,
        "recall": 0.82,
        "f1-score": 0.8453608247422681,
        "support": 100.0
    },
    "2": {
        "precision": 0.6454545454545455,
        "recall": 0.71,
        "f1-score": 0.6761904761904762,
        "support": 100.0
    },
    "accuracy": 0.7533333333333333,
    "macro avg": {
        "precision": 0.759403879217709,
        "recall": 0.7533333333333333,
        "f1-score": 0.7554830867054726,
        "support": 300.0
    },
    "weighted avg": {
        "precision": 0.759403879217709,
        "recall": 0.7533333333333333,
        "f1-score": 0.7554830867054726,
        "support": 300.0
    }
}
microsoft/vscode
{
    "0": {
        "precision": 0.8484848484848485,
        "recall": 0.84,
        "f1-score": 0.844

In [None]:
f1_scores = [results[repo]['metrics']['macro avg']['f1-score'] for repo in repos]
mean_score = sum(f1_scores) / len(f1_scores)

print(f"Mean F1 score: {mean_score}")

Mean F1 score: 0.8270463747086945


In [None]:
output_file_name = 'results.json'
with open(os.path.join(OUTPUT_PATH, output_file_name), 'w') as fp:
    json.dump(results, fp)