In [None]:
from fastai.vision import unet_learner, imagenet_stats, torch, Path, os, load_learner, models
from experiments import getDatasets, getData, random_seed
from losses import BCELoss, MixedLoss
from metrics import MetricsCallback, getDatasetMetrics
from fastai.callbacks import CSVLogger
from config import *

%load_ext autoreload
%autoreload 2

torch.cuda.set_device(0)

In [None]:
EXPERIMENT_PATH = Path(EXPERIMENTS_PATH) / 'loss'
MODELS_PATH = EXPERIMENT_PATH / "models"
os.makedirs(MODELS_PATH, exist_ok=True)

In [None]:
allData = getData()

In [None]:
props = {'bs': 4, 'val_bs': 2, 'num_workers': 0}
losses = {'bce0.5':  BCELoss(0.5), 'bce1': BCELoss(1), 'bce5': BCELoss(5), 'bce10': BCELoss(10), 'bce30': BCELoss(30), 
          'mixed_10_2': MixedLoss(10.0, 2.0), 'mixed_10_1': MixedLoss(10.0, 1.0),
          'mixed_5_2': MixedLoss(5.0, 2.0), 'mixed_5_1': MixedLoss(5.0, 1.0),
          'mixed_5_2': MixedLoss(2.0, 2.0), 'mixed_5_1': MixedLoss(2.0, 1.0),
          'dice': MixedLoss(0.0, 1.0)
         }

In [None]:
for name, loss in losses.items():
    for index, dataset in enumerate(getDatasets(allData)):
        PATH = EXPERIMENT_PATH / name / str(index)
        if not (PATH / 'final model.pkl').exists():
            random_seed(42)
            data = dataset.databunch(**props).normalize(imagenet_stats)
            random_seed(42)
            learn = unet_learner(data, models.resnet18, callback_fns=[MetricsCallback, CSVLogger], model_dir='models', loss_func=loss, path=PATH)
            random_seed(42)
            learn.fit_one_cycle(10, 1e-4)
            learn.save('model')
            learn.export(file='final model.pkl')
    for index, dataset in enumerate(getDatasets(allData, crop=False, cutInHalf=False)):   
        PATH = EXPERIMENT_PATH / name / str(index)
        if not (PATH / 'final predictions.csv').exists():
            learn = load_learner(PATH, 'final model.pkl')
            random_seed(42)
            m = getDatasetMetrics(dataset, learn)
            m.save(PATH / 'final predictions.csv')