 #  fastai-v1  inference on Validation data - GradCAM experiment

## Import libraries

In [None]:
# Put these at the top of every notebook, to get automatic reloading and inline plotting
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# setup CUDA_VISIBLE DEVICES for titan.sci.utah.edu
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [None]:
#Import libraries - fastai_v1

from fastai.vision import *
from fastai.metrics import error_rate


import numpy as np
import pandas as pd
import seaborn as sns

from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels

from gradcam import *

## I/O and hyper parameters

In [None]:
# Parameters and hyper-parameters

# CSV file contains test dataset only (synthetic data)
csv_test_FileName = 'Dataset_TargetClass_Overlap-9Blocks_25000xOnly_shuffled_fastai-v1_val.csv'
csv_test = os.path.join('../CSV_InputFiles_TargetClass',csv_test_FileName)

csv_result = os.path.join(os.getcwd(),'Dataset_TargetClass_Overlap-9Blocks_25000xOnly_shuffled_fastai-v1_val-Prediction.csv')

csv_result_MajVoting = os.path.join(os.getcwd(),'Dataset_TargetClass_Overlap-9Blocks_25000xOnly_shuffled_fastai-v1_val-Prediction_MajVoting.csv')

# Network
model_path = os.path.join(os.getcwd(),'models')
model_file = ('TargetClass_fastai-v1_224_all_resnet50.pkl')

# Network architecture
arch = models.resnet50
# Image size
sz = 224
# Batch size
bs = 32
# Default learning rate
lr = 0.01

## Define Test dataset

In [None]:
# Read csv file and create dataframe
df_test = pd.read_csv(csv_test, sep=',')
df_test.head()

In [None]:
df_test.shape

In [None]:
df_test.groupby(['Label']).size()

In [None]:
df_test_size = df_test.groupby(['Label']).size()
#df_test_size = df_test_size.reindex(classes_Labels_ordered)
df_test_size


In [None]:
# Generate bar graph
# pd.value_counts(df_test['Label']).sort_index().plot(kind='bar', title = 'Starting Material - test dataset')
# fig1 = plt.gcf()
# plt.tight_layout()
# fig1.savefig('BarGraph_Distribution_StartingMaterial_TestData.png')
# plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(5,5))
sns.set(style="whitegrid")
sns_plot = sns.countplot(x="Label", data=df_test)
sns_plot.set_xticklabels(sns_plot.get_xticklabels(), rotation=90)
plt.tight_layout()
plt.show()
fig = sns_plot.get_figure()
fig.savefig("BarGraph_Distribution_TargetClass_ValData.png")

## Deep Learning analysis

## Inference - Test dataset - without TTA

In [None]:
test = ImageList.from_csv(os.getcwd(), csv_test_FileName, folder='../Data_TargetClass')


In [None]:
test


In [None]:
# Main commands to load data and model
learn = load_learner(model_path,model_file, test=test)

In [None]:
learn

In [None]:
y_pred_test, _, losses = learn.get_preds(ds_type=DatasetType.Test,with_loss=True)


In [None]:
y_pred_test_classes = [learn.data.classes[np.argmax(pred)] for pred in y_pred_test]


In [None]:
print(y_pred_test[0])
print(y_pred_test[0].numpy())
print(np.sum(y_pred_test[0].numpy()))
print(np.argmax(y_pred_test[0]))
print(y_pred_test_classes[0])
print(losses[0])

In [None]:
print(y_pred_test_classes[:10])

In [None]:
#FileNames = [i.split('/', -1)[-1] for i in learn.data.test_ds.items]
FileNames = ['/'.join(i.split('/', -1)[-4:]) for i in learn.data.test_ds.items]
print(FileNames[:10])

In [None]:
# Create dataframe for prediction on test data
df_preds_test = pd.DataFrame({'File':FileNames, 'Prediction':y_pred_test_classes})
df_preds_test.head()

In [None]:
result = df_test.merge(df_preds_test,on='File',how='left')
result.shape


In [None]:
result.head()

In [None]:
# Save results as CSV file
result.to_csv(csv_result, index=False, na_rep = 'NA')

In [None]:
learn.data.classes

## Classification interpretation

In [None]:
# from ClassificationIntepretation object. 
interp = ClassificationInterpretation.from_learner(learn,ds_type=DatasetType.Test)

In [None]:
interp.most_confused()

## GradCAM

In [None]:
result.iloc[0]['File']

In [None]:
test_img = os.path.join('../Data_TargetClass/',result.iloc[0]['File'])
print(test_img)
img = open_image(test_img);



In [None]:
%%time
gcam = GradCam.from_one_img(learn,img)
gcam.plot(plot_hm=True,plot_gbp=True)

In [None]:
classes_Labels_ordered = learn.data.classes

In [None]:
import random 

# Find first element for each distinct predicted class
# plot gradCAM
for pred_class in classes_Labels_ordered:
    print(pred_class)
    pred_class_idx_list = result[result['Prediction']==pred_class].index.values
    # Pick random element
    pred_class_idx = random.choice(pred_class_idx_list)
    #print(pred_class_idx)
    File = result.iloc[pred_class_idx]['File']
    #print(File)
    test_img = os.path.join('../Data_TargetClass',File)
    #print(test_img)
    img = open_image(test_img);
    gcam = GradCam.from_one_img(learn,img)
    gcam.plot(plot_hm=True,plot_gbp=False)
    fig=plt.gcf()
    FigTitle = './GradCAM_Example_' + pred_class + '.png'
    fig.savefig(FigTitle)

In [None]:
import random

def plot_gradCAM_examples_correct(pred_class, nb=4):
    pred_class_idx_list_all = result[result['Prediction']==pred_class].index.values
    # Pick random nb elements
    pred_class_idx_list = random.choices(pred_class_idx_list_all,k=2*nb)
    counter = 0 
    for pred_class_idx in pred_class_idx_list:
        #print(pred_class_idx)
        Pred_Label = result.iloc[pred_class_idx]['Prediction']
        Actual_Label = result.iloc[pred_class_idx]['Label']
        if (Pred_Label == Actual_Label):
            counter += 1
            File = result.iloc[pred_class_idx]['File']
            #print(File)
            test_img = os.path.join('../Data_TargetClass',File)
            #print(test_img)
            img = open_image(test_img);
            gcam = GradCam.from_one_img(learn,img)
            gcam.plot(plot_hm=True,plot_gbp=False)
            fig=plt.gcf()
            FigTitle = './GradCAM-CorrectPred_' + pred_class + '_Example' + str(counter) + '.png'
            fig.savefig(FigTitle)
        else:
            print('Incorrect Prediction') 
        if (counter == nb):
            break

In [None]:
# Plot gradCAM of first 4 elements for a specific predicted class
for label in classes_Labels_ordered:
    plot_gradCAM_examples_correct(label,3)


In [None]:
import random

def plot_gradCAM_examples_incorrect(pred_class, nb=4):
    pred_class_idx_list_all = result[result['Prediction']==pred_class].index.values
    # Pick random nb elements
    pred_class_idx_list = random.choices(pred_class_idx_list_all,k=len(pred_class_idx_list_all))
    counter = 0 
    for pred_class_idx in pred_class_idx_list:
        #print(pred_class_idx)
        Pred_Label = result.iloc[pred_class_idx]['Prediction']
        Actual_Label = result.iloc[pred_class_idx]['Label']
        if (Pred_Label != Actual_Label):
            counter += 1
            File = result.iloc[pred_class_idx]['File']
            #print(File)
            test_img = os.path.join('../Data_TargetClass',File)
            #print(test_img)
            img = open_image(test_img);
            gcam = GradCam.from_one_img(learn,img, label1=Pred_Label, label2=Actual_Label)
            gcam.plot(plot_hm=True,plot_gbp=False)
            fig=plt.gcf()
            FigTitle = './GradCAM-IncorrectPred_' + pred_class + '_Example' + str(counter) + '.png'
            fig.savefig(FigTitle)
        #else:
            #print('Incorrect Prediction') 
        if (counter == nb):
            break

In [None]:
# Plot gradCAM of first 4 elements for a specific predicted class
for label in classes_Labels_ordered:
    plot_gradCAM_examples_incorrect(label,3)