In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random, getopt, os, sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../../')))

from Lamp.AttrDict.AttrDict import *
from Lamp.Model.Dataloader import *
from Lamp.Model.BaseModel import *
from Lamp.Model.ResNet import ResNet as ResNet, BasicBlock
from Lamp.Model.Baseline import Net as Baseline

In [24]:
def set_seed(seed):
    """ Set the random seed """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_config(cfg_path):
    """  """
    if os.path.splitext(cfg_path)[-1] == '.json':
        return AttrDict.from_json_path(cfg_path)
    elif os.path.splitext(cfg_path)[-1] in ['.yaml', '.yml']:
        return AttrDict.from_yaml_path(cfg_path)
    else:
        raise ValueError(f"Unsupported config file format. Only '.json', '.yaml' and '.yml' files are supported.")

def resnet(layers=[3, 4, 6, 3],channels=3, num_classes=1000):
    model = ResNet(BasicBlock,layers,channels=channels,num_classes=num_classes)
    return model

class Classifier(BaseModelSingle):
    def __init__(self, net: nn.Module, opt: Optimizer = None, sched: _LRScheduler = None, 
        logger: Logger = None, print_progress: bool = True, device: str = 'cuda:0', **kwargs):
        super().__init__(net, opt=opt, sched=sched, logger=logger, print_progress=print_progress, device=device, **kwargs)

    def forward_loss(self, data: Tuple[Tensor]) -> Tensor:
        """  """
        pass

    def predict(self, loader):
        """  """
        labels = []
        self.net.eval()
        with torch.no_grad():
            for b, data in enumerate(loader):
                input, label = data
                input = input.to(self.device)
                label = label.to(self.device).long()

                output = self.net(input)
                pred = torch.argmax(output, dim=1)

                labels += list(zip(pred.cpu().data.tolist(), label.cpu().data.tolist()))

            pred, label = zip(*labels)
            
        return pred, label

In [25]:
ifile = "../Models/config/MAR_RESNET18_PADDED_256_ALL_NEW.yaml"

In [26]:
inputs = load_config(ifile)

In [27]:
root_path = os.path.abspath(os.path.join(os.getcwd(), '../'))
path_load_data = f"{root_path}/{inputs.LoadPathTestBorehole}"
path_model = f"{root_path}/{inputs.PathSave}/{inputs.ModelName}"


dataframe = pd.read_csv(path_load_data,index_col=0)

# Train Test Split
_, test_dataframe = train_test_split(dataframe, test_size=(1 - inputs.TrainTestSplit),stratify=dataframe['Label'], random_state=inputs.Seed)

# Reset Index
test_dataframe = test_dataframe.reset_index(drop=True)

# Samples
test_dataframe = test_dataframe.groupby('Label').sample(20,replace=True,random_state=inputs.Seed).reset_index(drop=True)

In [28]:
if os.path.isfile('comparison_data_boreholes.csv'):
    comparison_data = pd.read_csv('comparison_data_boreholes.csv',index_col=0)
else :
    comparison_data = pd.DataFrame()
    comparison_data['Paths'] = test_dataframe['Paths']
    comparison_data['Label'] = test_dataframe['Label']

comparison_data.to_csv('comparison_data_boreholes.csv')

In [29]:
df_preds = pd.read_csv('../Application/comparison_data_boreholes.csv',index_col=0)

dict_transform = {
        "Padding":Padding,
        "CenterCrop":tf.CenterCrop,
        "Resize":tf.Resize,
        }
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Pred Best Model
inputs = load_config("../Models/config/MAR_RESNET18_PADDED_256_ALL_NEW.yaml")

root_path = os.path.abspath(os.path.join(os.getcwd(), '..')) # Workspace path to Cuttings_Characterisation 
path_model = f"{root_path}/{inputs.PathSave}/{inputs.ModelName}"
model_name = f"model_all.pt"
save_model_path = f"{path_model}/{model_name}"

net = resnet(layers=inputs.Model.Layers,channels=inputs.Model.Channels,num_classes=inputs.Model.OutClasses)
best_model = Classifier(net=net,device=device)
best_model.load(save_model_path)

transforms_test = Transforms(
    [dict_transform[key]([k for k in item.values()] if len(item.values()) > 1 else [k for k in item.values()][0]) for key, item in inputs.TransformTest.items()] 
    )

testDataset = Dataset(
    df_preds.reset_index(drop=True),
    transforms=transforms_test.get_transforms()
    )

test_dataloader = torch.utils.data.DataLoader(
    testDataset, 
    batch_size=4,
    shuffle=False
    )
    
time_start = time.time()
pred, label = best_model.predict(test_dataloader)
time_end = time.time()

df_preds['best_model'] = pred

# Pred Baseline
inputs = load_config("../Baseline/config/MAR_BASELINE_PADDED_256_ALL_NEW.yaml")

root_path = os.path.abspath(os.path.join(os.getcwd(), '..')) # Workspace path to Cuttings_Characterisation 
path_model = f"{root_path}/{inputs.PathSave}/{inputs.ModelName}"
model_name = f"model_all.pt"
save_model_path = f"{path_model}/{model_name}"

net = Baseline()
baseline = Classifier(net=net,device=device)
baseline.load(save_model_path)

transforms_test = Transforms(
    [dict_transform[key]([k for k in item.values()] if len(item.values()) > 1 else [k for k in item.values()][0]) for key, item in inputs.TransformTest.items()] 
    )

testDataset = Dataset(
    df_preds.reset_index(drop=True),
    transforms=transforms_test.get_transforms()
    )

test_dataloader = torch.utils.data.DataLoader(
    testDataset, 
    batch_size=4,
    shuffle=False
    )

pred, label = baseline.predict(test_dataloader)

df_preds['baseline_model'] = pred

In [30]:
df_preds

Unnamed: 0,Paths,Label,best_model,baseline_model
0,../Data/Test_Borehole/BL-DB-1\slice00270_mar_r...,0,0,0
1,../Data/Test_Borehole/BL-DB-1\slice00300_mar_r...,0,0,0
2,../Data/Test_Borehole/BL-DB-2\slice00420_mar_r...,0,4,0
3,../Data/Test_Borehole/BL-DB-3\slice00590_mar_r...,0,4,4
4,../Data/Test_Borehole/BL-DB-3\slice00720_mar_r...,0,0,0
...,...,...,...,...
95,../Data/Test_Borehole/OL-DB-1\slice00500_mar_r...,4,4,4
96,../Data/Test_Borehole/OL-DB-3\slice00490_mar_r...,4,4,4
97,../Data/Test_Borehole/OL-DB-2\slice00550_mar_r...,4,4,4
98,../Data/Test_Borehole/OL-DB-3\slice00340_mar_r...,4,0,3


In [31]:
accuracy_score(df_preds['Label'],df_preds['best_model'])

0.46

In [32]:
accuracy_score(df_preds['Label'],df_preds['baseline_model'])

0.55