In [None]:
# pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
# pip install rasterio

In [None]:
# pip install fastai 

In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os

In [2]:
from segmentation_models_pytorch.losses.constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from segmentation_models_pytorch.losses._functional import soft_jaccard_score

In [3]:
from pathlib import Path

from tqdm import tqdm

import sys, os, random, time, json
import numba, cv2, gc
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D

import torchvision
from torchvision import transforms as T

from segmentation_models_pytorch import FPN
from segmentation_models_pytorch import Unet
from segmentation_models_pytorch import MAnet
from segmentation_models_pytorch import Linknet
from segmentation_models_pytorch import PAN
import segmentation_models_pytorch as smp

import rasterio
from rasterio.windows import Window

import albumentations as A

import matplotlib.pyplot as plt

from sklearn.model_selection import GroupKFold

%matplotlib inline

In [4]:
torch.__version__

'1.7.1'

In [5]:
def set_seeds(seed = 42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
# set_seeds()

### Setup training variables

In [6]:
!mkdir -p /home/jupyter/data_2/
DATA_PATH = Path('/home/jupyter/data_2/')
assert DATA_PATH.exists()

In [7]:
!mkdir -p /home/jupyter/data/
DATA_PATH_ORIG = Path('/home/jupyter/data/')
assert DATA_PATH_ORIG.exists()

In [8]:
REPORT_PATH = Path('/home/jupyter/reports')
if not REPORT_PATH.exists():
    os.makedirs(REPORT_PATH)
assert REPORT_PATH.exists()

In [9]:
!ls {REPORT_PATH}

37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_fold_0
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_fold_1
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_fold_2
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_fold_3
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_fold_0
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_fold_1
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_fold_2
37_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_fold_3
39_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_advprop_fold_0
39_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_advprop_fold_1
39_pytorch_fpn_efficientnet_b7_1536_768_shifted_slices_groupkfold_smooth_new_data_advprop_fold_2
39_pytorch_fpn_efficientnet_b7_15

In [10]:
import multiprocessing

In [11]:
ARCH = 'fpn' # fpn, manet, linknet, pan, unet

WINDOW = 1792 # tile size
MIN_OVERLAP = 32
NEW_SIZE = WINDOW // 2 # size after re-size which are fed to the model

WINDOW_2 = 1024 # tile size
NEW_SIZE_2 = 1024 # size after re-size which are fed to the model

THRESHOLD = (WINDOW ** 2 * 1) // 100  # 5% of the image has to be positive
CONTENT_THRESHOLD = 0.001
NUM_WORKERS = 4

FOLDS = 4

BATCH_SIZE = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

ENCODER_NAME = 'efficientnet-b7'

LR = 1e-3
WD = 1e-3
LABEL_SMOOTH = 0.01
GRAD_ACCU_STEPS = 1
BEST_MODEL = f'best_model_{ARCH}_{ENCODER_NAME}_{WINDOW}_{NEW_SIZE}_double_shift_{ENCODER_NAME}.pth'

RESET_IMAGES = False
ENCODER_WEIGHTS = 'imagenet'

SLICES_PATH = Path('/home/jupyter/ds_cache')
SLICES_PATH_2 = Path('/home/jupyter/ds_cache_high_res')

EPOCHS = 5
PATIENCE = 10

EXPERIMENT_NAME = f'51_fastai_{ARCH}_{ENCODER_NAME}_{WINDOW}_{NEW_SIZE}_groupkfold_b{BATCH_SIZE}'
SLICE_NUMBER = 300

In [12]:
def rle_decode(mask_rle, shape=(256, 256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    splits = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (splits[0:][::2], splits[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype='uint8')
    for lo, hi in zip(starts, ends):
        img[lo: hi] = 1
    return img.reshape(shape, order='F') # Fortran order reshaping

In [13]:
@numba.njit()
def rle_numba(pixels):
    size = len(pixels)
    points = []
    if pixels[0] == 1: points.append(1)
    for i in range(1, size):
        if pixels[i] != pixels[i-1]:
            if len(points) % 2 == 0:
                points.append(i+1)
            else:
                points.append(i+1 - points[-1])
    if pixels[-1] == 1: points.append(size-points[-1]+1)    
    return points

In [14]:
def rle_numba_encode(image):
    pixels = image.flatten(order = 'F')
    points = rle_numba(pixels)
    return ' '.join(str(x) for x in points)

def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

In [15]:
train_df = pd.read_csv(DATA_PATH / 'train.csv', index_col=[0])
train_df

Unnamed: 0_level_0,encoding
id,Unnamed: 1_level_1
2f6ecfcdf,296084587 4 296115835 6 296115859 14 296147109...
8242609fa,96909968 56 96941265 60 96972563 64 97003861 6...
aaa6a05cc,30989109 59 31007591 64 31026074 68 31044556 7...
cb2d976f4,78144363 5 78179297 15 78214231 25 78249165 35...
b9a3865fc,61271840 4 61303134 13 61334428 22 61365722 30...
b2dc8411c,56157731 21 56172571 45 56187411 51 56202252 5...
0486052bb,101676003 6 101701785 8 101727568 9 101753351 ...
e79de561c,7334642 14 7350821 41 7367001 67 7383180 82 73...
095bf7a1f,113277795 21 113315936 53 113354083 87 1133922...
54f2eec69,124967057 36 124997425 109 125027828 147 12505...


In [16]:
import shutil

def reset_slices_path():
    shutil.rmtree(SLICES_PATH)

### Dataset

In [17]:

# Used to filter tiles with enough color information in it
def is_tile_contains_info(img, expected_shape, pixel_limits = (50, 220), content_threshold = CONTENT_THRESHOLD):
    """
    img: np.array
    pixel_limits: tuple
    content_threshold: float percents
    expected_shape: tuple
    """
    
    left_limit = np.prod(img > pixel_limits[0], axis=-1)
    right_limit =  np.prod(img < pixel_limits[1], axis=-1)

    if img.shape != expected_shape:
        print('img.shape != expected_shape', img.shape)
        return False, 0.

    percent_of_pixels = np.sum(left_limit*right_limit) / (img.shape[0] * img.shape[1])
    return  percent_of_pixels > content_threshold, percent_of_pixels

In [18]:
# identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
identity = None
stats = np.array([0.6276, 0.4468, 0.6769]), np.array([0.1446, 0.2113, 0.1233])
# normalize_transform = T.Normalize([0.625, 0.448, 0.688], [0.131, 0.177, 0.101])
# normalize_transform = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# normalize_transform = T.Normalize([0.65459856,0.48386562,0.69428385], [0.15167958,0.23584107,0.13146145])

def read_from_slice(dataset, layers, x1, x2, y1, y2, window):
    if dataset.count == 3:
        image = dataset.read([1,2,3],
                    window=Window.from_slices((x1,x2),(y1,y2)))
        image = np.moveaxis(image, 0, -1)
    else:
        image = np.zeros((window, window, 3), dtype=np.uint8)
        for fl in range(3):
            try:
                image[:,:,fl] = layers[fl].read(window=Window.from_slices((x1,x2),(y1,y2)))
            except ValueError as e:
                print('ValueError', e, dataset.shape, x1, x2, y1, y2)
                raise e
    return image.astype(np.uint8)

def extract_layers(dataset, filepath):
    layers = None
    if dataset.count != 3:
        layers = [rasterio.open(subd) for subd in dataset.subdatasets]
    return layers

In [19]:
if RESET_IMAGES:
    reset_slices_path()

In [20]:
!mkdir {SLICES_PATH}

mkdir: cannot create directory ‘/home/jupyter/ds_cache’: File exists


In [21]:
class HubDataset(D.Dataset):
    def __init__(self, root_dir, slices_path, transform, valid_transform=None, mode='train', 
                 window=WINDOW, overlap=MIN_OVERLAP, threshold=THRESHOLD, slice_number=SLICE_NUMBER):
        self.path, self.slices_path = root_dir, slices_path
        assert self.path.exists()
        self.overlap, self.window, self.transform, self.valid_transform, self.threshold, self.slice_number = overlap, window, transform, valid_transform, threshold, slice_number
        self.mode = mode
        self.csv = pd.read_csv(self.path/'train.csv', index_col=[0])
        self.build_slices()
        self.len = len(self.slices)
        self.set_normalize_transform(stats)
        self.do_transform = True
        
    def __copy__(self):
        new_ds = type(self)(
            self.path,
            self.slices_path,
            self.transform,
            valid_transform=self.valid_transform,
            mode=self.mode,
            window=self.window,
            overlap=self.overlap,
            threshold=self.threshold
        )
        new_ds.masks = self.masks
        new_ds.files = self.files
        new_ds.slices = self.slices
        new_ds.skipped = self.skipped
        return new_ds
    
    def set_normalize_transform(self, stats):
        self.as_tensor = T.Compose([
            T.ToTensor(),
            T.Normalize(stats[0], stats[1]),
        ])
    
    def build_masks(self):
        for i, filename in tqdm(enumerate(self.csv.index), total = len(self.csv)):
            filepath = self.path/'train'/f'{filename}.tiff'
            with rasterio.open(filepath) as dataset:
                self.masks.append(rle_decode(self.csv.loc[filename, 'encoding'], dataset.shape))
        
    def build_slices(self):
        self.masks = []; self.files = []; self.slices = []
        self.skipped = 0
        slices_path = self.slices_path/f'slices.pkl'
        files_path = self.slices_path/f'files.pkl'
        if not slices_path.exists():
            for i, filename in tqdm(enumerate(self.csv.index), total = len(self.csv)):
                filepath = self.path/'train'/f'{filename}.tiff'
                assert filepath.exists()
                self.files.append(filepath)
                with rasterio.open(filepath) as dataset:
                    self.build_slice_random(dataset, filename, i)
                print(f'Finished {filename}')
            with open(slices_path, "wb") as filehandler:
                pickle.dump(self.slices, filehandler)
            with open(files_path, "wb") as filehandler:
                pickle.dump(self.files, filehandler)
            
        else:
            print('Reading cached slices, files and masks')
            with open(slices_path,'rb') as file:
                self.slices = pickle.load(file)
            with open(files_path,'rb') as file:
                self.files = pickle.load(file)
        self.build_masks()
                
    def build_slice_random(self, dataset, filename, i):
        dataset_shape = dataset.shape
        self.masks.append(rle_decode(self.csv.loc[filename, 'encoding'], dataset_shape))
        filename = DATA_PATH/'train'/f'2f6ecfcdf.tiff'
        assert filename.exists()
        mask = self.masks[-1]
        layers = extract_layers(dataset, filename)
        shape = dataset.shape
        assert mask.shape == shape
        shape_minus_size = (shape[0] - self.window, shape[1] - self.window)
        slice_counter = 0
        set_x = set()
        set_y = set()
        while slice_counter < self.slice_number:
            x1 = np.random.randint(shape_minus_size[0], size=1)[0]
            y1 = np.random.randint(shape_minus_size[1], size=1)[0]
            if x1 not in set_x and y1 not in set_y:
                set_x.add(x1)
                set_y.add(y1)
                x2 = x1 + self.window
                y2 = y1 + self.window
                assert x1 < shape_minus_size[0]
                assert y1 < shape_minus_size[1]
                assert x2 < shape[0]
                assert y2 < shape[1]
                if mask[x1:x2,y1:y2].sum() > THRESHOLD or np.random.rand() < 0.02:
                    self.slices.append([i, x1, x2, y1, y2])
                    slice_counter += 1
                else:
                    self.skipped += 1
                        
    def apply_transform(self, image, mask):
        if self.do_transform:
            augments = self.transform(image=image, mask=mask) if self.mode == 'train' else self.valid_transform(image=image, mask=mask)
            image = self.as_tensor(augments['image'])
            mask = augments['mask'][None]
            mask_torch = torch.from_numpy(mask).to(torch.float16)
            return image, mask_torch
        else:
            trfm = A.Compose([
                A.Resize(NEW_SIZE, NEW_SIZE)
            ])
            augments = trfm(image=image, mask=mask)
            image = self.as_tensor(augments['image'])
            return image, augments['mask'][None]
        
    def __getitem__(self, index):
        image_path = self.slices_path/f'image_{index}'
        slices_path = self.slices_path/f'mask_{index}'
        if not image_path.exists():
            idx = self.slices[index][0]
            filename = self.files[idx]
            x1, x2, y1, y2 = self.slices[index][1:]
            with rasterio.open(filename) as dataset:
                layers = extract_layers(dataset, filename)
                image = read_from_slice(dataset, layers, x1, x2, y1, y2, self.window).astype('uint8')
            mask = self.masks[idx][x1:x2,y1:y2]
            with open(image_path, "wb") as filehandler:
                pickle.dump(image, filehandler)
                if index % 100 == 0:
                    print(f'Writing to {image_path}')
            with open(slices_path, "wb") as filehandler:
                pickle.dump(mask, filehandler)
            return self.apply_transform(image, mask)
        else:
            with open(image_path,'rb') as file:
                image = pickle.load(file)
            with open(slices_path,'rb') as file:
                mask = pickle.load(file)
            return self.apply_transform(image, mask)
    
    def __len__(self):
        return self.len
    
    def __repr__(self):
        return f'total: {len(self)}, skipped: {self.skipped} mode: {self.mode}'

In [22]:
def generate_ds(size, data_path, slices_path, window):
    trfm = A.Compose([
        A.Resize(size, size, p=1.0),
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=20, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),
        A.OneOf([
            A.OpticalDistortion(p=0.4),
            A.GridDistortion(p=.1, border_mode=cv2.BORDER_REFLECT),
            A.IAAPiecewiseAffine(p=0.4),
        ], p=0.3),
        A.OneOf([
            A.MotionBlur(p=0.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ]),
        A.OneOf([
            A.HueSaturationValue(10,15,10),
            A.CLAHE(clip_limit=3),
            A.RandomBrightnessContrast(),
            A.RandomGamma()
        ], p=0.5)
    ], p=1.0)
    
    valid_transform = A.Compose([
        A.Resize(size, size, p=1.0),
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90()
    ])

    return HubDataset(data_path, slices_path, window=window, overlap=MIN_OVERLAP, transform=trfm, valid_transform=valid_transform)

In [23]:
ds = generate_ds(NEW_SIZE, DATA_PATH, SLICES_PATH, WINDOW)
ds

  s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  7%|▋         | 1/15 [00:00<00:01,  9.78it/s]

Reading cached slices, files and masks


100%|██████████| 15/15 [00:03<00:00,  4.95it/s]


total: 4500, skipped: 0 mode: train

In [24]:
for i in range(len(ds.slices)):
    assert ds.slices[i][2] - ds.slices[i][1] == ds.slices[i][4] - ds.slices[i][3] == WINDOW
WINDOW

1792

In [25]:
def get_mean_std(train_dl):
    '''
    Calculate the mean and std
    var = E[x**2] - E[x]**2
    '''
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0

    for data, _ in tqdm(train_dl, total=len(train_dl)):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    assert num_batches == len(train_dl)
    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
    return mean, std

ds.do_transform = False
norm_dl = D.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
stats = get_mean_std(norm_dl)
stats
# (tensor([0.0578, 0.0925, 0.0886]), tensor([1.0409, 1.0288, 1.0202]))
# (tensor([0.0578, 0.0925, 0.0886]), tensor([1.0409, 1.0288, 1.0202]))
# (tensor([0.0741, 0.1079, 0.1045]), tensor([1.0437, 1.0380, 1.0198]))

 24%|██▎       | 266/1125 [00:16<00:53, 15.96it/s]


KeyboardInterrupt: 

In [None]:
ds.do_transform = True

In [None]:
ds.set_normalize_transform(stats)

In [None]:
def display_mask_img(idx):
    image, mask = ds[idx]
    mask = mask.to(torch.uint8)
    plt.figure(figsize=(16,8))
    plt.subplot(121)
    plt.imshow(mask[0], cmap='gray')
    plt.subplot(122)
    plt.imshow(np.moveaxis(image.numpy(), 0, -1));

images_to_display = 25
for i in range(images_to_display):
    display_mask_img(i)

In [None]:
ds.mode = 'valid'

for i in range(images_to_display):
    display_mask_img(i)

In [None]:
image, mask = ds[0]

_ = rle_numba_encode(mask[0].numpy().astype('uint8')) # compile function with numba

In [None]:
image.shape

#### Split

In [None]:
# Images and its corresponding masks are saved with the same filename.
def generate_fold_info(ds):
    groups = [ds.slices[i][0] for i in range(len(ds))]
    group_kfold = GroupKFold(n_splits = FOLDS)
    fold_info = [(train_idx, valid_idx) for fold, (train_idx, valid_idx) in tqdm(enumerate(group_kfold.split(ds.slices, 
                                                            groups = groups)), total=FOLDS)]
    return fold_info

In [None]:
fold_info = generate_fold_info(ds)

#### Data loaders

In [None]:
import copy

def create_subset(ds, train_idx, valid_idx):
    train_ds = D.Subset(ds, train_idx)
    val_ds = copy.copy(ds)
    val_ds.mode = 'valid'
    valid_ds = D.Subset(val_ds, valid_idx)
    print(val_ds)
    return train_ds, valid_ds

In [None]:
def generate_train_valid_dls(ds, train_idx, valid_idx):
    train_ds, valid_ds = create_subset(ds, train_idx, valid_idx)

    num_workers = NUM_WORKERS
    # define training and validation data loaders
    train_dl = D.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)

    valid_dl = D.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
    return train_dl, valid_dl

### Model

In [None]:
class HuBMAPModel(nn.Module):
    def __init__(self):
        super(HuBMAPModel, self).__init__()
        args = {
            'encoder_name': ENCODER_NAME, 
            'encoder_weights': ENCODER_WEIGHTS,
            'classes': 2,
            'activation': None,
            'aux_params': None
        }
        if ARCH == 'unet':
            self.model = Unet(**args)
        elif ARCH == 'fpn':
            self.model = FPN(**args)
        elif ARCH == 'manet':
            self.model = MAnet(**args)
        elif ARCH == 'linknet':
            self.model = Linknet(**args)
        elif ARCH == 'pan':
            self.model = PAN(**args)
        
    def forward(self, images):
        return self.model(images)

In [None]:
def create_model():
    model = HuBMAPModel()
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model.to(DEVICE)
    model.train()
    return model

### Training

In [None]:
def create_optimizer(model):
    return torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

def create_optimizer_scheduler(model, train_dl, epochs):
    optimizer = create_optimizer(model)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,
                                                    steps_per_epoch=len(train_dl), epochs=epochs)
    return optimizer, scheduler

### Loss

In [None]:
dice_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)

def loss_fn(y_pred, y_true):
    return dice_loss(y_pred, y_true)

In [None]:
jaccard_loss = smp.losses.JaccardLoss(mode='multiclass', from_logits=True)

In [None]:
def gen_dls(fold, ds=ds):
    train_idx, valid_idx = fold_info[fold]
    print(f'Proportions valid / train: {len(valid_idx) / len(train_idx)}')
    train_dl, valid_dl = generate_train_valid_dls(ds, train_idx, valid_idx)
    return train_dl, valid_dl

In [None]:
def create_datasets(fold=0):
    train_idx, valid_idx = fold_info[fold]
    train_ds, valid_ds = create_subset(ds, train_idx, valid_idx)
    return train_ds, valid_ds

In [None]:
FOLD = 0
train_ds, valid_ds = create_datasets(FOLD)
model = create_model()

In [None]:
from fastai.vision.all import *

In [None]:
from deepflash2.all import Iou, Dice_f1

In [None]:
@patch
def forward(self:smp.losses.JaccardLoss, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    
#     print('y_pred.shape', y_pred.shape)
#     print('y_true.shape', y_true.shape)
    
    y_true = y_true.long()
    
    assert y_true.size(0) == y_pred.size(0)

    if self.from_logits:
        # Apply activations to get [0..1] class probabilities
        # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
        # extreme values 0 and 1
        if self.mode == MULTICLASS_MODE:
            y_pred = y_pred.log_softmax(dim=1).exp()
        else:
            y_pred = F.logsigmoid(y_pred).exp()

    bs = y_true.size(0)
    num_classes = y_pred.size(1)
    dims = (0, 2)

    if self.mode == BINARY_MODE:
        y_true = y_true.view(bs, 1, -1)
        y_pred = y_pred.view(bs, 1, -1)

    if self.mode == MULTICLASS_MODE:
        
        y_true = y_true.view(bs, -1)
        y_pred = y_pred.view(bs, num_classes, -1)
        
#         print('y_true.shape', y_true.shape)
#         print('num_classes', num_classes)
        y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C
        y_true = y_true.permute(0, 2, 1)  # H, C, H*W

    if self.mode == MULTILABEL_MODE:
        y_true = y_true.view(bs, num_classes, -1)
        y_pred = y_pred.view(bs, num_classes, -1)

    scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), smooth=self.smooth, eps=self.eps, dims=dims)

    if self.log_loss:
        loss = -torch.log(scores.clamp_min(self.eps))
    else:
        loss = 1.0 - scores

    # IoU loss is defined for non-empty classes
    # So we zero contribution of channel that does not have true pixels
    # NOTE: A better workaround would be to use loss term `mean(y_pred)`
    # for this case, however it will be a modified jaccard loss

    mask = y_true.sum(dims) > 0
    loss *= mask.float()

    if self.classes is not None:
        loss = loss[self.classes]

    return loss.mean()

In [None]:
# Dataloader and learner
dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=BATCH_SIZE)
if torch.cuda.is_available(): dls.cuda()

In [None]:
best_model_name = f'{FOLD}_best_model'
cbs = [SaveModelCallback(monitor='iou', fname=best_model_name), EarlyStoppingCallback(monitor='iou', patience=10)]

In [None]:
learn = Learner(dls, model, metrics=[Iou(), Dice_f1()], wd=WD, loss_func=jaccard_loss, opt_func=ranger, cbs=cbs)
learn.to_fp16()

In [None]:
learn.fit_one_cycle(EPOCHS, lr_max=LR)

In [None]:
learn.recorder.plot_metrics()

In [None]:
!mv /home/hubmap_notebooks/models/{best_model_name}.pth /home/hubmap_notebooks/models/{FOLD}_{EXPERIMENT_NAME}.pth

In [None]:
import copy

FOLD = 1
train_ds, valid_ds = create_datasets(FOLD)
model = create_model()

In [None]:
from fastai.vision.all import *

In [None]:
# Dataloader and learner
dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=BATCH_SIZE)
if torch.cuda.is_available(): dls.cuda()

In [None]:
best_model_name = f'{FOLD}_best_model'
cbs = [SaveModelCallback(monitor='dice_metric', fname=best_model_name), EarlyStoppingCallback(monitor='dice_metric', patience=10)]
learn = Learner(dls, model, metrics=[Iou(), Dice_f1()], wd=WD, loss_func=loss_fn, opt_func=ranger, cbs=cbs)
learn.to_fp16()

In [None]:
learn.fit_one_cycle(EPOCHS, lr_max=LR)

In [None]:
!mv /home/hubmap_notebooks/models/{best_model_name}.pth /home/hubmap_notebooks/models/{FOLD}_{EXPERIMENT_NAME}.pth

In [None]:
import copy

FOLD = 2
train_ds, valid_ds = create_datasets(FOLD)
model = create_model()

from fastai.vision.all import *

# Dataloader and learner
dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=BATCH_SIZE)
if torch.cuda.is_available(): dls.cuda()

In [None]:
best_model_name = f'{FOLD}_best_model'
cbs = [SaveModelCallback(monitor='dice_metric', fname=best_model_name), EarlyStoppingCallback(monitor='dice_metric', patience=10)]
learn = Learner(dls, model, metrics=[dice_metric, iou], wd=WD, loss_func=loss_fn, opt_func=ranger, cbs=cbs)
learn.to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(EPOCHS, lr_max=LR)

In [None]:
import copy

FOLD = 3
train_ds, valid_ds = create_datasets(FOLD)
model = create_model()

from fastai.vision.all import *

# Dataloader and learner
dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=BATCH_SIZE)
if torch.cuda.is_available(): dls.cuda()

In [None]:
best_model_name = f'{FOLD}_best_model'
cbs = [SaveModelCallback(monitor='dice_metric', fname=best_model_name), EarlyStoppingCallback(monitor='dice_metric', patience=10)]
learn = Learner(dls, model, metrics=[dice_metric, iou], wd=WD, loss_func=loss_fn, opt_func=ranger, cbs=cbs)
learn.to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(EPOCHS, lr_max=LR)

In [None]:
!mv /home/hubmap_notebooks/models/{best_model_name}.pth /home/hubmap_notebooks/models/{FOLD}_{EXPERIMENT_NAME}.pth

In [None]:
!mkdir pytorch_models
!mv *.pth pytorch_models

In [None]:
state_dict = torch.load('/home/hubmap_notebooks/models/model.pth')

In [None]:
test_model = get_model()

In [None]:
test_model.module.load_state_dict(state_dict)