In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import accuracy_score

from pathlib import Path
import numpy as np
import pandas as pd
import random, getopt, os, sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../..')))

from Lamp.AttrDict.AttrDict import *
from Lamp.Model.Dataloader import *
from Lamp.Model.BaseModel import *
from Lamp.Model.Resnet import *

In [7]:
cfg_path = 'Config/MAR_RESNET34_PADDED_64_ALL.yaml'

inputs = AttrDict.from_yaml_path(cfg_path)
inputs

{'ModelName': 'MAR_RESNET34_PADDED_64_ALL',
 'PathSave': 'Models/outputs/',
 'LoadPath': 'Dataset/train_mar.csv',
 'CheckpointName': 'checkpoint.pt',
 'CheckpointFreq': 5,
 'NSamples': 1000,
 'NEpochs': 200,
 'BatchSize': 32,
 'TrainTestSplit': 0.95,
 'KFold': 5,
 'Seed': 0,
 'Model': {'Layers': [3, 4, 6, 3], 'OutClasses': 5, 'Channels': 1},
 'TransformTrain': {'Padding': {'out_shape': 64},
  'VerticalFlip': {'p': 0.5},
  'HorizontalFlip': {'p': 0.5},
  'Rotation': {'min': -90, 'max': 90}},
 'TransformTest': {'Padding': {'out_shape': 64}},
 'Optimizer': {'lr': 0.001, 'weight_decay': 1e-06},
 'Scheduler': {'gamma': 0.95}}

In [8]:
def load_config(cfg_path):
    """  """
    if os.path.splitext(cfg_path)[-1] == '.json':
        return AttrDict.from_json_path(cfg_path)
    elif os.path.splitext(cfg_path)[-1] in ['.yaml', '.yml']:
        return AttrDict.from_yaml_path(cfg_path)
    else:
        raise ValueError(f"Unsupported config file format. Only '.json', '.yaml' and '.yml' files are supported.")

def set_seed(seed):
    """ Set the random seed """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def set_parameter_requires_grad(model):
    
    for _, child in model.named_children():
        for param in child.parameters():
            param.requires_grad = True

def resnet(layers=[3, 4, 6, 3],channels=3, num_classes=1000):
    model = ResNet(BasicBlock,layers,channels=channels,num_classes=num_classes)
    return model

class Classifier(BaseModelSingle):
    def __init__(self, net: nn.Module, opt: Optimizer = None, sched: _LRScheduler = None, 
        logger: Logger = None, print_progress: bool = True, device: str = 'cuda:0', **kwargs):
        super().__init__(net, opt=opt, sched=sched, logger=logger, print_progress=print_progress, device=device, **kwargs)
        
        self.loss_fn = nn.CrossEntropyLoss(reduction="mean")


    def forward_loss(self, data: Tuple[Tensor]) -> Tensor:
        """  """
        input, label = data
        input = input.to(self.device)
        label = label.to(self.device).long()

        output = self.net(input)
        loss = self.loss_fn(output, label)

        pred = torch.argmax(output, dim=1)
        pred_label = list(zip(pred.cpu().data.tolist(), label.cpu().data.tolist()))

        pred, label = zip(*pred_label)
        acc = accuracy_score(np.array(label), np.array(pred))

        return loss, {"Loss": loss, "Train Accuracy": acc}

In [12]:
root_path = os.path.abspath(os.path.join(os.getcwd(), '..')) # Workspace path to Cuttings_Characterisation 
path_load_data = f"{root_path}/{inputs.LoadPath}" # Path for the .csv file
path_model = f"{root_path}/{inputs.PathSave}/{inputs.ModelName}"

model_name = f"model_all.pt"
save_model_path = f"{path_model}/{model_name}"

dataframe = pd.read_csv(path_load_data,index_col=0)

dataframe = dataframe.groupby('Label').sample(inputs.NSamples,replace=True,random_state=inputs.Seed).reset_index(drop=True)

dataframe = dataframe.reset_index(drop=True)

In [14]:
dict_transform = {
    "Padding":Padding,
    "VerticalFlip":tf.RandomVerticalFlip,
    "HorizontalFlip":tf.RandomHorizontalFlip,
    "Rotation":tf.RandomRotation,
    "CenterCrop":tf.CenterCrop,
    "Resize":tf.Resize,
    }

transforms_train = Transforms(
    [dict_transform[key]([k for k in item.values()] if len(item.values()) > 1 else [k for k in item.values()][0]) for key, item in inputs.TransformTrain.items()] 
    )

In [15]:
trainDataset = Dataset(
    dataframe,
    transforms=transforms_train.get_transforms()
    )