# Classify organs present in ultrasound images

In [None]:
import os
import numpy
from fastai.vision import *

In [None]:
# Indicate the location of the pretrained torchvision models 
os.environ['TORCH_HOME'] = Path('models/resnet').absolute().as_posix()

## Data import

In [None]:
valid_frac = 0.2
data_path = "data/Output_organ/"
data = ImageDataBunch.from_folder(data_path, valid_pct=valid_frac)
#data = ImageDataBunch.from_folder(data_path, valid_pct=valid_frac,tfms=get_transforms())

## Model initial training

In [None]:
learn = cnn_learner(data, models.resnet34, metrics=accuracy)
#learn = cnn_learner(data, models.resnet34, metrics=[accuracy, auc_roc_score])

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
lr = 3e-3 #get lr from plot

In [None]:
callbacks_opt = [callbacks.SaveModelCallback(learn,
                                             monitor='accuracy',
                                             every='improvement',
                                             name='best_accuracy_resnet34'),
                callbacks.EarlyStoppingCallback(learn, monitor='accuracy', min_delta=.01, patience=3)]

In [None]:
learn.fit(10, lr, callbacks=callbacks_opt)

We select the best model from the early stopping

In [None]:
learn.load('best_accuracy_resnet34');

In [None]:
learn.summary()

### Check the confusion matrix to identify problematic classes

In [None]:
interp = ClassificationInterpretation.from_learner(learn)

In [None]:
interp.plot_confusion_matrix(normalize=False, figsize=(14, 8))

## Fine tuning of the model

In [None]:
learn.unfreeze()

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
callbacks_opt = [callbacks.SaveModelCallback(learn,
                                             monitor='accuracy',
                                             every='improvement',
                                             name='best_accuracy_unfreeze_resnet34'),
                callbacks.EarlyStoppingCallback(learn, monitor='accuracy', min_delta=.003, patience=3)]

In [None]:
lr_max = 1e-4

In [None]:
learn.fit(10, slice(lr_max/5, lr_max), callbacks=callbacks_opt)

In [None]:
interp_02 = ClassificationInterpretation.from_learner(learn)

In [None]:
interp_02.plot_confusion_matrix(normalize=False, figsize=(14, 8))

In [None]:
interp_02.plot_top_losses(9)

In [None]:
interp_02.plot_multi_top_losses()

In [None]:
#freeze bottom layer
learn.freeze()

In [None]:
learn.summary()

## data labelling for binary classification

In [None]:
#valid_frac2 = 0.2
data_path2 = "data/Organ_Liver_Others/"

In [None]:
#customize label    
data2 = ImageDataBunch.from_folder(data_path2, valid_pct=valid_frac, size =224)).normalize(imagenet_stats)
#data2 = ImageDataBunch.from_folder(data_path2, valid_pct=valid_frac, size =224, ds_tfms=get_transforms()).normalize(imagenet_stats)

## binary classification model

In [None]:
learn11 = learn.load('best_accuracy_unfreeze_resnet34')

In [None]:
print (learn11.model[0])

In [None]:
print (learn11.model[1])

In [None]:
learn2 = cnn_learner(data2, models.resnet34, metrics=AUROC())
#If "AUROC()" is the wrong statement, change it to "auc_roc_score"

In [None]:
learn2.model[0].load_state_dict(learn11.model[0].state_dict())

In [None]:
learn2.lr_find() 
learn2.recorder.plot()   

In [None]:
lr = 1e-2

In [None]:
callbacks_opt = [callbacks.SaveModelCallback(learn2,
                                             monitor='AUROC()',
                                             every='improvement',
                                             name='best_AUROC_binary_resnet34'),
                callbacks.EarlyStoppingCallback(learn2, monitor='AUROC()', min_delta=.01, patience=3)]
#If "AUROC()" is the wrong statement, change it to "auc_roc_score"

In [None]:
learn2.fit(10, lr, callbacks=callbacks_opt)

In [None]:
learn2.summary()

In [None]:
interp3 = ClassificationInterpretation.from_learner(learn2)

In [None]:
interp3.plot_confusion_matrix(normalize=False, figsize=(8, 8))

In [None]:
interp3.plot_top_losses(9)

In [None]:
interp3.plot_multi_top_losses()

## prediction

In [None]:
#img = learn2.data.test_ds[0][0]
img = Image.open(path)

In [None]:
img 

In [None]:
learn2.predict(img)