In [1]:
from utils import getLoadersMap
from models import UNetMultiDate

import os, logging, phobos

import torch
import torch.nn as nn

from phobos.grain import Grain
from phobos.runner import Runner

from polyaxon.tracking import Run

from phobos.loss import save_loss_map
from phobos.metrics import save_metrics_map

In [2]:
if not Runner.local_testing():
    print('not local testing')
    experiment = Run()

not local testing


In [3]:
args = Grain(yaml='metadata.yaml',polyaxon_exp=experiment)


In [4]:
inputs, outputs = args.get_inputs_outputs()

In [5]:
if not os.path.exists(args.weight_dir):
    os.makedirs(args.weight_dir)

In [6]:
#save_loss_map('maps')
#save_metrics_map('maps')


In [7]:
loaders = getLoadersMap(args,inputs)

number of train keys : 9170
number of val keys : 4610



In [8]:
#tloader = loaders['train']
#for inputs, labels in tloader:
#    print(inputs['inp1'].shape)
#    print(labels['out1'].shape,'\n')

In [9]:
shape = inputs.heads['inp1'].shape.H
device = torch.device("cuda",0)
n_channels = inputs.heads['inp1'].shape.C

n_classes  = outputs.heads['out1'].num_classes
if n_classes == 2:
    n_classes = 1

if args.model == 'unetmultidate':
    model = args.load_model(UNetMultiDate,
                                n_channels=n_channels,
                                n_classes=n_classes,
                                patch_size=shape,
                                device=device
                                )

In [10]:
if args.distributed:
    model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=False)
elif args.num_gpus > 1:
    model = nn.DataParallel(model, device_ids=list(range(args.num_gpus)))


In [11]:
if args.pretrained_checkpoint:
    """
    If you have any pretrained weights that you want to load for the model, this 
    is the place to do it.
    """
    print(f'pretrained checkpoint set to {args.pretrained_checkpoint}')
    pretrained = torch.load(args.pretrained_checkpoint)
    model.load_state_dict(pretrained)

In [12]:
if args.resume_checkpoint:
    """If we want to resume training from some checkpoints.
    """
    print(f'pretrained checkpoint set to {args.pretrained_checkpoint}')
    weight = torch.load(args.resume_checkpoint)
    model.load_state_dict(weight)

In [13]:
runner = Runner(
    model=model,
    device=args.device,
    train_loader=loaders['train'],
    val_loader=loaders['val'], 
    inputs=inputs, 
    outputs=outputs, 
    optimizer=args.optimizer, 
    optimizer_args=args.optimizer_args,
    scheduler=args.scheduler,
    scheduler_args=args.scheduler_args,
    mode=args.mode,
    distributed=args.distributed,
    verbose=args.verbose,
    max_iters=args.max_iters,
    frequency=args.frequency, 
    tensorboard_logging=True, 
    polyaxon_exp=None
)

In [14]:
best_val = -1e5
best_metrics = None

logging.info('STARTING training')

for step, outputs in runner.trainer():
    if runner.master():
        print(f'step: {step}')
        outputs.print()

step: 1
out1:
	train_metrics:
		precision: 0.5010952353477478
		recall: 0.501265287399292
		f1: 0.500917375087738
	train_loss:
		bcejaccardloss: 0.6563423871994019
	train_loss : 0.6563423871994019
train_loss: 0.6563423871994019
step: 2
out1:
	train_metrics:
		precision: 0.5026018619537354
		recall: 0.5012729167938232
		f1: 0.4971616864204407
	train_loss:
		bcejaccardloss: 0.6359084844589233
	train_loss : 0.6359084844589233
train_loss: 0.6359084844589233
step: 3
out1:
	train_metrics:
		precision: 0.5021465420722961
		recall: 0.5008498430252075
		f1: 0.49551260471343994
	train_loss:
		bcejaccardloss: 0.616326093673706
	train_loss : 0.616326093673706
train_loss: 0.616326093673706
step: 4
out1:
	train_metrics:
		precision: 0.48958638310432434
		recall: 0.49746468663215637
		f1: 0.4887101352214813
	train_loss:
		bcejaccardloss: 0.5959136486053467
	train_loss : 0.5959136486053467
train_loss: 0.5959136486053467
step: 5
out1:
	train_metrics:
		precision: 0.5046526789665222
		recall: 0.50070559

  return [base_lr * self.gamma ** (self.last_epoch // self.step_size)


step: 10
out1:
	val_metrics:
		precision: 0.45655736327171326
		recall: 0.5
		f1: 0.47729218006134033
	val_loss:
		bcejaccardloss: 0.6178191304206848
	val_loss : 0.6178191304206848
val_loss: 0.6178191304206848
step: 11
out1:
	train_metrics:
		precision: 0.457589715719223
		recall: 0.49982285499572754
		f1: 0.4777747690677643
	train_loss:
		bcejaccardloss: 0.5312828421592712
	train_loss : 0.5312828421592712
train_loss: 0.5312828421592712
step: 12
out1:
	train_metrics:
		precision: 0.503704309463501
		recall: 0.5000105500221252
		f1: 0.47884076833724976
	train_loss:
		bcejaccardloss: 0.5247690677642822
	train_loss : 0.5247690677642822
train_loss: 0.5247690677642822
step: 13
out1:
	train_metrics:
		precision: 0.529067873954773
		recall: 0.5000805258750916
		f1: 0.4822823405265808
	train_loss:
		bcejaccardloss: 0.510019063949585
	train_loss : 0.510019063949585
train_loss: 0.510019063949585
step: 14
out1:
	train_metrics:
		precision: 0.45659664273262024
		recall: 0.49993106722831726
		f1: 0