The purpose of this notebook is to develop a model to predict whether a cell is normal or abnomal based on nucleus features.

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import statistics as stats
from os import listdir
from skimage import io, filters, color,exposure,feature,measure,segmentation
from scipy import ndimage
import cv2 
import imutils
import glob
import sys
sys.path.append('../scripts/')
from nuclei_segmentation_opencv import nuclei_segmenter

In [4]:
df_ext = pd.read_csv('../data/processed/Sipakmed_nuclei_database.csv')

#### To encourage the model to generalize across the specifics of certain images, we drop all image intensity features (esp wrt color)

In [352]:
df = df_ext.copy()
# restrict to the features of interest
df = df[['cluster_id','Class','area','major_axis_length','minor_axis_length','major_to_minor','eccentricity','solidity','Normal']]
df.head()

Unnamed: 0,cluster_id,Class,area,major_axis_length,minor_axis_length,major_to_minor,eccentricity,solidity,Normal
0,1,s,108300.0,394.1,363.96,1.082811,0.38358,0.94167,1
1,1,s,50590.0,330.71,217.29,1.521975,0.75385,0.88165,1
2,1,s,95258.0,450.82,290.24,1.553266,0.76518,0.9068,1
3,1,s,84199.0,363.33,310.51,1.170107,0.51924,0.91911,1
4,1,s,98175.0,433.48,303.43,1.4286,0.71416,0.92263,1


#### Train classifier, use GroupKFold Cross-Validation

In [278]:
# imports
from sklearn.metrics import classification_report,accuracy_score, confusion_matrix, roc_curve, roc_auc_score, precision_score, recall_score, precision_recall_curve, auc, f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import cross_val_score

In [295]:
# create classifier
clf = RandomForestClassifier(n_estimators=100,random_state=0)
# get data
X = df.drop(['cluster_id','Class','Normal'],axis=1)
y = df['Normal']
# group by cluster: make sure no cells from the same cluster appear in both the training and test sets
# (this is important to minimize data leakage)
groups = df['cluster_id'].values
# train and get the GroupKFold roc_auc score
scores = cross_val_score(clf,X,y,groups=groups,cv=GroupKFold(n_splits=5),scoring='roc_auc')
print("Average cross-validation roc_auc score: {:.2f}".format(scores.mean()))

Average cross-validation roc_auc score: 0.91


#### Grid search for hyper parameter tuning 

In [322]:
# Separate a validation set, respecting cluster boundaries
X_train = X[df['cluster_id']>15]
y_train = y[df['cluster_id']>15]
groups_train = groups[df['cluster_id']>15]
X_validate = X[df['cluster_id']<=15]
y_validate = y[df['cluster_id']<=15]

In [353]:
# Grid search over paramters
from sklearn.model_selection import GridSearchCV
param_grid = {'max_depth':[5,10,20,40],
             'min_samples_leaf':[2,5,10]}
grid_search = GridSearchCV(clf, param_grid, cv=GroupKFold(n_splits=5))
grid_search.fit(X_train,y_train,groups=groups_train)
grid_search.best_params_

{'max_depth': 20, 'min_samples_leaf': 5}

In [351]:
# validate the model on the held out validation set
preds = grid_search.predict(X_validate)
print(f'Classification report for true dataset: \n{classification_report(y_validate,preds)}')
print()
print(f'Confusion matrix for true dataset: \n{confusion_matrix(y_validate,preds)}')
print()
# precision recall curve - better than the ROC for unbalanced data
precision, recall, thresholds = precision_recall_curve(y_validate,preds)
print(f'Area under the precision-recall curve: \n{auc(recall,precision)}')

Classification report for true dataset: 
              precision    recall  f1-score   support

           0       0.73      0.90      0.81       136
           1       0.94      0.81      0.87       235

    accuracy                           0.84       371
   macro avg       0.83      0.86      0.84       371
weighted avg       0.86      0.84      0.85       371


Confusion matrix for true dataset: 
[[123  13]
 [ 45 190]]

Area under the precision-recall curve: 
0.9328825149849803


#### Save the model 

In [356]:
import pickle
filename = '../models/Sipakmed_model'
pickle.dump(grid_search,open(filename,'wb'))