In [None]:
%matplotlib notebook
import numpy as np
import pandas as pd
import os
import ipywidgets as widgets
import matplotlib.pyplot as plt
import seaborn as sns
from dataAnalysis import train_model,transform_X_for_pca
from utils import save_sklearn_model, load_sklearn_model, read_tif,include_file_extension
from config import paths
from sklearn.metrics import ConfusionMatrixDisplay,precision_recall_curve,roc_curve,average_precision_score

# Train a classifier
### This trains a classifier and displays the confusion matrix w.r.t to the training and validation sets

In [None]:
filename = "20_mix_images"
df = pd.read_csv(os.path.join(paths["crops"],include_file_extension(filename,"tif")[:-4]+".csv"))
crops = read_tif(os.path.join(paths["crops"],filename))
random_seed=2023

pca, classifier = train_model(
    crops, df,
    pca_components=20,
    C=.05,noise_strength=15.,
    train_percent=90.,random_seed=random_seed,n_augmentations=6,
    symmetrize=True, normalize=True)

### Take a look at some "eigenfaces"

In [None]:
fig,axes = plt.subplots(5,4)
for i in range(20):
    col = i % 4
    row = i // 4    
    eigenface=np.hstack([pca.components_[i].reshape([24,5]),np.flip(pca.components_[i].reshape([24,5]),axis=1)])
    axes[row,col].imshow(eigenface)
    axes[row,col].set_axis_off()
fig.tight_layout()

### Save the model

In [None]:
model_name = "my_miR_model.pkl"
save_sklearn_model((pca,classifier),model_name)

# Classify crops using a model

In [None]:
filename = "20_mix_images"
model_name = "classifier_v0.pkl"

pca,classifier = load_sklearn_model(model_name)
crops = read_tif(os.path.join(paths["crops"],filename))

X,max_est_arr,good_mask = transform_X_for_pca(crops,normalize=True,symmetrize=True)    
for_classifier = np.hstack([pca.transform(X),max_est_arr.reshape(-1,1)])
y_mix_pred = classifier.predict(for_classifier)

fig,ax=plt.subplots(figsize=(4,3))
_=sns.histplot(np.sort(y_mix_pred).astype(str))