In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd ../

/mnt/NVME1TB/Projects/kaggle-severstal-2019


In [3]:
TRAIN_IMAGES = '/home/denilv/Projects/kaggle-severstal-2019/data/train_images/'
VALID_CLS_CSV = '/mnt/NVME1TB/Projects/kaggle-severstal-2019/data/cls_df/valid.csv'
VALID_SEGM_CSV = '/mnt/NVME1TB/Projects/kaggle-severstal-2019/data/segm_df/valid.csv'
TEST_IMAGES = '/home/denilv/Projects/kaggle-severstal-2019/data/test_images/'

In [4]:
BATCH_SIZE = 16

CUDA_VISIBLE_DEVICES = '1'

In [5]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
import torch
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
import cv2

from tqdm.auto import tqdm
from modules.comp_tools import (
    Dataset,
    ClsDataset,
    AUGMENTATIONS_TRAIN, 
    get_segm_model,
    get_model,
    ModelAgg, 
    predict_semg, 
    decode_masks, 
    dice_channel_torch,
    predict_cls,
    preprocessing_fn,
    to_tensor,
    decode_masks,
    TestDataset,
    TestClsDataset,
    mask2rle,
)
from torch.utils.data import DataLoader as BaseDataLoader

import ttach as tta


pyarrow not available, switching to pickle. To install pyarrow, run `pip install pyarrow`.
lz4 not available, disabling compression. To install lz4, run `pip install lz4`.
wandb not available, to install wandb, run `pip install wandb`.


# Common validation dataset

In [37]:
cls_df = pd.read_csv(VALID_CLS_CSV, index_col='ImageId')
segm_df = pd.read_csv(VALID_SEGM_CSV, index_col='ImageId').fillna('')

In [38]:
empty_ids = cls_df.index[cls_df.has_defect == 0].values.tolist()
not_empty_ids = cls_df.index[cls_df.has_defect == 1].values.tolist()
inter_ids = (cls_df.index & segm_df.index).unique().values.tolist()

ids = empty_ids + inter_ids

In [39]:
df = pd.read_csv('data/train.csv').fillna('')
df = df.join(df.ImageId_ClassId.str.split('_', expand=True).rename(columns={0: 'ImageId', 1: 'ClassId'}))
df.set_index('ImageId', inplace=True)

In [40]:
cls_df = cls_df.loc[ids]
segm_df = df.loc[ids]
segm_df = decode_masks(segm_df)

HBox(children=(IntProgress(value=0, max=5884), HTML(value='')))




In [41]:
f = lambda x: list(map(int, x.strip('[]').split()))
cls_df['defect_map'] = cls_df.defect_map.apply(f)

In [42]:
cls_dl = BaseDataLoader(
    ClsDataset(
        cls_df.reset_index(),
        img_prefix=TRAIN_IMAGES, 
        augmentations=None, 
        preprocess_img=preprocessing_fn,
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

segm_dl = BaseDataLoader(
    Dataset(
        segm_df.reset_index(),
        img_prefix=TRAIN_IMAGES, 
        augmentations=None, 
        preprocess_img=preprocessing_fn,
        preprocess_mask=to_tensor,
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

# Check segm models

In [11]:
from torch.jit import load

import albumentations as A

In [12]:
segm_models = []

In [110]:
segm_models.append(load('data/blend_models/segm/unet_se_resnext50_32x4d.pth').cuda().eval())
segm_models.append(load('data/blend_models/segm/unet_mobilenet2.pth').cuda().eval())
segm_models.append(load('data/blend_models/segm/unet_resnet34.pth').cuda().eval())

In [13]:
# ####
# arch_args = dict(
#     encoder_name='se_resnext101_32x4d',
#     encoder_weights='imagenet',
#     classes=4, 
#     activation='sigmoid',
# )
# load_weights = 'logs/fpn_se_resnext101_32x4d/checkpoints/best.pth'
# model = get_segm_model('FPN', arch_args, load_weights=load_weights)
# model = model.cuda()
# model = model.eval()
# segm_models.append(model)

####
arch_args = dict(
    encoder_name='se_resnext101_32x4d',
    encoder_weights='imagenet',
    classes=4, 
    activation='sigmoid',
)
load_weights = 'logs/fpn_se_resnext101_32x4d_softmax_withEmpty/checkpoints/best.pth'
model = get_segm_model('FPN', arch_args, load_weights=load_weights)
model = model.cuda()
model = model.eval()
segm_models.append(model)

####
arch_args = dict(
    encoder_name='resnet18',
    encoder_weights='imagenet',
    classes=4, 
    activation='softmax',
)
load_weights = 'logs/unet_resnet18_softmax_withEmpty/checkpoints/best.pth'
model = get_segm_model('Unet', arch_args, load_weights=load_weights)
model = model.cuda()
model = model.eval()
segm_models.append(model)

# ####
# arch_args = dict(
#     encoder_name='resnet50',
#     encoder_weights='imagenet',
#     classes=4, 
#     activation='sigmoid',
# )
# load_weights = 'logs/fpn_resnet50/checkpoints/best.pth'
# model = get_segm_model('FPN', arch_args, load_weights=load_weights)
# model = model.cuda()
# model = model.eval()
# segm_models.append(model)

# ####
# arch_args = dict(
#     encoder_name='resnet50',
#     encoder_weights='imagenet',
#     classes=4, 
#     activation='sigmoid',
# )
# load_weights = 'logs/unet_resnet50/checkpoints/best.pth'
# model = get_segm_model('Unet', arch_args, load_weights=load_weights)
# model = model.cuda()
# model = model.eval()
# segm_models.append(model)


# ####
# arch_args = dict(
#     encoder_name='se_resnext50_32x4d',
#     encoder_weights='imagenet',
#     classes=4, 
#     activation='sigmoid',
# )
# load_weights = 'logs/unet_se_resnext50_32x4d/checkpoints/best.pth'
# model = get_segm_model('Unet', arch_args, load_weights=load_weights)
# model = model.cuda()
# model = model.eval()
# segm_models.append(model)

# ####
# arch_args = dict(
#     encoder_name='se_resnext101_32x4d',
#     encoder_weights='imagenet',
#     classes=4, 
#     activation='sigmoid',
#     attention_type='scse',
# )
# load_weights = 'logs/unet_se_resnext101_32x4d_attn/checkpoints/best.pth'
# model = get_segm_model('Unet', arch_args, load_weights=load_weights)
# model = model.cuda()
# model = model.eval()
# segm_models.append(model)


# segm_model_ens = ModelAgg(segm_models)
# segm_model = tta.SegmentationTTAWrapper(model, tta.aliases.hflip_transform())

Loading logs/fpn_se_resnext101_32x4d_softmax_withEmpty/checkpoints/best.pth
<All keys matched successfully>
Loading logs/unet_resnet18_softmax_withEmpty/checkpoints/best.pth
<All keys matched successfully>


In [8]:
####
arch_args = dict(
    encoder_name='efficientnet-b5',
    encoder_weights='imagenet',
    classes=5, 
    activation='softmax',
)
load_weights = 'logs/fpn_efficientnet-b5_crop/checkpoints/best_full.pth'
model = get_segm_model('FPN', arch_args, load_weights=load_weights)
model = model.cuda()
model = model.eval()
segm_models.append(model)

Loading logs/fpn_efficientnet-b5_crop/checkpoints/best_full.pth
<All keys matched successfully>


In [20]:
segm_df = pd.read_csv(VALID_SEGM_CSV, index_col='ImageId').fillna('')
segm_df = decode_masks(segm_df)

HBox(children=(IntProgress(value=0, max=5336), HTML(value='')))




In [21]:
dataset = Dataset(
    segm_df.reset_index(),
    img_prefix=TRAIN_IMAGES, 
    augmentations=None,
    background=False,
    preprocess_img=preprocessing_fn,
    preprocess_mask=to_tensor,
)

In [22]:
segm_dl = BaseDataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)

## Check slicer

In [148]:
origin_model = model

In [124]:
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.hflip_transform())

In [125]:
from modules.comp_tools import CroppedDataset
from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger
from modules.common import visualize
from torch.utils.data import DataLoader
from pytorch_toolbelt.utils.torch_utils import to_numpy

In [140]:
th = 0.5

d = iter(dataset)

In [155]:
def check(model, dataset, max_iters=100):
    tiler = ImageSlicer((256, 1600), (256, 416), 256)
    dice_scores = []
    i = 0
    for img, gt_mask in tqdm(dataset):
        if i >= max_iters:
            break
        merger = CudaTileMerger(tiler.target_shape, 5, tiler.weight)
        tiles = [preprocessing_fn(image_tile) for image_tile in tiler.split(img)]
        dl = DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True)
        for tiles_batch, coords_batch in dl:
            with torch.no_grad():
                tiles_batch = tiles_batch.float().cuda()
                pred_batch = model(tiles_batch)
                merger.integrate_batch(pred_batch, coords_batch)

        merged_logits = merger.merge()
        merged_probs = torch.sigmoid(merged_logits)
        merged_probs = to_numpy(merged_probs).transpose(1, 2, 0)
        pred_probs = tiler.crop_to_orignal_size(merged_probs)
        pred_mask = (pred_probs > th).astype(np.uint8)

        batch_pred_mask = np.expand_dims(pred_mask[..., :-1], 0)
        batch_gt_mask = np.expand_dims(gt_mask[..., :-1], 0)
        dice_score = dice_channel_torch(batch_pred_mask, batch_gt_mask)
        dice_scores.append(dice_score)
    #     print(np.round(dice_score, 3))
    #     visualize(img=img, gt=gt_mask[..., :-1].sum(-1), pred=pred_mask[..., :-1].sum(-1))
        i += 1
    print(np.mean(dice_scores))

In [14]:
def dice_channel_torch(probability, truth):
    batch_size = truth.shape[0]
    channel_num = truth.shape[1]
    mean_dice_channel = 0.
    for i in range(batch_size):
        for j in range(channel_num):
            channel_dice = dice_single_channel(probability[i, j,:,:], truth[i, j, :, :])
            mean_dice_channel += channel_dice / (batch_size * channel_num)
    return mean_dice_channel

def dice_single_channel(probability, truth, eps=1e-9):
    p = probability.flatten()
    t = (truth.flatten() > 0.5).astype(float)
    dice = (2.0 * (p * t).sum() + eps)/ (p.sum() + t.sum() + eps)
    return dice

In [15]:
def remove_small_one(predict, min_size):
    H,W = predict.shape
    num_component, component = cv2.connectedComponents(predict.astype(np.uint8))
    predict = np.zeros((H,W), np.bool)
    for c in range(1,num_component):
        p = (component == c)
        if p.sum() > min_size:
            predict[p] = True
    return predict

In [10]:
def sharpen(p, t=0.5):
    if t != 0:
        return p**t
    else:
        return p

In [16]:
model_ens = ModelAgg(segm_models)
tta_se_resnext101 = tta.SegmentationTTAWrapper(segm_models[0], tta.aliases.hflip_transform())
tta_resnet18 = tta.SegmentationTTAWrapper(segm_models[1], tta.aliases.hflip_transform())
tta_model_ens = tta.SegmentationTTAWrapper(model_ens, tta.aliases.hflip_transform())

In [23]:
device = 'cuda'
segm_th = 0.5

dices = []
cls_preds = []
for i, segm_model in enumerate([model_ens, tta_se_resnext101, tta_resnet18, tta_model_ens]):
    with torch.no_grad():
        for segm_batch in tqdm(segm_dl, total=len(segm_dl)):
            segm_features, segm_gt = segm_batch
            segm_features = segm_features.to(device)
            segm_logits = segm_model(segm_features).detach().cpu()
            pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
            pred_masks = pred_masks.astype(np.uint8)
#             for img_id in range(len(pred_masks)):
#                 for j in range(4):
#                     pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=200)
            batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
            dices.append(batch_dice)
    print(i, np.mean(dices))

HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


0 0.9127107775961745


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


1 0.9152353536871027


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


2 0.9107266681551741


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


3 0.9115327435686017


In [21]:
device = 'cuda'
segm_th = 0.5

dices = []
cls_preds = []
for i, segm_model in enumerate([model_ens, tta_se_resnext101, tta_resnet18, tta_model_ens]):
    with torch.no_grad():
        for segm_batch in tqdm(segm_dl, total=len(segm_dl)):
            segm_features, segm_gt = segm_batch
            segm_features = segm_features.to(device)
            segm_logits = segm_model(segm_features).detach().cpu()
            pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
            pred_masks = pred_masks.astype(np.uint8)
#             for img_id in range(len(pred_masks)):
#                 for j in range(4):
#                     pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=200)
            batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
            dices.append(batch_dice)
    print(i, np.mean(dices))

HBox(children=(IntProgress(value=0, max=84), HTML(value='')))




0 0.9078998254167824


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


1 0.9058301214405342


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


2 0.9044565133241284


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))


3 0.9057805123247601


In [44]:
segm_model = tta_model_ens

# Check cls models

In [34]:
cls_models = []

# ####
# ENCODER = 'resnet50'
# ENCODER_WEIGHTS = 'imagenet'
# ACTIVATION = 'sigmoid'

# CONTINUE = 'logs/cls_resnet50/checkpoints/best_augm.pth'

# model = get_model(ENCODER, 2, ENCODER_WEIGHTS, load_weights=CONTINUE)
# model = model.cuda()
# model = model.eval()
# cls_models.append(model)

# ####
# ENCODER = 'resnet18'
# ENCODER_WEIGHTS = 'imagenet'
# ACTIVATION = 'sigmoid'

# CONTINUE = 'logs/cls_resnet18/checkpoints/best_augm.pth'

# model = get_model(ENCODER, 2, ENCODER_WEIGHTS, load_weights=CONTINUE)
# model = model.cuda()
# model = model.eval()
# cls_models.append(model)

####
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax'

CONTINUE = 'logs/cls_resnet34_multiclass/checkpoints/best_non_augm.pth'

model = get_model(ENCODER, 5, ENCODER_WEIGHTS, load_weights=CONTINUE)
model = model.cuda()
model = model.eval()
cls_models.append(model)


cls_model_ens = ModelAgg(cls_models)
cls_model_tta = tta.ClassificationTTAWrapper(model, tta.aliases.hflip_transform())

Loading logs/cls_resnet34_multiclass/checkpoints/best_non_augm.pth
<All keys matched successfully>


In [23]:
f = lambda x: list(map(int, x.strip('[]').split()))
df = pd.read_csv(VALID_CLS_CSV)
df['defect_map'] = df.defect_map.apply(f)

cls_dl = BaseDataLoader(
    ClsDataset(
        df,
        img_prefix=TRAIN_IMAGES, 
        augmentations=None, #AUGMENTATIONS_TRAIN,
        mode='multiclass',
        binary=False,
        preprocess_img=preprocessing_fn,
    ), 
    batch_size=BATCH_SIZE,
    shuffle=False, 
    num_workers=0
)

In [48]:
device = 'cuda'
cls_th = 0.5

cls_probs = []
cls_preds = []
cls_gts = []
with torch.no_grad():
    for cls_batch in tqdm(cls_dl, total=len(cls_dl)):
        cls_features, cls_gt = cls_batch['features'], cls_batch['targets_one_hot']
        cls_features = cls_features.to(device)
        cls_logits = cls_model(cls_features).detach().cpu()
#         cls_prob = torch.softmax(cls_logits, 1)[:, 1].numpy()
        cls_prob = torch.softmax(cls_logits, 1).numpy()
        cls_pred = (cls_prob > cls_th).astype(int)
        cls_preds.append(cls_pred)
        cls_probs.append(cls_prob)
        cls_gts.append(cls_gt.numpy())

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))




In [49]:
prob = np.concatenate(cls_probs).argmax(axis=1)
labels = np.vstack(cls_gts).argmax(axis=1)

In [76]:
cls_probs

array([[9.8798966e-01, 7.0864409e-03, 7.1337220e-04, 4.1426835e-03,
        6.7845722e-05]], dtype=float32)

In [56]:
from sklearn.metrics import classification_report, accuracy_score, f1_score

In [57]:
for th in np.arange(0.1, 1.01, 0.1):
    print(f'Th {th:0.2f} Acc {accuracy_score(labels, prob):0.3f}')
    print(classification_report(labels, prob))

Th 0.10 Acc 0.936
              precision    recall  f1-score   support

           0       0.90      0.86      0.88       180
           1       0.79      0.90      0.84        42
           2       0.94      0.93      0.94      1008
           3       0.71      0.72      0.71       103
           4       0.97      0.97      0.97      1179

    accuracy                           0.94      2512
   macro avg       0.86      0.88      0.87      2512
weighted avg       0.94      0.94      0.94      2512

Th 0.20 Acc 0.936
              precision    recall  f1-score   support

           0       0.90      0.86      0.88       180
           1       0.79      0.90      0.84        42
           2       0.94      0.93      0.94      1008
           3       0.71      0.72      0.71       103
           4       0.97      0.97      0.97      1179

    accuracy                           0.94      2512
   macro avg       0.86      0.88      0.87      2512
weighted avg       0.94      0.94      0.

In [46]:
for th in np.arange(0.1, 1.01, 0.1):
    print(f'Th {th:0.2f} Acc {accuracy_score(labels, prob):0.3f}')
    print(classification_report(labels, prob))

Th 0.10 Acc 0.863
              precision    recall  f1-score   support

           0       0.54      0.67      0.59       180
           1       0.62      0.12      0.20        42
           2       0.88      0.91      0.89      1008
           3       0.65      0.54      0.59       103
           4       0.93      0.91      0.92      1179

    accuracy                           0.86      2512
   macro avg       0.72      0.63      0.64      2512
weighted avg       0.87      0.86      0.86      2512

Th 0.20 Acc 0.863
              precision    recall  f1-score   support

           0       0.54      0.67      0.59       180
           1       0.62      0.12      0.20        42
           2       0.88      0.91      0.89      1008
           3       0.65      0.54      0.59       103
           4       0.93      0.91      0.92      1179

    accuracy                           0.86      2512
   macro avg       0.72      0.63      0.64      2512
weighted avg       0.87      0.86      0.

# Full

In [35]:
segm_model = tta_se_resnext101
cls_model = cls_model_tta

In [45]:
device = 'cuda'
cls_th = 0.5
segm_th = 0.5

dices = []
cls_preds = []
with torch.no_grad():
    for cls_batch, segm_batch in tqdm(zip(cls_dl, segm_dl), total=len(cls_dl)):
        cls_features, cls_gt = cls_batch['features'], cls_batch['targets']
        cls_features = cls_features.to(device)
        cls_logits = cls_model(cls_features).detach().cpu()
        cls_probs = torch.softmax(cls_logits, 1).numpy()
        cls_pred = cls_probs.argmax(axis=1) # (cls_probs > cls_th).astype(int)
        cls_preds.append(cls_pred)
        
        segm_features, segm_gt = segm_batch
        segm_features = segm_features.to(device)
        segm_logits = segm_model(segm_features).detach().cpu()
        pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
        pred_masks = pred_masks.astype(np.uint8)
        for img_id in range(len(pred_masks)):
            for j in range(4):
                pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=100)
        
        # clean predicted masks w/o defects
        pred_masks[cls_pred == 4] = 0
                
        batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
        dices.append(batch_dice)
print(np.mean(dices))

HBox(children=(IntProgress(value=0, max=92), HTML(value='')))


0.9774952063689157


In [49]:
device = 'cuda'
cls_th = 0.5
segm_th = 0.5

dices = []
cls_preds = []
with torch.no_grad():
    for cls_batch, segm_batch in tqdm(zip(cls_dl, segm_dl), total=len(cls_dl)):
        cls_features, cls_gt = cls_batch['features'], cls_batch['targets']
        cls_features = cls_features.to(device)
        cls_logits = cls_model(cls_features).detach().cpu()
        cls_probs = torch.softmax(cls_logits, 1).numpy()
        cls_pred = cls_probs.argmax(axis=1) # (cls_probs > cls_th).astype(int)
        cls_preds.append(cls_pred)
        
        segm_features, segm_gt = segm_batch
        segm_features = segm_features.to(device)
        segm_logits = segm_model(segm_features).detach().cpu()
        pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
        pred_masks = pred_masks.astype(np.uint8)
        for img_id in range(len(pred_masks)):
            for j in range(4):
                if cls_pred[img_id] == j:
                    pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=200)
                else:
                    pred_masks[img_id, j] = 0
        
        # clean predicted masks w/o defects
        pred_masks[cls_pred == 4] = 0
                
        batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
        dices.append(batch_dice)
print(np.mean(dices))

HBox(children=(IntProgress(value=0, max=92), HTML(value='')))


0.9767752887125256


In [48]:
print(np.mean(dices))

1.0


In [45]:
device = 'cuda'
cls_th = 0.5
segm_th = 0.5

dices = []
cls_preds = []
with torch.no_grad():
    for cls_batch, segm_batch in tqdm(zip(cls_dl, segm_dl), total=len(cls_dl)):
        cls_features, cls_gt = cls_batch['features'], cls_batch['targets']
        cls_features = cls_features.to(device)
        cls_logits = cls_model(cls_features).detach().cpu()
        cls_probs = torch.softmax(cls_logits, 1).numpy()
        cls_pred = cls_probs.argmax(axis=1) # (cls_probs > cls_th).astype(int)
        cls_preds.append(cls_pred)
        
        segm_features, segm_gt = segm_batch
        segm_features = segm_features.to(device)
        segm_logits = segm_model(segm_features).detach().cpu()
        pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
        pred_masks = pred_masks.astype(np.uint8)
        for img_id in range(len(pred_masks)):
            for j in range(4):
                pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=200)
        
        # clean predicted masks w/o defects
        pred_masks[cls_pred == 4] = 0
                
        batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
        dices.append(batch_dice)
print(np.mean(dices))

HBox(children=(IntProgress(value=0, max=92), HTML(value='')))


0.9774952063689157


In [52]:
device = 'cuda'
cls_th = 0.5
segm_th = 0.5

dices = []
cls_preds = []
with torch.no_grad():
    for cls_batch, segm_batch in tqdm(zip(cls_dl, segm_dl), total=len(cls_dl)):
        cls_features, cls_gt = cls_batch['features'], cls_batch['targets']
        cls_features = cls_features.to(device)
        cls_logits = cls_model(cls_features).detach().cpu()
        cls_probs = torch.softmax(cls_logits, 1).numpy()
        cls_pred = cls_probs.argmax(axis=1) # (cls_probs > cls_th).astype(int)
        cls_preds.append(cls_pred)
        
        segm_features, segm_gt = segm_batch
        segm_features = segm_features.to(device)
        segm_logits = segm_model(segm_features).detach().cpu()
        pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
        pred_masks = pred_masks.astype(np.uint8)
        for img_id in range(len(pred_masks)):
            for j in range(4):
                pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=100)
        
        # clean predicted masks w/o defects
        pred_masks[cls_pred == 4] = 0
                
        batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
        dices.append(batch_dice)
print(np.mean(dices))

HBox(children=(IntProgress(value=0, max=92), HTML(value='')))


0.9776792745706833


In [19]:
device = 'cuda'
cls_th = 0.6
segm_th = 0.5

dices = []
cls_preds = []
with torch.no_grad():
    for cls_batch, segm_batch in tqdm(zip(cls_dl, segm_dl), total=len(cls_dl)):
        cls_features, cls_gt = cls_batch['features'], cls_batch['targets_one_hot']
        cls_features = cls_features.to(device)
        cls_logits = cls_model(cls_features).detach().cpu()
        cls_probs = torch.softmax(cls_logits, 1)[:, 1].numpy()
        cls_pred = (cls_probs > cls_th).astype(int)
        cls_preds.append(cls_pred)
        
        segm_features, segm_gt = segm_batch
        segm_features = segm_features.to(device)
        segm_logits = segm_model(segm_features).detach().cpu()
        pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
        pred_masks = pred_masks.astype(np.uint8)
        for img_id in range(len(pred_masks)):
            for j in range(4):
                pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=200)
        
        # clean predicted masks w/o defects
        pred_masks[cls_pred == 0] = 0
                
        batch_dice = dice_channel_torch(pred_masks, segm_gt.numpy())
        dices.append(batch_dice)
print(np.mean(dices))

HBox(children=(IntProgress(value=0, max=92), HTML(value='')))


0.9752024519282599


In [23]:
def calc_dice(model, dl, th=0.5, device='cuda:0'):
    dices = []
    with torch.no_grad():
        for features, gt in tqdm(dl):
            features = features.to(device)
            logits = model(features).detach().cpu()
            batch_dice = dice_channel_torch(logits, gt, th)
            dices.append(batch_dice)
    return np.mean(dices)

# Submit

In [53]:
test_dataset = TestDataset(
    TEST_IMAGES,
    preprocess_img=preprocessing_fn,
)
test_dl = BaseDataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

In [54]:
# test_dataset = TestClsDataset(TEST_IMAGES, preprocess_img=preprocessing_fn)
# test_dl = BaseDataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [115]:
segm_model = tta_model_ens

In [126]:
device = 'cuda'

min_area = [600, 600, 1000, 2000]

cls_th = 0.5
segm_th = 0.5

cls_preds = []
results = []
all_probs = []
with torch.no_grad():
    for i, batch in enumerate(tqdm(test_dl)):
        img_id_str = test_dataset.img_ids[i]
        
#         cls_features = batch
#         cls_features = cls_features.to(device)
#         cls_logits = cls_model(cls_features).detach().cpu()
#         cls_probs = torch.softmax(cls_logits, 1).numpy()
#         cls_pred = cls_probs.argmax(axis=1) # (cls_probs > cls_th).astype(int)
#         cls_preds.append(cls_pred)
#         all_probs.append(cls_probs)
        
        segm_features = batch
        segm_features = segm_features.to(device)
        segm_logits = segm_model(segm_features).detach().cpu()
        pred_masks = (torch.sigmoid(segm_logits) > segm_th).numpy()
        pred_masks = pred_masks.astype(np.uint8)
        for img_id in range(len(pred_masks)):
            for j in range(4):
                if pred_masks[img_id, j].sum() < min_area[j]:
                    pred_masks[img_id, j] = np.zeros(
                        pred_masks[img_id, j].shape, 
                        dtype=pred_masks[img_id, j].dtype
                    )
                pred_masks[img_id, j] = remove_small_one(pred_masks[img_id, j], min_size=100)
        
        # clean predicted masks w/o defects
#         pred_masks[cls_pred == 4] = 0
        pred_masks = pred_masks.squeeze()
        for ch_id, ch_mask in enumerate(pred_masks):
            results.append({
                'ImageId_ClassId': f'{img_id_str}_{ch_id + 1}',
                'EncodedPixels': mask2rle(ch_mask)
            })

HBox(children=(IntProgress(value=0, max=1801), HTML(value='')))




In [124]:
pickle.dump(list(zip(test_dataset.img_ids, all_probs)), open('id2probs.pkl', 'wb'))

In [None]:
import pickle

In [144]:
submit_df = pd.DataFrame(results).set_index('ImageId_ClassId')
submit_df.to_csv('submits/super-ensemble-all.csv', index=True)

In [140]:
not_empty_ids = pickle.load(open('/home/denilv/Downloads/Telegram Desktop/non_emplty_image.pkl', 'rb'))
submit_df.loc[submit_df.index.difference(not_empty_ids), 'EncodedPixels'] = ''
submit_df.to_csv('submits/super-ensemble-all-wo-empty.csv', index=True)

In [145]:
empty_ids = list(pickle.load(open('/home/denilv/Downloads/Telegram Desktop/empty_ids.pkl', 'rb')))

In [146]:
aaa = []
for i in empty_ids:
    for j in range(1, 5):
        aaa.append(f'{i}_{j}')

In [147]:
submit_df.loc[aaa, 'EncodedPixels'] = ''
submit_df.to_csv('submits/super-ensemble-all-wo-empty2.csv', index=True)

In [19]:
# df = pd.read_csv('submits/23102019-0112.csv', index_col='ImageId_ClassId').fillna('')

# not_empty_ids = pickle.load(open('/home/denilv/Downloads/Telegram Desktop/non_emplty_image.pkl', 'rb'))

# df.loc[df.index.difference(not_empty_ids), 'EncodedPixels'] = ''

# df.to_csv('submits/23102019-0210.csv', index=True)