In [None]:
import os

import albumentations as A
import cv2
import matplotlib.pyplot as plt
# import numpy as np
import torch
from torch.utils.data import DataLoader
from pytorch_toolbelt import losses
# from prepare_data import prepare_data
from segmentation.data import SegmentationDataModule
from segmentation.utils import object_from_dict
import pytorch_lightning as pl
from train import SegmentationModule

%load_ext autoreload
%autoreload 2

In [None]:
import yaml

In [None]:
from albumentations.core.serialization import from_dict

In [None]:
import glob

In [None]:
with open(sorted(glob.glob("configs/*"), reverse=True)[0]) as f:
    hparams = yaml.load(f, Loader=yaml.SafeLoader)
    
hparams["model"]["model"]["classes"] = len(hparams["categories"])

In [None]:
def get_everything(hparams, overfit_batches=0.0):
    pl.seed_everything(hparams["seed"])
    model = SegmentationModule(hparams["model"])
    transforms = {
    "train": from_dict(hparams["data"]["transforms"]["train"]),
    "val": from_dict(hparams["data"]["transforms"]["val"]),
    "test": from_dict(hparams["data"]["transforms"]["test"]),
    }
    data = SegmentationDataModule(
        **hparams["data"]["data"],
        transforms=transforms,
    )
    trainer = object_from_dict(
        hparams["trainer"]["trainer"],
        logger=object_from_dict(hparams["trainer"]["logger"]),
        callbacks=[
            object_from_dict(callback)
            for callback in hparams["trainer"]["callbacks"].values()
        ],
        overfit_batches=overfit_batches,
    )
    return model, data, trainer

In [None]:
rm -rf lightning_logs/

In [None]:
model, data, trainer = get_everything(hparams, overfit_batches=2)
trainer.fit(model, data)

In [None]:
model, data, trainer = get_everything(hparams)

# Run learning rate finder
lr_finder = trainer.tuner.lr_find(
    model, 
    data, 
    min_lr=1e-3, max_lr=1e3, num_training=200, early_stop_threshold=None)

# Results can be found in
lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)

# Print suggestion
lr_finder.suggestion()

In [None]:
model, data, trainer = get_everything(hparams)
trainer.fit(model, data)
trainer.test()