In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import os
import json
import pandas as pd
import torchvision.transforms as T

from train.multi_hyperparameter import MultiHyperparameter
from custom_models.unet_original import UNet, UNetSmall, UNetLarge
from evaluate.cross_evaluator import SemanticCrossEvaluator
from train.unet_trainer import UnetTrainer
from datasets.semantic_dataset import SemanticDataset
from preprocessing.data_augment import DataAugmenter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cwd = os.getcwd()
ds1_path = os.path.join(cwd, '../data/cell_type_1')
ds2_path = os.path.join(cwd, '../data/cell_type_2')

cv_param = {'interval_img_out': 1,
            'num_images': 3,
            'device': th.device("cuda" if th.cuda.is_available() else "cpu"),
            'datasets_path': [ds1_path, ds2_path],
            'results_path': os.path.join(cwd, '../results'),
            'folds': [0, 1, 2, 3],
            'epochs_cv': 1,
            'epochs_ct': 1,
            'num_random_params': 3}

param = {'id': 0,
         'padding_mode': 'reflect',
         'out_classes': 2,
         'criterion': nn.CrossEntropyLoss(),
         'optimizer': MultiHyperparameter({'type': [  #'sgd',
             #'adam',
             'rmsprop',
             #'asgd'
         ],
             'lr_factor': [
                 1.5,
                 1,
                 0.5
             ],
             'weight_decay': [0,
                              #1e-3,
                              #1e-5
                              ]
         }).get_full_grid_params(),
         'augment_transform': [{'rotate': False, 'mirror': False, 'translate': False, 'pad': 0},
                               #{'rotate': True, 'mirror': True, 'translate': False, 'pad': 0},
                               #{'rotate': True, 'mirror': True, 'translate': True, 'pad': 16},
                               #{'rotate': True, 'mirror': True, 'translate': True, 'pad': 8}
                               ],
         'num_augments': 100,
         'binarizer_lr': 0.02,
         'batch_size': 2}

unet_hyps = MultiHyperparameter(param)
params = unet_hyps.get_full_grid_params()
print(len(params))
unet = UNet.__new__(UNet)

3


In [3]:
cte = SemanticCrossEvaluator(unet, cv_param)
report = cte.cross_test_model(params, cv_param['epochs_ct'], cv_param['epochs_cv'])
#cte.train_validate(params[0], ds1_path, [0, 1, 2], [3], 50)

Testing Fold: 0 / 3
Evaluating Param: 0 / 2
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.50s/it]


Loss: 0.5533112043142319 Acc: 0.09382044523954391
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.59s/it]


Loss: 0.5251954615116119 Acc: 0.20735852420330048
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.76s/it]


Loss: 0.5165163058042527 Acc: 0.03456239402294159
Evaluating Dataset: 1 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.26s/it]


Loss: 0.5521536201238633 Acc: 0.6524022817611694
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.44s/it]


Loss: 0.5645504063367843 Acc: 0.5345255136489868
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.23s/it]


Loss: 0.5274746924638748 Acc: 0.5237728357315063
Evaluating Param: 1 / 2
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.78s/it]


Loss: 0.544754678606987 Acc: 0.04372614249587059
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.77s/it]


Loss: 0.5595835322141647 Acc: 0.020796170458197594
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.42s/it]


Loss: 0.5533826804161072 Acc: 0.008105015382170677
Evaluating Dataset: 1 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.52s/it]


Loss: 0.5614715629816055 Acc: 0.6542930603027344
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.48s/it]


Loss: 0.5322900462150574 Acc: 0.49345889687538147
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.51s/it]


Loss: 0.5769050335884094 Acc: 0.5246397256851196
Evaluating Param: 2 / 2
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.98s/it]


Loss: 0.6310721015930176 Acc: 0.04373030364513397
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.76s/it]


Loss: 0.7061578071117401 Acc: 0.020796170458197594
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:22<00:00, 22.02s/it]


Loss: 0.6309318923950196 Acc: 0.021788280457258224
Evaluating Dataset: 1 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.60s/it]


Loss: 0.6235542809963226 Acc: 0.5909838676452637
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.70s/it]


Loss: 0.6795531868934631 Acc: 0.5227495431900024
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.69s/it]


Loss: 0.610199019908905 Acc: 0.5071520805358887
Evaluating Dataset: 0 / 1


Training: 100%|██████████████████████████████| 1/1 [00:25<00:00, 25.04s/it]


Loss: 0.5446128022670745 Acc: 0.02779960073530674
Evaluating Dataset: 1 / 1


Training: 100%|██████████████████████████████| 1/1 [00:24<00:00, 24.79s/it]


Loss: 0.5564838117361068 Acc: 0.5735226273536682
Testing Fold: 1 / 3
Evaluating Param: 0 / 2
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.53s/it]


Loss: 0.4643408650159836 Acc: 0.03611069545149803
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.24s/it]


Loss: 0.5313307332992554 Acc: 0.6725136637687683
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.41s/it]


Loss: 0.5297744387388229 Acc: 0.49034491181373596
Evaluating Dataset: 1 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.37s/it]


Loss: 0.5769422090053559 Acc: 0.6029945015907288
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.20s/it]


Loss: 0.5131680059432984 Acc: 0.5233028531074524
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.07s/it]


Loss: 0.5227026331424713 Acc: 0.5369646549224854
Evaluating Param: 1 / 2
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.30s/it]


Loss: 0.5853465294837952 Acc: 0.03612442687153816
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.01s/it]


Loss: 0.6010579001903534 Acc: 0.14224620163440704
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.33s/it]


Loss: 0.5848111176490783 Acc: 0.021788280457258224
Evaluating Dataset: 1 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.22s/it]


Loss: 0.6305784451961517 Acc: 0.5987714529037476
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.31s/it]


Loss: 0.5725930535793304 Acc: 0.531168520450592
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.29s/it]


Loss: 0.5732896023988724 Acc: 0.5461848378181458
Evaluating Param: 2 / 2
Evaluating Dataset: 0 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.53s/it]


Loss: 0.611964476108551 Acc: 0.03612995147705078
Cross validate: 1 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.34s/it]


Loss: 0.6121302402019501 Acc: 0.06729914247989655
Cross validate: 2 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.45s/it]


Loss: 0.6177533936500549 Acc: 0.021059760823845863
Evaluating Dataset: 1 / 1
Cross validate: 0 / 2


Training: 100%|██████████████████████████████| 1/1 [00:21<00:00, 21.46s/it]


Loss: 0.6301901841163635 Acc: 0.5767256617546082
Cross validate: 1 / 2


Training:   0%|                                      | 0/1 [00:15<?, ?it/s]


KeyboardInterrupt: 