## Imports

In [None]:
import torch
from torchvision.datasets import ImageFolder
from collections import Counter
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from functools import partial

## Preprocessing

In [None]:
mean = [0.1918]
std = [0.2148]

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
dataset = ImageFolder('../dataset', transform)

In [None]:
count_dict = dict(Counter(dataset.targets))
count = count_dict.values()
total = sum(count)
weight = [total/c for c in count]
weight = torch.FloatTensor(weight)

In [None]:
dataset_size = len(dataset)
train_size = int(dataset_size * 0.8)
val_size = int(dataset_size * 0.1)
test_size = dataset_size - (train_size + val_size)

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

In [None]:
def create_data_loaders(train_dataset, val_dataset, test_dataset, batch_size):
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    return train_loader, val_loader, test_loader

In [None]:
def show_batch(dl):
    for images, labels in dl:
        fig,ax = plt.subplots(figsize = (8,8))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(images.mul_(torch.as_tensor(mean)).add_(torch.as_tensor(std)), nrow=4, pad_value=1).permute(1,2,0))
        break

In [None]:
def calc_mean_std(data_loader):
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0

    for data, _ in data_loader:
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

# Models

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ImageClassificationBase(nn.Module):
    def __init__(self):
        super().__init__()
        self._initialize_weights()
    
    def training_step(self, batch, criterion):
        images, labels = batch 
        out = self(images) 
        loss = criterion(out, labels) 
        acc = accuracy(out, labels)
        return loss, acc
    
    def validation_step(self, batch, criterion):
        images, labels = batch 
        out = self(images)                  
        loss = criterion(out, labels)  
        acc = accuracy(out, labels)          
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()     
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, train_acc: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['train_acc'], result['val_acc']))
        
    def _initialize_weights(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
            nn.init.kaiming_normal_(w, nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)

## AlexNet

In [None]:
class AlexNet(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=6),
            nn.Dropout(0.5),
            nn.Flatten(),
            nn.Linear(2304, 256, bias=True),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 4)
        )
        
    def forward(self, x):
        return self.network(x)

In [None]:
class AlexNet2(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=4),
            nn.Dropout(0.5),
            nn.Flatten(),
            nn.Linear(1024, 256, bias=True),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 4)
        )
        
    def forward(self, x):
        return self.network(x)

## VGG

In [None]:
class VGG1(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            
            nn.Conv2d(1, 16, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64 ,64, kernel_size=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=7),
            
            nn.Flatten(),
            nn.Linear(3136, 128),
            nn.ReLU(),
            nn.Dropout(0.85),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.85),
            nn.Linear(64, 4)
        )
    
    def forward(self, x):
        return self.network(x)
    
    def _initialize_weights(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

In [None]:
class VGG2(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            
            nn.Conv2d(1, 16, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        
            nn.Conv2d(32, 48, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(48 ,64, kernel_size=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=4),
            
            nn.Flatten(),
            nn.Linear(1024, 128),
            nn.ReLU(),
            nn.Dropout(0.85),
            nn.Linear(128, 4)
        )
    
    def forward(self, x):
        return self.network(x)
    

    def _initialize_weights(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

In [None]:
class VGG3(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            
            nn.Conv2d(1, 16, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(32, 48, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(48, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        
            nn.Conv2d(64, 96, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(96 ,96, kernel_size=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=3),
            
            nn.Flatten(),
            nn.Linear(864, 128),
            nn.ReLU(),
            nn.Dropout(0.85),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.85),
            nn.Linear(64, 4)
        )
    
    def forward(self, x):
        return self.network(x)
    
    def _initialize_weights(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

## Inception Net

In [None]:
class InceptionModuleV1(nn.Module):

    def __init__(self, in_channels, out_channels: dict, reduction_channels: dict):

        super().__init__()
        
        self.conv_1x1 = ConvBlock(in_channels, out_channels['1x1'], kernel_size=1)
        
        self.conv_3x3 = nn.Sequential(
            ConvBlock(in_channels, reduction_channels['3x3'], kernel_size=1),
            ConvBlock(reduction_channels['3x3'], out_channels['3x3'], kernel_size=3, padding=1)
        )
        
        self.conv_5x5 = nn.Sequential(
            ConvBlock(in_channels, reduction_channels['5x5'], kernel_size=1),
            ConvBlock(reduction_channels['5x5'], out_channels['5x5'], kernel_size=5, padding=2)
        )

        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            ConvBlock(in_channels, out_channels['max'], kernel_size=1),
            nn.ReLU()
        )

    def forward(self, x):
        return torch.cat([self.conv_1x1(x), self.conv_3x3(x), self.conv_5x5(x), self.max_pool(x)], dim=1)

In [None]:
class Inception1(ImageClassificationBase):
    def __init__(self):
        super().__init__()

        self.input_net = ConvBlock(1, 64, kernel_size=3, padding=1)
        
        self.inception_blocks = nn.Sequential(
            InceptionModuleV1(64, out_channels={'1x1': 16, '3x3': 32, '5x5': 8, 'max': 8}, reduction_channels={'3x3': 32, '5x5': 16}),
            InceptionModuleV1(64, out_channels={'1x1': 24, '3x3': 48, '5x5': 12, 'max': 12}, reduction_channels={'3x3': 32, '5x5': 16}),
            nn.MaxPool2d(3, stride=2, padding=1), 
            InceptionModuleV1(96, out_channels={'1x1': 24, '3x3': 48, '5x5': 12, 'max': 12}, reduction_channels={'3x3': 32, '5x5': 16}),
            InceptionModuleV1(96, out_channels={'1x1': 16, '3x3': 48, '5x5': 16, 'max': 16}, reduction_channels={'3x3': 32, '5x5': 16}),
            InceptionModuleV1(96, out_channels={'1x1': 32, '3x3': 48, '5x5': 24, 'max': 24}, reduction_channels={'3x3': 32, '5x5': 16}),
            nn.MaxPool2d(3, stride=2, padding=1),
            InceptionModuleV1(128, out_channels={'1x1': 32, '3x3': 64, '5x5': 16, 'max': 16}, reduction_channels={'3x3': 48, '5x5': 16}),
            InceptionModuleV1(128, out_channels={'1x1': 32, '3x3': 64, '5x5': 16, 'max': 16}, reduction_channels={'3x3': 48, '5x5': 16})
        )

        self.output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.Dropout(0.4),
            nn.Linear(64, 4)
        )


    def forward(self, x):
        x = self.input_net(x)
        x = self.inception_blocks(x)
        x = self.output_net(x)
        return x

In [None]:
class Inception2(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        
        self.input_net = ConvBlock(1, 32, kernel_size=3, padding=1)
        
        self.inception_blocks = nn.Sequential(
            InceptionModuleV1(32, out_channels={'1x1': 8, '3x3': 16, '5x5': 4, 'max': 4}, reduction_channels={'3x3': 12, '5x5': 4}),
            InceptionModuleV1(32, out_channels={'1x1': 12, '3x3': 24, '5x5': 8, 'max': 8}, reduction_channels={'3x3': 16, '5x5': 8}),
            nn.MaxPool2d(3, stride=2, padding=1), 
            InceptionModuleV1(52, out_channels={'1x1': 16, '3x3': 32, '5x5': 12, 'max': 12}, reduction_channels={'3x3': 16, '5x5': 8}),
            InceptionModuleV1(72, out_channels={'1x1': 16, '3x3': 32, '5x5': 16, 'max': 16}, reduction_channels={'3x3': 24, '5x5': 12}),
            InceptionModuleV1(80, out_channels={'1x1': 32, '3x3': 48, '5x5': 24, 'max': 24}, reduction_channels={'3x3': 32, '5x5': 16}),
            nn.MaxPool2d(3, stride=2, padding=1),
            InceptionModuleV1(128, out_channels={'1x1': 32, '3x3': 64, '5x5': 16, 'max': 16}, reduction_channels={'3x3': 48, '5x5': 16}),
            InceptionModuleV1(128, out_channels={'1x1': 32, '3x3': 64, '5x5': 16, 'max': 16}, reduction_channels={'3x3': 48, '5x5': 16}),
            nn.MaxPool2d(3, stride=2, padding=1),
            InceptionModuleV1(128, out_channels={'1x1': 32, '3x3': 72, '5x5': 24, 'max': 24}, reduction_channels={'3x3': 48, '5x5': 16}),
            InceptionModuleV1(152, out_channels={'1x1': 32, '3x3': 72, '5x5': 24, 'max': 24}, reduction_channels={'3x3': 48, '5x5': 16}),
        )

        self.output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(152, 128),
            nn.Dropout(0.4),
            nn.Linear(128, 4)
        )


    def forward(self, x):
        x = self.input_net(x)
        x = self.inception_blocks(x)
        x = self.output_net(x)
        return x

## ResNet

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, intermediate_channels, identity_downsample=None, stride=1):
        super().__init__()
        self.expansion = 4
        self.blocks = nn.Sequential(
            ConvBlock(in_channels, intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False),
            ConvBlock(intermediate_channels, intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.Conv2d(intermediate_channels, intermediate_channels * self.expansion, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(intermediate_channels * self.expansion)
        )
        
        self.relu = nn.ReLU()

        self.identity_downsample = identity_downsample
        self.stride = stride

    def forward(self, x):
        identity = x.clone()

        x = self.blocks(x)

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x


In [None]:
class ResNet(ImageClassificationBase):
    def __init__(self, block, layers, image_channels=1, num_classes=4, apply_dropout=False):
        super().__init__()
        self.in_channels = 64
        self.apply_dropout = apply_dropout
        
        self.network = nn.Sequential(
            ConvBlock(image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            self._make_layer(block, layers[0], intermediate_channels=16, stride=1),
            self._make_layer(block, layers[1], intermediate_channels=32, stride=2),
            self._make_layer(block, layers[2], intermediate_channels=64, stride=2),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.fc1 = nn.Linear(64 * 4, 64)
        self.fc2 = nn.Linear(64, num_classes)
    
        self.dropout = nn.Dropout()

    def forward(self, x):
        x = self.network(x)
        x = x.reshape(x.shape[0], -1)
        if self.apply_dropout:
            x = self.dropout(x)
        x = self.fc1(x)
        if self.apply_dropout:
            x = self.dropout(x)
        x = self.fc2(x)

        return x

    def _make_layer(self, block, num_residual_blocks, intermediate_channels, stride):
        identity_downsample = None
        layers = []

        if stride != 1 or self.in_channels != intermediate_channels * 4:
            identity_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, intermediate_channels * 4, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(intermediate_channels * 4),
            )

        layers.append(
            block(self.in_channels, intermediate_channels, identity_downsample, stride)
        )

        self.in_channels = intermediate_channels * 4

        for i in range(num_residual_blocks - 1):
            layers.append(block(self.in_channels, intermediate_channels))

        return nn.Sequential(*layers)

## Training

In [None]:
import copy

def accuracy(outputs, labels):
    _, preds = torch.max(outputs.data, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))
  
@torch.no_grad()
def evaluate(model, criterion, val_loader):
    model.eval()
    outputs = [model.validation_step(batch, criterion) for batch in val_loader]
    return model.validation_epoch_end(outputs)

  
def fit(epochs, model, optimizer, criterion, train_loader, val_loader):
    
    history = []

    # early stopping params
    n_epochs_stop = 20
    epochs_no_improve = 0
    min_val_loss = None
    best_model = None
    
    for epoch in range(epochs):
        print(f'start epoch {epoch}')
        
        model.train()
        train_losses = []
        train_accuracies = []
        for batch in train_loader:
            optimizer.zero_grad()
            loss, acc = model.training_step(batch, criterion)
            train_losses.append(loss)
            train_accuracies.append(acc)
            loss.backward()
            optimizer.step()

        result = evaluate(model, criterion, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['train_acc'] = torch.stack(train_accuracies).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
        
        # early stopping
        if min_val_loss == None or result['val_loss'] < min_val_loss:
            epochs_no_improve = 0
            min_val_loss = result['val_loss']
            best_model = copy.deepcopy(model)
        else:
            epochs_no_improve += 1
            if epochs_no_improve == n_epochs_stop:
                print('Early stopping' )
                break
        
    
    return history, model, best_model

In [None]:
def train(model_name, opt_name, criterion, batch_size, lr, momentum, weight_decay, train_loader, val_loader):    
    
    params = model_params[model_name]
    model = models[model_name](**params)
    
    optim = opt_func[opt_name]
    optim_params = {'lr': lr, 'weight_decay': weight_decay}
        
    if opt_name == 'RMSprop':
        optim_params['momentum'] = momentum

    optimizer = optim(model.parameters(), **optim_params)
    
    return fit(num_epochs, model, optimizer, criterion, train_loader, val_loader)

## Evaluation

In [None]:
import pickle
import json

def save_obj(obj, name):
    obj = json.dumps(obj)
    f = open(name + '.json', 'w')
    f.write(obj)
    f.close()


In [None]:
def plot_history(path, id, history, train_str, val_str, y_label, title):
    train = [x[train_str] for x in history]
    val = [x[val_str] for x in history]
    plt.figure().clear()
    plt.plot(train, color='#4c9ac7', linestyle='solid', marker='x')
    plt.plot(val, color='#de9d35', linestyle='solid', marker='.')
    plt.xlabel('epoch')
    plt.ylabel(y_label)
    plt.title(title);
    plt.savefig(f'{path}/{id}_{y_label}.png')
    

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

def get_confusion_matrix(path, id, model, test_loader):

    y_pred = []
    y_true = []

    for images, labels in test_loader:
            output = model(images)

            _, preds = torch.max(output.data, dim=1)
            y_pred.extend(preds.cpu().numpy()) 

            labels = labels.data.cpu().numpy()
            y_true.extend(labels) 

    classes = ('cnv', 'dme', 'drusen', 'normal')

    cf_matrix = confusion_matrix(y_true, y_pred, normalize='true')

    df_cm = pd.DataFrame(cf_matrix, index = [i for i in classes], columns = [i for i in classes])

    plt.figure().clear()
    plt.figure(figsize = (12,7))
    sn.heatmap(df_cm, annot=True, cmap="Blues")
    plt.savefig(f'{path}/{id}_conf_matrix.png')

In [None]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, accuracy_score

def get_precision_recall_f1(path, id, model, test_loader):
    y_pred = []
    y_true = []
    y_score = []
    softmax = nn.Softmax(dim=1)
    res = {}

    for inputs, labels in test_loader:
        output = model(inputs)

        _, preds = torch.max(output.data, dim=1)
        y_pred.extend(preds.cpu().numpy()) 
        y_score.extend(softmax(output.data).cpu().numpy())

        labels = labels.data.cpu().numpy()
        y_true.extend(labels)

    res['metrics'] = precision_recall_fscore_support(y_true, y_pred, average='macro')
    res['roc_auc_score_ovo'] = roc_auc_score(y_true, y_score, multi_class='ovo')
    res['roc_auc_score_ovr'] = roc_auc_score(y_true, y_score, multi_class='ovr')
    res['accuracy'] = accuracy_score(y_true, y_pred)
    save_obj(res, f'{path}/{id}_metrics')

In [None]:
import os

def log_model(id, history, model, best_model, test_loader):
    path = f'models/{id}'
    os.mkdir(path)
    
    model.eval()
    best_model.eval()
    
    torch.save(model.state_dict(), f'{path}/{id}.pth')
    torch.save(best_model.state_dict(), f'{path}/{id}_best.pth')
    save_obj(history, f'{path}/{id}_history')

    get_confusion_matrix(path, id, model, test_loader)
    get_confusion_matrix(path, f'{id}_best', best_model, test_loader)
    
    get_precision_recall_f1(path, id, model, test_loader)
    get_precision_recall_f1(path, f'{id}_best', best_model, test_loader)
    
    plot_history(path, id, history, 'train_acc', 'val_acc', 'accuracy', 'Accuracy vs. No. of epochs')
    plot_history(path, id, history, 'train_loss', 'val_loss', 'loss', 'Loss vs. No. of epochs')

## Hyperparameters

In [None]:
batch_sizes = [16, 32]

opt_func = {
    'Adam': torch.optim.Adam, 
    'RMSprop': torch.optim.RMSprop
}

criterion = nn.CrossEntropyLoss(weight=weight)

models = {
    'AlexNet1': AlexNet,
    'AlexNet2': AlexNet2,
    'VGG1': VGG1,
    'VGG2': VGG2,
    'VGG3': VGG3,
    'Inception1': Inception1,
    'Inception2': Inception2,
    'ResNet': ResNet
}

model_params = {
    'AlexNet1': {},
    'AlexNet2': {},
    'VGG1': {},
    'VGG2': {},
    'VGG3': {},
    'Inception1': {},
    'Inception2': {},
    'ResNet': {
        'block': ResidualBlock,
        'layers': [4,4,4],
        'apply_dropout': True
    }
}

In [None]:
model_name = 'ResNet'
opt_name = 'RMSprop'
batch_size = batch_sizes[1]
num_epochs = 200
lr = 0.00001
momentum = 0.99
weight_decay = 0.001

id = f'{model_name}_{opt_name}_{batch_size}_lr{lr}' + (f'_m{momentum}' if opt_name == 'RMSprop' else '') + f'_wd{weight_decay}'

train_loader, val_loader, test_loader = create_data_loaders(train_dataset, val_dataset, test_dataset, batch_size)
history, model, best_model = train(model_name, opt_name, criterion, batch_size, lr, momentum, weight_decay, train_loader, val_loader)
log_model(id, history, model, best_model, test_loader)