In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import os
import json
import pandas as pd
import torchvision.transforms as T

from train.multi_hyperparameter import MultiHyperparameter
from custom_models.unet_original import UNet, UNetSmall, UNetLarge
from evaluate.cross_evaluator import SemanticCrossEvaluator
from train.unet_trainer import UnetTrainer
from datasets.semantic_dataset import SemanticDataset
from preprocessing.data_augment import DataAugmenter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cwd = os.getcwd()
ds1_path = os.path.join(cwd, '../data/cell_type_1')
ds2_path = os.path.join(cwd, '../data/cell_type_2')

cv_param = {'interval_img_out': 1,
            'num_images': 3,
            'device': th.device("cuda" if th.cuda.is_available() else "cpu"),
            'datasets_path': [ds1_path, ds2_path],
            'results_path': os.path.join(cwd, '../results'),
            'folds': [0, 1, 2, 3],
            'epochs_cv': 7,
            'epochs_ct': 1,
            'num_random_params': 3}

param = {'id': 0,
         'padding_mode': 'reflect',
         'out_classes': 2,
         'criterion': nn.CrossEntropyLoss(),
         'optimizer': MultiHyperparameter({'type': [
             # 'sgd',
             #'adam',
              'rmsprop',
             # 'asgd'
         ],
             'lr_factor': [
                 # 10,
                 # 4,
                 # 2,
                 1,
                 # 0.5,
                 # 0.25,
                 # 0.1
             ],
             'weight_decay': [
                 0,
                 1e-3,
                 1e-5
             ]
         }).get_full_grid_params(),
         'augment_transform': [{'rotate': False, 'mirror': False, 'translate': False, 'pad': 0},
                               #{'rotate': True, 'mirror': True, 'translate': False, 'pad': 0},
                               #{'rotate': True, 'mirror': True, 'translate': True, 'pad': 16},
                               #{'rotate': True, 'mirror': True, 'translate': True, 'pad': 8}
                               ],
         'num_augments': 100,
         'binarizer_lr': 0.1,
         'batch_size': 2}

unet_hyps = MultiHyperparameter(param)
params = unet_hyps.get_full_grid_params(indexed=True)
print(len(params))
unet = UNet.__new__(UNet)

3


In [3]:
cte = SemanticCrossEvaluator(unet, cv_param, 3)
print(params[0])
report = cte.cross_test_model([params[0]], cv_param['epochs_ct'], cv_param['epochs_cv'])
#cte.train_validate(params[0], ds1_path, [1, 2], [3], 10)

{'id': 0, 'padding_mode': 'reflect', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': {'type': 'rmsprop', 'lr_factor': 1, 'weight_decay': 0}, 'augment_transform': {'rotate': False, 'mirror': False, 'translate': False, 'pad': 0}, 'num_augments': 100, 'binarizer_lr': 0.1, 'batch_size': 2}
Testing Fold: 0 / 3
Evaluating Param: 0 / 0
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2
{'id': 0, 'padding_mode': 'reflect', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': {'type': 'rmsprop', 'lr_factor': 1, 'weight_decay': 0}, 'augment_transform': {'rotate': False, 'mirror': False, 'translate': False, 'pad': 0}, 'num_augments': 100, 'binarizer_lr': 0.1, 'batch_size': 2}
[2, 3] [1]


Training:  14%|████▎                         | 1/7 [00:17<01:46, 17.78s/it]

Loss: 0.5915964722633362 Acc: 0.22577735781669617


Training:  29%|████████▌                     | 2/7 [00:35<01:29, 17.82s/it]

Loss: 0.4663645422458649 Acc: 0.4347752332687378


Training:  43%|████████████▊                 | 3/7 [00:53<01:11, 17.87s/it]

Loss: 0.392926961183548 Acc: 0.3899022936820984


Training:  57%|█████████████████▏            | 4/7 [01:11<00:53, 17.93s/it]

Loss: 0.3563102614879608 Acc: 0.4469853937625885


Training:  71%|█████████████████████▍        | 5/7 [01:29<00:35, 17.95s/it]

Loss: 0.34078985571861264 Acc: 0.41896742582321167


Training:  86%|█████████████████████████▋    | 6/7 [01:47<00:17, 17.97s/it]

Loss: 0.33003078639507294 Acc: 0.4987831115722656


Training: 100%|██████████████████████████████| 7/7 [02:05<00:00, 17.94s/it]


Loss: 0.32517119467258454 Acc: 0.4778904914855957
Cross validate: 1 / 2
{'id': 0, 'padding_mode': 'reflect', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': {'type': 'rmsprop', 'lr_factor': 1, 'weight_decay': 0}, 'augment_transform': {'rotate': False, 'mirror': False, 'translate': False, 'pad': 0}, 'num_augments': 100, 'binarizer_lr': 0.1, 'batch_size': 2}
[1, 3] [2]


Training:  14%|████▎                         | 1/7 [00:17<01:47, 17.98s/it]

Loss: 0.5521011489629746 Acc: 0.020796170458197594


Training:  29%|████████▌                     | 2/7 [00:35<01:29, 17.99s/it]

Loss: 0.4138087296485901 Acc: 0.020796192809939384


Training:  43%|████████████▊                 | 3/7 [00:53<01:11, 17.99s/it]

Loss: 0.37209497213363646 Acc: 0.07888387143611908


Training:  57%|█████████████████▏            | 4/7 [01:11<00:54, 18.00s/it]

Loss: 0.3470654308795929 Acc: 0.5493043065071106


Training:  71%|█████████████████████▍        | 5/7 [01:30<00:36, 18.02s/it]

Loss: 0.33456147968769073 Acc: 0.6571475863456726


Training:  86%|█████████████████████████▋    | 6/7 [01:48<00:18, 18.01s/it]

Loss: 0.3289401626586914 Acc: 0.5851095914840698


Training: 100%|██████████████████████████████| 7/7 [02:06<00:00, 18.01s/it]


Loss: 0.3258860331773758 Acc: 0.6295795440673828
Cross validate: 2 / 2
{'id': 0, 'padding_mode': 'reflect', 'out_classes': 2, 'criterion': CrossEntropyLoss(), 'optimizer': {'type': 'rmsprop', 'lr_factor': 1, 'weight_decay': 0}, 'augment_transform': {'rotate': False, 'mirror': False, 'translate': False, 'pad': 0}, 'num_augments': 100, 'binarizer_lr': 0.1, 'batch_size': 2}
[1, 2] [3]


Training:  14%|████▎                         | 1/7 [00:18<01:48, 18.05s/it]

Loss: 0.6200078701972962 Acc: 0.015730641782283783


Training:  29%|████████▌                     | 2/7 [00:36<01:30, 18.04s/it]

Loss: 0.47911904513835907 Acc: 0.20333155989646912


Training:  43%|████████████▊                 | 3/7 [00:54<01:12, 18.03s/it]

Loss: 0.39777602672576906 Acc: 0.32228341698646545


Training:  43%|████████████▊                 | 3/7 [00:59<01:19, 19.77s/it]


KeyboardInterrupt: 