In [1]:
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import numpy as np
import os
import pickle

# Загрузка данных

In [2]:
os.chdir('..')

In [3]:
data_filename = os.path.join('data', f'X_y_characteristics.pkl')

X = None
y = None

if os.path.exists(data_filename):
    with open(data_filename, 'rb') as f:
        data = pickle.load(f)
    
    X = data['X']
    y = data['y']

X = np.array(X)
y = np.array(y)

In [4]:
from collections import Counter
counter = Counter(y)
counter

Counter({'LTR': 203,
         'Helitron': 164,
         'DNA/MuDR': 134,
         'LINE': 116,
         'DNA+': 89,
         'TEG': 34,
         'DNA/HAT': 28,
         'Mix with Helitron': 19,
         'Mix': 18,
         'Unassigned': 17,
         'RathE1/2/3_cons': 7,
         'SINE': 7})

Фильтруем, берем только те, которых много.

In [5]:
families_to_filter = ['LTR', 'Helitron', 'DNA/MuDR', 'LINE']

indices = np.isin(y, families_to_filter)

X_filtered = X[indices]
y_filtered = y[indices]

Делаем равномерные классы

In [6]:
def balance_classes(X, y):
    X = np.array(X)
    y = np.array(y)
    # Определяем количество элементов в каждом классе
    class_counts = Counter(y)
    min_count = min(class_counts.values())

    # Собираем индексы для каждого класса
    indices_by_class = {cls: np.where(y == cls)[0] for cls in class_counts}

    # Оставляем только min_count элементов для каждого класса
    balanced_indices = []
    for cls, indices in indices_by_class.items():
        balanced_indices.extend(indices[:min_count])

    np.random.shuffle(balanced_indices)

    X_balanced = X[balanced_indices]
    y_balanced = y[balanced_indices]

    return X_balanced, y_balanced

In [7]:
X_balanced, y_balanced = balance_classes(X_filtered, y_filtered)

In [9]:
X_balanced.shape

(464, 109)

In [10]:
y_balanced.shape

(464,)

In [13]:
clf = RandomForestClassifier(random_state=42)

# Кросс-валидация
y_pred_cv = cross_val_predict(clf, X_balanced, y_balanced, cv=5)

# Результат
print(classification_report(y_balanced, y_pred_cv))

# Таблица сопряженности
data_cv = pd.DataFrame({'1': y_balanced, '2': y_pred_cv})
contingency_table_cv = pd.crosstab(data_cv['1'], data_cv['2'])

print(contingency_table_cv)

              precision    recall  f1-score   support

    DNA/MuDR       0.62      0.64      0.63       116
    Helitron       0.70      0.68      0.69       116
        LINE       0.79      0.73      0.76       116
         LTR       0.74      0.78      0.76       116

    accuracy                           0.71       464
   macro avg       0.71      0.71      0.71       464
weighted avg       0.71      0.71      0.71       464

2         DNA/MuDR  Helitron  LINE  LTR
1                                      
DNA/MuDR        74        23     7   12
Helitron        23        79     8    6
LINE            13         4    85   14
LTR             10         7     8   91


In [19]:
clf = RandomForestClassifier(random_state=42)
clf.fit(X_balanced, y_balanced)
# Сохранение модели
file_path = 'models_files/random_forest_balanced_01.pkl'
if not os.path.exists(file_path):
    with open(file_path, 'wb') as f:  
        pickle.dump(clf, f)
    print("Завершено")

Завершено
