In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import warnings
import os

from train.multi_hyperparameter import MultiHyperparameter
from custom_models.unet_original import UNet
from train.unet_trainer import UnetTrainer
from datasets.semantic_dataset import SemanticDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
params = {'padding_mode': ['zeros', 'reflect', 'replicate', 'circular'],
          'out_classes': 2,
          'criterion': nn.CrossEntropyLoss(),
          'optimizer': 1}

unet_hyps = MultiHyperparameter(params)
unet = UNet.__new__(UNet)

models_to_evaluate = [(unet, unet_hyps)]

num_epochs = 1
models = []
trainers = []
folds = [0, 1, 2, 3]
device = th.device("cuda" if th.cuda.is_available() else "cpu")
cwd = os.getcwd()
ds1_path = os.path.join(cwd, '../data/cell_type_1')
ds2_path = os.path.join(cwd, '../data/cell_type_2')



for i, model_to_evaluate in enumerate(models_to_evaluate, 0):
    model, current_params = model_to_evaluate
    params_to_evaluate = current_params.get_full_grid_params()

    best_loss = 1e10
    best_param = None

    for fold in folds:
        folds_ds_train = folds.copy()
        folds_ds_train.remove(fold)
        folds_ds_test = [fold]

        for param in params_to_evaluate:
            print(param)
            validation_losses = []
            for train_fold in folds_ds_train:
                model.__init__(param)
                folds_train = folds_ds_train.copy()
                folds_train.remove(train_fold)
                folds_validate = [train_fold]

                ds_train = SemanticDataset(ds1_path, folds_train)
                ds_validate = SemanticDataset(ds1_path, folds_validate)

                trainer = UnetTrainer(model, device, param['criterion'], th.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5), ds_train, ds_validate)
                trainer.train(num_epochs)
                trainer.test()
                validation_losses.append(trainer.test_losses[-1])

            validation_loss = np.sum(validation_losses)

            if best_loss > validation_loss:
                best_loss = validation_loss
                best_param = param

        ds_train = SemanticDataset(ds1_path, folds_train)
        ds_validate = SemanticDataset(ds2_path, folds_validate)
        model.__init__(best_param)








print('success')

{'padding_mode': 'zeros', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': 1}
{'padding_mode': 'reflect', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': 1}
{'padding_mode': 'replicate', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': 1}
{'padding_mode': 'circular', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': 1}


KeyboardInterrupt: 

In [19]:
class Model1:
    def __init__(self, params):
        self.number1 = params['number1']


class Model2:
    def __init__(self, params):
        self.number1 = params['number1']
        self.number2 = params['number2']
        self.number3 = params['number3']
        self.number4 = params['number4']
        self.number5 = params['number5']


params1 = {'number1': [1, 2, 3]}

params2 = {'number1': 1,
           'number2': 2,
           'number3': [1, 2, 3],
           'number4': [4, 5, 6],
           'number5': [7, 8, 9]}

models_to_evaluate = [(Model1.__new__(Model1), params1), (Model2.__new__(Model2), params2)]

In [22]:
for i, model_to_evaluate in enumerate(models_to_evaluate, 0):
    model, current_params = model_to_evaluate
    multi_params = MultiHyperparameter(current_params)

    params_to_evaluate = multi_params.get_full_grid_params()

    for param in params_to_evaluate:
        model.__init__(param)
        print(model, param)

<__main__.Model1 object at 0x00000283FBC75760> {'number1': 1}
<__main__.Model1 object at 0x00000283FBC75760> {'number1': 2}
<__main__.Model1 object at 0x00000283FBC75760> {'number1': 3}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 1, 'number4': 4, 'number5': 7}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 2, 'number4': 4, 'number5': 7}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 3, 'number4': 4, 'number5': 7}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 1, 'number4': 5, 'number5': 7}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 2, 'number4': 5, 'number5': 7}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 3, 'number4': 5, 'number5': 7}
<__main__.Model2 object at 0x00000283FBC756D0> {'number1': 1, 'number2': 2, 'number3': 1, 'number4': 6, 'n