In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import os
import json
import pandas as pd
import torchvision.transforms as T

from train.multi_hyperparameter import MultiHyperparameter
from custom_models.unet_original import UNet
from train.unet_trainer import UnetTrainer
from datasets.semantic_dataset import SemanticDataset
from evaluate.cross_evaluator import CrossTrainEvaluator
from preprocessing.data_augment import DataAugmenter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
params = {'padding_mode': ['zeros', 'reflect', 'replicate', 'circular'],
          'depth': 5,
          'start_layers': 64,
          'dim_multiplier': 2,
          'input_conv_kernel_size': 3,
          'out_classes': 2,
          'criterion': nn.CrossEntropyLoss(),
          'optimizer': 1}

param = {'padding_mode': 'zeros',
         'depth': 5,
         'start_layers': 64,
         'dim_multiplier': 2,
         'input_conv_kernel_size': 3,
         'out_classes': 2,
         'criterion': nn.CrossEntropyLoss(),
         'optimizer': 1,
         'augment_transform': T.Compose([T.RandomVerticalFlip()]),
         'num_augments': 10}

unet_hyps = MultiHyperparameter(params)
unet = UNet.__new__(UNet)

models_to_evaluate = [(unet, unet_hyps)]

num_epochs = 1
models = []
trainers = []
folds = [0, 1, 2, 3]
device = th.device("cuda" if th.cuda.is_available() else "cpu")
cwd = os.getcwd()
ds1_path = os.path.join(cwd, '../data/cell_type_1')
ds2_path = os.path.join(cwd, '../data/cell_type_2')
result_path = os.path.join(cwd, '../results')

cte = CrossTrainEvaluator(unet, [ds1_path, ds2_path], device, result_path)
report = cte.evaluate_param(param, [0, 1, 2, 3], 1)

print(report)
# (self, model, datasets_path, params_to_test: MultiHyperparameter, results_path, epochs):
# (self, param, model, folds, epochs):

{'padding_mode': 'zeros', 'depth': 5, 'start_layers': 64, 'dim_multiplier': 2, 'input_conv_kernel_size': 3, 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': 1, 'augment_transform': Compose(
    RandomVerticalFlip(p=0.5)
), 'num_augments': 10}
   ds_0_train  ds_0_validate  ds_1_train  ds_1_validate  combined_train  \
0       0.663       0.748542    0.657301       0.728916         0.66015   

   combined_validate  
0           0.738729  


In [None]:
file = th.load(os.path.join(cwd, '../results/test.pt'))
print(file[0]['criterion'])