In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random, getopt, os, sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))

from Lamp.AttrDict.AttrDict import *
from Lamp.Model.Dataloader import *
from Lamp.Model.BaseModel import *
from Lamp.Model.Resnet import *

In [43]:
def load_config(cfg_path):
    """  """
    if os.path.splitext(cfg_path)[-1] == '.json':
        return AttrDict.from_json_path(cfg_path)
    elif os.path.splitext(cfg_path)[-1] in ['.yaml', '.yml']:
        return AttrDict.from_yaml_path(cfg_path)
    else:
        raise ValueError(f"Unsupported config file format. Only '.json', '.yaml' and '.yml' files are supported.")

def resnet(layers=[3, 4, 6, 3],channels=3, num_classes=1000):
    model = ResNet(BasicBlock,layers,channels=channels,num_classes=num_classes)
    return model

class Classifier(BaseModelSingle):
    def __init__(self, net: nn.Module, opt: Optimizer = None, sched: _LRScheduler = None, 
        logger: Logger = None, print_progress: bool = True, device: str = 'cuda:0', **kwargs):
        super().__init__(net, opt=opt, sched=sched, logger=logger, print_progress=print_progress, device=device, **kwargs)

    def forward_loss(self, data: Tuple[Tensor]) -> Tensor:
        """  """
        pass

    def predict(self, loader):
        """  """
        labels = []
        self.net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, label = data
                input = input.to(self.device)
                label = label.to(self.device).long()

                output = self.net(input)
                pred = torch.argmax(output, dim=1)

                labels += list(zip(pred.cpu().data.tolist(), label.cpu().data.tolist()))

            pred, label = zip(*labels)
        #acc = accuracy_score(np.array(label), np.array(pred))

        return pred, label

dict_transform = {
    "Padding":Padding,
    "VerticalFlip":tf.RandomVerticalFlip,
    "HorizontalFlip":tf.RandomHorizontalFlip,
    "Rotation":tf.RandomRotation,
    "CenterCrop":tf.CenterCrop,
    "Resize":tf.Resize,
}

In [44]:
ifile = "../Models/config/MAR_RESNET34_PADDED_256_ALL.yaml"

In [45]:
inputs = load_config(ifile)

layers = inputs.Model.Layers # [3, 4, 6, 3] for ResNet34 and [2, 2, 2, 2] for ResNet18
classes = inputs.Model.OutClasses
channels = inputs.Model.Channels
seed = int(inputs.Seed)
batch_size = int(inputs.BatchSize)

In [46]:
root_path = os.path.abspath(os.path.join(os.getcwd(), '../'))
path_load_data = f"{root_path}/{inputs.LoadPath}"
path_model = f"{root_path}/{inputs.PathSave}/{inputs.ModelName}"

model_name = f"model_all.pt"
save_model_path = f"{path_model}/{model_name}"

dataframe = pd.read_csv(path_load_data,index_col=0)

# Train Test Split
_, test_dataframe = train_test_split(dataframe, test_size=(1 - inputs.TrainTestSplit),stratify=dataframe['Label'], random_state=inputs.Seed)

# Reset Index
test_dataframe = test_dataframe.reset_index(drop=True)

# Samples
test_dataframe = test_dataframe.groupby('Label').sample(20,replace=True,random_state=inputs.Seed).reset_index(drop=True)

In [47]:
transforms_test = Transforms(
    [dict_transform[key]([k for k in item.values()] if len(item.values()) > 1 else [k for k in item.values()][0]) for key, item in inputs.TransformTest.items()] 
    )

testDataset = Dataset(
    test_dataframe.reset_index(drop=True),
    transforms=transforms_test.get_transforms()
    )

test_dataloader = torch.utils.data.DataLoader(
    testDataset, 
    batch_size=batch_size,
    shuffle=False
    )

In [51]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = resnet(layers=layers,channels=channels,num_classes=classes)

classifier = Classifier(
                net=net, 
                device=device
                )

classifier.load(save_model_path)

pred, label = classifier.predict(test_dataloader)

'MAR_RESNET34_PADDED_256_ALL'

In [57]:
if os.path.isfile('comparison_data.csv'):
    comparison_data = pd.read_csv('comparison_data.csv',index_col=0)
else :
    comparison_data = pd.DataFrame()
    comparison_data['Paths'] = test_dataframe['Paths']
    comparison_data['Label'] = test_dataframe['Label']

comparison_data[f'{inputs.ModelName}_best_model'] = pred
comparison_data[f'{inputs.ModelName}_baseline_model'] = pred

comparison_data.to_csv('comparison_data.csv')

In [59]:
pd.read_csv('comparison_data.csv',index_col=0)

Unnamed: 0,Paths,Label,best_model,baseline_model,PADDED_256_08121352,PADDED_256_time_08121352,RESIZED_256_08121358,RESIZED_256_time_08121358,MAR_RESNET34_PADDED_256_ALL_best_model,MAR_RESNET34_PADDED_256_ALL_baseline_model,CROPPED_256_08121900,CROPPED_256_time_08121900
0,../Data/Train/BL3-1-DL\slice01340 (2020_05_17 ...,0,0,0,0.0,4.60,0.0,4.09,0,0,4.0,2.35
1,../Data/Train/BL3-1-DL\slice00670 (2020_05_17 ...,0,0,0,0.0,6.07,0.0,4.17,0,0,0.0,1.93
2,../Data/Train/BL3-4-DL\slice01180 (2020_05_17 ...,0,0,0,0.0,1.88,4.0,7.59,0,0,4.0,17.78
3,../Data/Train/BL3-4-DL\slice01140 (2020_05_17 ...,0,0,0,0.0,1.45,0.0,1.27,0,0,4.0,2.14
4,../Data/Train/BL3-4-DL\slice01140 (2020_05_17 ...,0,0,0,0.0,1.39,0.0,1.12,0,0,0.0,8.88
...,...,...,...,...,...,...,...,...,...,...,...,...
95,../Data/Train/OL4-3-DL\slice01090 (2020_05_17 ...,4,4,4,4.0,1.56,4.0,1.58,4,4,0.0,1.97
96,../Data/Train/OL7-3-DL\slice00370 (2020_05_17 ...,4,4,4,2.0,1.54,4.0,2.18,4,4,4.0,2.71
97,../Data/Train/OL-7-1-DL\slice00270 (2020_05_17...,4,4,4,4.0,3.55,4.0,1.86,4,4,0.0,1.98
98,../Data/Train/OL-7-1-DL\slice00600 (2020_05_17...,4,4,4,2.0,1.19,2.0,2.82,4,4,2.0,2.77


In [None]:
n = 5
fig, ax = plt.subplots()
index = np.arange(n)
bar_width = 0.25
ax.bar(index+bar_width, np.array(our_model_vec).mean(axis=0), bar_width, color='g',
                label='Our Model')
ax.bar(index, np.array(baseline_vec).mean(axis=0), bar_width, color='r',
                label='Baseline')
ax.bar(index-bar_width, acc_user, bar_width, color='b',
                label='Expert')
ax.set_xlabel('Rock Type')
ax.set_ylabel('Accuracy')
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(('Rock 1','Rock 2','Rock 3','Rock 4','Rock 5'))
ax.legend()