imports

In [1]:
import torch
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from tqdm.notebook import tqdm as tqdm
from PIL import Image, ImageDraw, ImageFont
import math

from HGNN.train.configParser import ConfigParser, getExperimentParamsAndRecord
from HGNN.train import CNN, dataLoader

parameters

In [2]:
experimentsPath="/raid/elhamod/Fish/experiments/"
dataPath="/raid/elhamod/Fish/"
experimentName="biology_paper_medium_curated3_50_30"
trial_hash="fc3f98edf24b6b374e87720d07c377ba5d995dedb35728ab8c377d62"

cuda=6

numOfRows=None
plotCorrectlyClassified=True

cuda

In [3]:
# set cuda
if torch.cuda.is_available():
    torch.cuda.set_device(cuda)
    print("using cuda", cuda)

using cuda 6


Get dataset

In [4]:
experimentPathAndName = os.path.join(experimentsPath, experimentName)
experiment_params, experimentRecord = getExperimentParamsAndRecord(experimentsPath, experimentName, trial_hash)
resolution = experiment_params['img_res']
print(experiment_params)

config_parser = ConfigParser(experimentsPath, dataPath, experimentName)
datasetManager = dataLoader.datasetManager(experimentPathAndName, dataPath, True)
datasetManager.updateParams(config_parser.fixPaths({**experiment_params,**{'augmented': False}}))
train_dataset, _, test_dataset = datasetManager.getDataset()
train_dataset.toggle_image_loading(augmentation=False, normalization=train_dataset.normalization_enabled) # Needed so we always get the same prediction accuracy 
fineList = train_dataset.csv_processor.getFineList()
coarseList = train_dataset.csv_processor.getCoarseList()

# get a test loader without randomization
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1)

{'image_path': 'Curated3/Medium', 'suffix': 'curated_30_50', 'img_res': 448, 'augmented': True, 'batchSize': 64, 'learning_rate': 0.0001, 'numOfTrials': 5, 'fc_layers': 1, 'modelType': 'BB', 'lambda': 0.01, 'unsupervisedOnTest': False, 'tl_model': 'ResNet18', 'link_layer': 'avgpool', 'adaptive_smoothing': True, 'adaptive_lambda': 0.01, 'adaptive_alpha': 0.9}
Creating datasets...
Creating datasets... Done.


Get untrained model

In [5]:
architecture = {
    "fine": len(fineList),
    "coarse" : len(coarseList)
}
model = CNN.create_model(architecture, experiment_params, cuda)

# get the model and the parameters
modelName = experimentRecord.iloc[0]["modelName"]
trialName = os.path.join(experimentPathAndName, modelName)
_ = CNN.loadModel(model, trialName, cuda)

sort through predictions

In [6]:
df_misclassified = pd.DataFrame(columns=['file name', 'true label', 'probability of true label', 'predicted label'])
df_correctlyclassified_columns = ['file name', 'true label', 'probability of true label']
df_correctlyclassified = pd.DataFrame(columns=df_correctlyclassified_columns)

# get probability of correct prediction and true label
predProblist, lbllist = CNN.getLoaderPredictionProbabilities(test_loader, model, experiment_params, device=cuda)
_, predlist = torch.max(predProblist, 1)
lbllist = lbllist.reshape(lbllist.shape[0], -1)
# True label
correct_predProblist = predProblist.gather(1, lbllist)
correct_predProblist = correct_predProblist.reshape(1, -1)
correct_predProblist = correct_predProblist[0]
# Predicted label
predicted_predProblist = predProblist.gather(1, predlist.unsqueeze(0).T)
predicted_predProblist = predicted_predProblist.reshape(1, -1)
predicted_predProblist = predicted_predProblist[0]

for i, lbl in enumerate(tqdm(lbllist)):
    prd = predlist[i]
    correctProb = correct_predProblist[i]
    prdProb = predicted_predProblist[i]
    fileName = test_dataset[i]['fileName']
    fileNameFull = test_dataset[i]['fileNameFull']
    
    if torch.cuda.is_available():
        lbl = lbl.cpu()
        prd = prd.cpu()
        prdProb = prdProb.cpu()
        correctProb = correctProb.cpu()
    
    if(lbl != prd):
        row = {'file name' : fileName ,
           'full file name' : fileNameFull ,
           'true label' : int(lbl.numpy()), 
           'probability of true label': float(correctProb.numpy()),
           'probability of predicted label': float(prdProb.numpy()),
           'predicted label' : int(prd.numpy())}
        df_misclassified = df_misclassified.append(row, ignore_index=True)
    else:
        row = {'file name' : fileName ,
           'full file name' : fileNameFull ,
           'true label' : int(lbl.numpy()), 
           'probability of true label': float(correctProb.numpy())}
        df_correctlyclassified = df_correctlyclassified.append(row, ignore_index=True)

HBox(children=(FloatProgress(value=0.0, max=977.0), HTML(value='')))




In [7]:
test_predProblist = predProblist
train_predProblist, _ = CNN.getLoaderPredictionProbabilities(train_loader, model, experiment_params, device=cuda)

In [8]:
df_misclassified = df_misclassified.sort_values(by=[ 'true label', 'probability of true label'])
df_correctlyclassified = df_correctlyclassified.sort_values(by=['true label', 'probability of true label'])

define function to plot top n of a category

In [9]:
# Create indices per label
class_training_indices_dict = {}
indx=0
for batch in tqdm(train_loader, desc="Getting sub-indices"):
    lbl = batch['fine'].item()
    if lbl not in class_training_indices_dict.keys():
        class_training_indices_dict[lbl] = [indx]
    else:
        class_training_indices_dict[lbl].append(indx)
    indx = indx + 1

HBox(children=(FloatProgress(value=0.0, description='Getting sub-indices', max=3124.0, style=ProgressStyle(des…




In [10]:
# Given a data frame of specimen, prints a pdf of a grid of those examples with information about them
# showPrediction: Should only used to show misclassifications
# showClosestClassTrainingExample: Should only used to show misclassifications
def plot_top_n(df, fig_file_name, class_training_indices_dict, ismisclassified, numOfRows=None, perRow=5):
    font = ImageFont.truetype(font='DejaVuSans.ttf', size=int(float(resolution) / 30))
    
    # construct results data frame
    h_list = ['image','image','image','image','image',
              'closest predicted class example from training set','closest predicted class example from training set','closest predicted class example from training set',]
    h2_list = ['file name','true label','probability of true label','predicted label','probability of predicted label',
             'file name','true label','cosine similarity',]
    
#     if show_same_class:
    h_list = h_list + ['closest same class example from training set','closest same class example from training set']
    h2_list = h2_list + ['file name','cosine similarity']
        
    df_result = pd.DataFrame(columns = [np.array(h_list), np.array(h2_list)]  )   
    
    # Disable augmentation
    augmentation, normalization, _ = test_dataset.toggle_image_loading(augmentation=False, normalization=test_dataset.normalization_enabled)
    
    if numOfRows is None:
        numOfRows = len(fineList)
#         numOfRows = df['true label'].nunique()
    topn = df.groupby('true label').head(perRow)
                        
    rows_per_page = 10
    number_of_pages = math.floor(numOfRows/rows_per_page)+1
    with tqdm(total=perRow * numOfRows, desc="figure") as bar:
        with PdfPages(os.path.join(experimentPathAndName, modelName, fig_file_name+".pdf")) as pdf:
            for k in range(number_of_pages):
                fig, axes = plt.subplots(ncols=perRow, nrows=rows_per_page, figsize=(15, 4*rows_per_page), dpi= 300)

                for i, row in enumerate(axes):
                    if i >= numOfRows + k*rows_per_page:
                        break
                        
                    topn_lbl = topn[topn['true label']==i+k*rows_per_page]
                    for j, ax in enumerate(row):

                        if len(topn_lbl.index) > j:

                            entry = topn_lbl.iloc[j]
                            fileName = entry['file name']
                            fileNameFull = entry['full file name']
                            trueLabel = entry['true label']
                            correct_prob = entry['probability of true label']
                            
                            if ismisclassified:
                                prediction = entry['predicted label']
                                predicted_prob = entry['probability of predicted label']
                            
                            img = Image.open(fileNameFull)
                            img = resize(img) 
                            img_ = rotateImageIfNeeded(img)
                            
                            if ismisclassified:
                                predicted_class_training_indices = class_training_indices_dict[prediction]
                                predicted_class_training_dataset = torch.utils.data.Subset(train_dataset, predicted_class_training_indices)

                                # get closest training image from dataset training set
                                closest, cosine_score = get_closest_example(fileName, test_dataset, test_predProblist, predicted_class_training_dataset, train_predProblist[predicted_class_training_indices, :])
                                closest_fileName = closest['fileNameFull']
                                closest_species = fineList[closest['fine']]
                                img2 = Image.open(closest_fileName)
                                img2 = resize(img2) 
                                draw = ImageDraw.Draw(img2)
                                draw.text((0, 0),"closest prediction\n" + closest_species + "\n" + closest['fileName'],(255,0,0), font=font)
                            
                            
                                img2_ = rotateImageIfNeeded(img2)
                        
                                vis = Image.new('RGB', (img_.width, img_.height + img2_.height))
                                vis.paste(img_, (0, 0))
                                vis.paste(img2_, (0, img_.height))
                                
                            else:
                                vis=img_
                            
                            
                            # get closest training image from dataset training set of same class
                            # get subset of trainign set that corresponds to the true label
                            class_training_indices = class_training_indices_dict[trueLabel]
                            class_training_dataset = torch.utils.data.Subset(train_dataset, class_training_indices)

                            closest_fromClass, cosine_score_fromClass = get_closest_example(fileName, test_dataset, test_predProblist, class_training_dataset, train_predProblist[class_training_indices, :])
                            closest_fromClass_fileName = closest_fromClass['fileNameFull']
                            closest_fromClass_species = fineList[closest_fromClass['fine']]
                            img3 = Image.open(closest_fromClass_fileName)
                            img3 = resize(img3) 

                            draw2 = ImageDraw.Draw(img3)
                            draw2.text((0, 0),"closest same label\n" + closest_fromClass['fileName'],(255,0,0), font=font)
                            img3_ = rotateImageIfNeeded(img3)

                            vis2 = Image.new('RGB', (vis.width, vis.height + img3_.height))
                            vis2.paste(vis, (0, 0))
                            vis2.paste(img3_, (0, vis.height))

                            ax.imshow(vis2)
                            txt = f"{fileName} \n {fineList[trueLabel]}"
                            if ismisclassified:
                                txt = txt + f" \n as {fineList[prediction]}"
                            ax.set_title(txt)
                            
                            # add to dataframe
                            row = {
                                ('image', 'file name'): fileName,
                                ('image', 'true label'): fineList[trueLabel],
                                ('image', 'probability of true label'): round(correct_prob, 3),
                                ('closest same class example from training set', 'file name'): closest_fromClass_fileName,
                                ('closest same class example from training set', 'cosine similarity'): round(cosine_score_fromClass, 3),
                            }
                            if ismisclassified:
                                row = {**row, **{
                                    ('image', 'predicted label'): fineList[prediction],
                                    ('image', 'probability of predicted label'): round(predicted_prob, 3),
                                    ('closest predicted class example from training set', 'file name'): closest_fileName,
                                    ('closest predicted class example from training set', 'true label'): closest_species,
                                    ('closest predicted class example from training set', 'cosine similarity'): round(cosine_score, 3),
                               }}
                            df_result = df_result.append(row, ignore_index=True)

                        bar.update()

                fig.tight_layout(rect=[0, 0.03, 1, 0.95])
                pdf.savefig()
                df_result.to_csv(os.path.join(experimentPathAndName, modelName, fig_file_name+".csv"))
                plt.close()

    # Reenable aggregation if needed.
    test_dataset.toggle_image_loading(augmentation=augmentation, normalization=normalization)

def get_closest_example(fileName, source_dataset, source_dataset_predProblist, target_dataset, target_dataset_predProblist):
    example=source_dataset_predProblist[source_dataset.getIdxByFileName(fileName), :]
    top_1 = torch.topk(CNN.get_distance_from_example2(target_dataset_predProblist, example), 1)
    closest = target_dataset[top_1.indices[0][0].item()]
    cosine_score = top_1.values[0][0].item()
    return closest, cosine_score

# Makes sure that the width is more than the height
def rotateImageIfNeeded(img):
    width, height = img.size
    if width < height:
        return img.rotate(90, expand=True)
    else:
        return img
    
def resize(img):
    width, height = img.size
    max_ = min(resolution/width, resolution/height)
    return img.resize((int(max_*width), int(max_*height)), Image.ANTIALIAS)

Display and save mispredicted

In [11]:
df_misclassified.to_csv(os.path.join(experimentPathAndName, modelName, 'misclassified examples.csv'))
plot_top_n(df_misclassified, "misclassified", class_training_indices_dict, ismisclassified=True, numOfRows=numOfRows)
df_misclassified

HBox(children=(FloatProgress(value=0.0, description='figure', max=510.0, style=ProgressStyle(description_width…




Unnamed: 0,file name,true label,probability of true label,predicted label,full file name,probability of predicted label
2,UWZM-F-0000005.JPG,0,0.012882,1,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.082710
0,INHS_FISH_101620.JPG,0,0.034373,1,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.065993
1,UWZM-F-0000004.JPG,0,0.044425,1,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.240449
3,83467_lat_FMNH_FZ.jpg,1,0.019879,13,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.036785
4,86978_lat_FMNH_FZ.jpg,1,0.021401,45,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.079175
...,...,...,...,...,...,...
199,UWZM-F-0003633.JPG,98,0.052314,94,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.080235
200,INHS_FISH_68537.jpg,100,0.075788,59,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.189533
203,JFBM-FISH-0011714.jpg,101,0.008779,91,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.038377
201,INHS_FISH_50838.jpg,101,0.028608,2,/raid/elhamod/Fish/Curated3/Medium/curated_30_...,0.049822


Display and save correctly predicted

In [12]:
df_correctlyclassified.to_csv(os.path.join(experimentPathAndName, modelName, 'correctly classified examples.csv'))
if plotCorrectlyClassified:
    plot_top_n(df_correctlyclassified, "correctly classified", class_training_indices_dict, ismisclassified=False, numOfRows=numOfRows)
df_correctlyclassified[df_correctlyclassified_columns]

HBox(children=(FloatProgress(value=0.0, description='figure', max=510.0, style=ProgressStyle(description_width…




Unnamed: 0,file name,true label,probability of true label
0,100788_lat_FMNH_FZ.jpg,0,0.245145
5,INHS_FISH_5773.JPG,0,0.523608
3,INHS_FISH_1157.JPG,0,0.790208
6,INHS_FISH_5837.JPG,0,0.826976
1,INHS_FISH_100595.jpg,0,0.887488
...,...,...,...
768,INHS_FISH_53056.jpg,101,0.418759
767,INHS_FISH_39579.jpg,101,0.463106
770,JFBM-FISH-0030701.jpg,101,0.482620
771,JFBM-FISH-0034781.jpg,101,0.875861
