In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ../

/mnt/NVME1TB/Projects/kaggle-severstal-2019


In [3]:
TRAIN_IMAGES = '/home/denilv/Projects/kaggle-severstal-2019/data/train_images/'
TRAIN_CSV = '/mnt/NVME1TB/Projects/kaggle-severstal-2019/data/segm_df/train.csv'
VALID_CSV = '/mnt/NVME1TB/Projects/kaggle-severstal-2019/data/segm_df/valid.csv'
TEST_IMAGES = '/home/denilv/Projects/kaggle-severstal-2019/data/test_images/'

In [4]:
EPOCHS = 30
LR = 1e-3
BATCH_SIZE = 32
CROP_SIZE = None

CUDA_VISIBLE_DEVICES = '1'

In [5]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
import torch
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp

from albumentations.augmentations.functional import normalize
from tqdm.auto import tqdm
from modules.comp_tools import Dataset, AUGMENTATIONS_TRAIN, get_segm_model, ModelAgg, predict_semg, decode_masks, dice_channel_torch
from modules.common import rle_decode
from catalyst.dl.runner import SupervisedRunner
from catalyst.dl.callbacks import F1ScoreCallback, AccuracyCallback
from torch.utils.data import DataLoader as BaseDataLoader
from torch.utils.data import Dataset as BaseDataset

import ttach as tta

pyarrow not available, switching to pickle. To install pyarrow, run `pip install pyarrow`.
lz4 not available, disabling compression. To install lz4, run `pip install lz4`.
wandb not available, to install wandb, run `pip install wandb`.


In [16]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

preprocessing_fn = lambda x: to_tensor(normalize(x, MEAN, STD, max_pixel_value=1.0))

# encoder_preprocessing = smp.encoders.get_preprocessing_fn('se_resnext50_32x4d', 'imagenet')
# preprocessing_fn1 = lambda x: to_tensor(encoder_preprocessing(x))

In [7]:
arch_args = dict(
    encoder_name='se_resnext101_32x4d',
    encoder_weights='imagenet',
    classes=4, 
    activation='sigmoid',
)

load_weights = 'logs/se_resnext101_32x4d_augm_cos_annealing_bce_jacc/checkpoints/best.pth'

model = get_segm_model('FPN', arch_args, load_weights=load_weights)

model = model.cuda()
model = model.eval()

se_resnext101 = model

Loading logs/se_resnext101_32x4d_augm_cos_annealing_bce_jacc/checkpoints/best.pth
<All keys matched successfully>


In [8]:
arch_args = dict(
    encoder_name='se_resnext50_32x4d',
    encoder_weights='imagenet',
    classes=4, 
    activation='sigmoid',
)

load_weights = 'logs/unet_se_resnext50_32x4d/checkpoints/best.pth'

model = get_segm_model('Unet', arch_args, load_weights=load_weights)

model = model.cuda()
model = model.eval()

se_resnext50 = model

Loading logs/unet_se_resnext50_32x4d/checkpoints/best.pth
<All keys matched successfully>


In [17]:
valid_df = pd.read_csv(VALID_CSV).fillna('')
valid_df = decode_masks(valid_df)

valid_dataset = Dataset(
    valid_df[:],
    img_prefix=TRAIN_IMAGES, 
    augmentations=None, 
    preprocess_img=preprocessing_fn1,
    preprocess_mask=to_tensor,
)

valid_dl = BaseDataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)

HBox(children=(IntProgress(value=0, max=5336), HTML(value='')))




In [11]:
def calc_dice(model, dl, th=0.5, device='cuda:0'):
    dices = []
    with torch.no_grad():
        for features, gt in tqdm(dl):
            features = features.to(device)
            logits = model(features).detach().cpu()
            batch_dice = dice_channel_torch(logits, gt, th)
            dices.append(batch_dice)
    return np.mean(dices)

In [12]:
calc_dice(se_resnext50, valid_dl)

HBox(children=(IntProgress(value=0, max=334), HTML(value='')))




0.873062

In [21]:
tta_se_resnext50 = tta.SegmentationTTAWrapper(se_resnext50, tta.aliases.hflip_transform())
calc_dice(tta_se_resnext50, valid_dl)

HBox(children=(IntProgress(value=0, max=334), HTML(value='')))




0.8824164

In [13]:
calc_dice(se_resnext101, valid_dl)

HBox(children=(IntProgress(value=0, max=334), HTML(value='')))




0.90145725

In [24]:
tta_se_resnext101 = tta.SegmentationTTAWrapper(se_resnext101, tta.aliases.hflip_transform())
calc_dice(tta_se_resnext101, valid_dl)

HBox(children=(IntProgress(value=0, max=334), HTML(value='')))




0.9038838

In [27]:
ens_model = ModelAgg([
    se_resnext101,
    se_resnext50,
])
calc_dice(ens_model, valid_dl)

HBox(children=(IntProgress(value=0, max=334), HTML(value='')))




0.89592856

In [28]:
tta_ens_model = tta.SegmentationTTAWrapper(ens_model, tta.aliases.hflip_transform())
calc_dice(tta_ens_model, valid_dl)

HBox(children=(IntProgress(value=0, max=334), HTML(value='')))




0.8981718