In [5]:
import csv    
import pickle
import pandas as pd
import numpy as np
import torch
from tqdm.auto import tqdm
from collections import Counter

In [6]:
from sklearn.model_selection import StratifiedKFold, train_test_split

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.utils import shuffle, class_weight
from imblearn.under_sampling import RandomUnderSampler 

import matplotlib.pyplot as plt

In [7]:
embeddings_df = pd.read_csv('trex_embeddings_20000.csv', 
                      header=None, 
                      names=['embedding', 'triplet', 'rel_label'])

In [8]:
embeddings_df.head()

Unnamed: 0,embedding,triplet,rel_label
0,[6.22067600e-02 1.60363363e-03 3.17481868e-02 ...,"('Hell', 'outcast', 'Devilkin', '0')",0
1,[1.1821641e-02 2.5162415e-03 1.7185923e-02 1.9...,"('Dungeons & Dragons', 'roleplaying game', 'is...",P31
2,[2.41272449e-01 1.46490023e-01 1.23494724e-02 ...,"('England', 'Hartfield', 'Sussex', '0')",0
3,[4.42590602e-02 1.16817303e-01 8.24852288e-02 ...,"('Hartfield', 'civil parish', 'is a', 'P31')",P31
4,[4.13215496e-02 6.18278806e-04 1.09743709e-02 ...,"('area', 'miles', 'is', '0')",0


In [9]:
#  костыльное решение того что я плохо записываю np array в csv
def read_array(emb: str):
    new_array = []
    for row in emb.strip('[]').split('\n'):
        for dig in row.split():
            new_array.append(float(dig))
            
    return np.array(new_array).reshape(1, -1)

In [10]:
embeddings_df.embedding = embeddings_df.embedding.apply(read_array)

In [11]:
embs = np.array(list(embeddings_df.embedding.values)).squeeze()
labels = embeddings_df.rel_label.values

assert len(embs) == len(labels)

In [12]:
binary_labels = ['0' if label == '0' else '1' for label in labels]

In [13]:
Counter(binary_labels)

Counter({'0': 38595, '1': 35918})

### Бинарная классификация (мусор и не-мусор)

In [31]:
X_train, X_test, y_train, y_test = train_test_split(embs, binary_labels,
                                                    stratify=binary_labels, 
                                                    test_size=0.3)

In [32]:
lr_bin = LogisticRegression(max_iter=1000).fit(X_train, y_train)

with open('logreg_bin.pkl', 'wb') as file:
      pickle.dump(lr_bin, file)

In [33]:
print(classification_report(y_test, lr_bin.predict(X_test)))

              precision    recall  f1-score   support

           0       0.86      0.87      0.87     11579
           1       0.86      0.85      0.86     10775

    accuracy                           0.86     22354
   macro avg       0.86      0.86      0.86     22354
weighted avg       0.86      0.86      0.86     22354



### Многоклассовая классификация

In [19]:
#  оставляем только классы, встретившиеся больше 10 раз
labels_counter = Counter(labels)
labels_over10 = [label for label in labels_counter if labels_counter[label] > 10]

new_embs = np.array(list(embeddings_df[embeddings_df.rel_label.isin(labels_over10)].embedding.values)).squeeze()
new_labels = embeddings_df[embeddings_df.rel_label.isin(labels_over10)].rel_label.values

In [20]:
#  убираем мусорный класс
ros = RandomUnderSampler(sampling_strategy={'0': 0})
X_resampled, y_resampled = ros.fit_resample(new_embs, new_labels)

In [21]:
len(X_resampled)

35446

In [23]:
X = np.array(X_resampled)
y = np.array(y_resampled)

In [24]:
train, test = next(iter(StratifiedKFold().split(X_resampled, y_resampled)))

In [25]:
X_train, X_test = X[train], X[test]
y_train, y_test = y[train], y[test]

In [27]:
lr = LogisticRegression(class_weight='balanced', max_iter=1000).fit(X_train, y_train)
    
with open('logreg_multi.pkl', 'wb') as file:
      pickle.dump(lr, file)

In [28]:
print(classification_report(y_test, lr.predict(X_test)))

              precision    recall  f1-score   support

        P102       0.04      0.43      0.07         7
        P106       0.65      0.76      0.70        29
       P1066       0.33      0.20      0.25         5
        P112       0.52      0.40      0.45        30
        P113       0.25      1.00      0.40         2
        P115       0.61      0.68      0.64        25
        P118       0.72      0.81      0.76        42
        P119       1.00      0.50      0.67         4
       P1191       1.00      0.33      0.50         3
        P123       0.14      0.33      0.20         3
        P127       0.67      0.31      0.42        26
       P1303       0.50      0.83      0.62         6
        P131       0.95      0.72      0.82      1308
       P1336       0.10      0.50      0.17         2
       P1344       0.14      1.00      0.24         3
       P1346       0.50      0.30      0.37        10
        P135       0.08      0.33      0.13         3
        P136       0.70    