In [None]:
%load_ext autoreload
%autoreload 2
from IPython.core.display import display, HTML
display(HTML("<style>.container {width:100% !important;}</style>"))

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import albumentations as albu
import albumentations.pytorch as albu_pt
%matplotlib inline

import apex
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import pytorch_tools as pt

from src.dataset import OpenCitiesDataset

In [None]:
SZ = 384
BS = 16
BUILDINGS_ONLY = True
# want also to transform mask
class ToTensor(albu_pt.ToTensorV2):
    def apply_to_mask(self, mask, **params):
        return torch.from_numpy(mask.transpose(2, 0, 1))
    
aug = albu.Compose([
    albu.Flip(),
    albu.ShiftScaleRotate(scale_limit=0.2), 
    albu.RandomCrop(SZ,SZ),
    albu.RandomBrightnessContrast(), 
    albu.HueSaturationValue(),
    albu.RandomRotate90(),
    albu.Normalize(),
    ToTensor(),
])
simple_aug = albu.Compose([
    albu.CenterCrop(SZ,SZ),
    albu.Normalize(),
    ToTensor(), 
])

val_dtst = OpenCitiesDataset(split="val", transform=simple_aug, buildings_only=BUILDINGS_ONLY)
val_dtld = DataLoader(val_dtst, batch_size=BS, shuffle=False, num_workers=2, drop_last=True)
val_dtld_i = iter(val_dtld)

train_dtst = OpenCitiesDataset(split="train", transform=aug, buildings_only=BUILDINGS_ONLY)
train_dtld = DataLoader(train_dtst, batch_size=BS, shuffle=True, num_workers=8, drop_last=True)
train_dtld_i = iter(train_dtld)

In [None]:
len(train_dtst), len(val_dtst)

In [None]:
img, mask = val_dtst[1]
img.shape, mask.shape

In [None]:
# plt.imshow(mask[2])

In [None]:
# batch = next(val_dtld_i)
imgs, masks = next(val_dtld_i)
preds = model(imgs.cuda()).cpu().detach()
imgs = preds
imgs_grid = make_grid(imgs, nrow=4).transpose(0,2)#.transpose(0,1)
masks_grid = make_grid(masks, nrow=4).transpose(0,2)#.transpose(0,1)
fig, axes = plt.subplots(ncols=2, figsize=(20, 10))
axes[0].imshow(imgs_grid)
axes[1].imshow(masks_grid);

In [None]:
# model = pt.segmentation_models.DeepLabV3(
#     'se_resnet50', output_stride=16, num_classes=1 if BUILDINGS_ONLY else 3
# ).cuda()
model = pt.segmentation_models.Linknet(
    'se_resnet50', num_classes=1 if BUILDINGS_ONLY else 3
).cuda()
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)
model, optim = apex.amp.initialize(model, optim, verbosity=0, loss_scale=2048)
# freeze encoder
for p in model.encoder.parameters():
    p.requires_grad = False
loss = pt.losses.JaccardLoss(mode="binary" if BUILDINGS_ONLY else "multilabel").cuda()
loss = pt.losses.CrossEntropyLoss(mode="binary" if BUILDINGS_ONLY else "multilabel").cuda()
pt.utils.misc.count_parameters(model)

In [None]:
runner = pt.fit_wrapper.Runner(
    model, 
    optim, 
    criterion=loss,
    callbacks=[
#         pt.fit_wrapper.callbacks.Timer(),
        pt.fit_wrapper.callbacks.ConsoleLogger(), 
#         pt.fit_wrapper.callbacks.ReduceLROnPlateau(10),
#         pt.fit_wrapper.callbacks.FileLogger('/tmp/')
    ],
    metrics=pt.metrics.JaccardScore(mode="binary" if BUILDINGS_ONLY else "multilabel"),
)

In [None]:
class ToCudaLoader:
    def __init__(self, loader):
        self.loader = loader
        
    def __iter__(self):
        return ((img.cuda(), target.cuda()) for img, target in self.loader)
    
    def __len__(self):
        return len(self.loader)
    
val_dtld_gpu = ToCudaLoader(val_dtld)
train_dtld_gpu = ToCudaLoader(train_dtld)

In [None]:
imgs, masks = next(iter(val_dtld_gpu))
imgs.dtype, masks.dtype, imgs.shape, masks.shape

In [None]:
runner.fit(val_dtld_gpu, val_loader=val_dtld_gpu, epochs=30)

In [None]:
for p in model.parameters():
    p.requires_grad = True

In [None]:
runner.fit(val_dtld_gpu, val_loader=val_dtld_gpu, epochs=30)