In [82]:
import pandas as pd
from imblearn.over_sampling import SMOTE
from collections import Counter
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix

### Train Classification Model
Classes: 0, 1, 2, 3 - relevance score provided by dataset

3. Evaluate accuracy, precision, recall, F1 SCore, AUC
4. Save model to evaluate against eval set

#### Create dataset to use for training
1. Start with X = cosine similarity and y = relvance score
2. Later add more features to X (code and docstring embeddings, results from static analysis?)

In [74]:
df = pd.read_pickle('data/sim.pickle')
df = df[['split', 'sim', 'relevance']]
df = df.dropna()
df.head()

Unnamed: 0,split,sim,relevance
0,train,0.685977,2
1,train,0.469626,2
2,train,0.49025,2
3,train,0.568199,3
4,train,0.166293,3


In [76]:
train_df = df[df.split == 'train']
test_df = df[df.split == 'test']

test_df.head()

Unnamed: 0,split,sim,relevance
291,test,0.384595,1
292,test,0.777225,1
293,test,0.285888,2
294,test,0.171154,0
296,test,0.536571,0


In [79]:
X_train = train_df.sim.values.reshape(-1, 1)
y_train = train_df.relevance.values.reshape(-1, 1)

X_test = test_df.sim.values.reshape(-1, 1)
y_test = test_df.relevance.values.reshape(-1, 1)

X_train[:5]

array([[0.68597656],
       [0.46962649],
       [0.49024951],
       [0.56819922],
       [0.16629307]])

#### Sample dataset to balance classes

In [80]:
smote = SMOTE()
X_smote, Y_smote = smote.fit_resample(X_train, y_train)

Counter(Y_smote)

Counter({2: 112, 3: 112, 0: 112, 1: 112})

#### Train model

In [88]:
clf = LogisticRegression(multi_class='ovr')

clf.fit(X_smote, Y_smote)

y_pred = clf.predict(X_test)

print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred, zero_division=True))

[[0 2 1 0]
 [0 2 5 0]
 [0 0 1 0]
 [0 1 1 0]]
              precision    recall  f1-score   support

           0       1.00      0.00      0.00         3
           1       0.40      0.29      0.33         7
           2       0.12      1.00      0.22         1
           3       1.00      0.00      0.00         2

    accuracy                           0.23        13
   macro avg       0.63      0.32      0.14        13
weighted avg       0.61      0.23      0.20        13

