In [2]:
from scipy.stats import linregress
import numpy as np
import itertools
import matplotlib.pyplot as plt
from lightgbm import LGBMClassifier
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix


def get_linear_reg(measures):
    slope, intercept, r_value, p_value, std_err = linregress(range(len(measures)), measures)
    return slope, r_value


def get_quadratic_reg(measures):
    try:
        x = range(len(measures))
        a, b, c = np.polyfit(x, measures, 2)
        return (a, b)
    except:
        return (None, None)
    

def fit_model(X_train, y_train, X_eval, y_eval):
    model = LGBMClassifier(
        objective='multiclass',
        num_iterations=300,
        learning_rate=0.05,
        verbose=0,
        num_class=3,
        force_col_wise=True
    )

    return model.fit(
        X_train, 
        y_train, 
        eval_set=(X_eval, y_eval), 
        eval_metric='logloss'
    )   
    
    
def eval_pred(clf, X_test, y_test):
    CLASSES = [0, 1, 2]
    y_pred = clf.predict(X_test)
    plot_confusion_matrix(confusion_matrix(y_test, y_pred, labels=CLASSES), classes=CLASSES)
    print(classification_report(y_test, y_pred))
    

def eval_model(clf, X, y):
    cv = StratifiedKFold(n_splits=20, random_state=42, shuffle=True)
    print(f'Cross-Validation splits score: {cross_val_score(clf, X, y, cv=cv)}')  
    
    
def plot_confusion_matrix(cm, classes, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], 'd'), horizontalalignment="center",  
                 color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()