In [None]:
%load_ext autoreload
%autoreload
from factory import *
import torch
from model import Unet
from catalyst.dl.callbacks import CriterionCallback, EarlyStoppingCallback, AUCCallback, AccuracyCallback, F1ScoreCallback
from catalyst.dl.runner import SupervisedRunner
from pytorch_toolbelt import losses as L
import collections
from pytorch_toolbelt.utils.catalyst import * 
from metrics import *
import custom_tta as tta
from pytorch_toolbelt.inference.tiles import *
import matplotlib.pyplot as plt
from viz_utils import *
from tqdm import tqdm
%matplotlib inline

In [None]:
import pixiedust

In [4]:
encoder_name = 'resnet50'
sample_submission_path = 'data/sample_submission.csv'
train_df_path = 'data/train.csv'
data_folder = "data/train_images/"
test_data_folder = "data/test_images/"
base_exp_name = '{}_with_corase_matrix_deeper_with_class'.format(encoder_name)
log_dir = 'logs/{}_pretrain/'.format(base_exp_name)
batch_size = 32
batch_size_val = 16
crop_size = 256
num_workers = 16
num_epochs_with_frozen_encoder = 5
num_epochs = 100
tta_type = None
output_channels = 5

In [5]:
!rm -r logs/resnet50_with_corase_matrix_deeper_with_class_pretrain/

In [6]:
dataloader_train = provider(
    data_folder=data_folder,
    df_path=train_df_path,
    phase='train',
    transforms=medium_augmentations(crop_size),
    batch_size=batch_size,
    num_workers=num_workers, 
    prepare_coarse=True, prepare_edges=False, prepare_class=True)
dataloader_val = provider(
    data_folder=data_folder,
    df_path=train_df_path,
    phase='val',
    transforms=validation_augmentations(),
    batch_size=batch_size_val,
    num_workers=num_workers, 
    prepare_coarse=True, prepare_edges=False, prepare_class=True)

In [7]:
loaders = collections.OrderedDict()
loaders["train"] = dataloader_train
loaders["valid"] = dataloader_val
runner = SupervisedRunner(input_key = 'features',
                          output_key =  None,
                          input_target_key = None)

In [9]:
model = Unet(classes=output_channels, 
             encoder_name=encoder_name,
             encoder_weights='imagenet')
loss_f_segmentation = get_loss('bce_lovasz')
loss_f_classification = nn.BCEWithLogitsLoss()
losses = dict({'loss_f_segmentation':loss_f_segmentation, 
               'loss_f_classification':loss_f_classification})
optimizer = get_optimizer('radam', model.parameters(), lr = 1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=[10, 30, 50, 70, 90],
                                                 gamma=0.5)

In [10]:
if tta_type=='flipr':
    model = tta.TTAWrapper(model, tta.fliplr_image2mask)
if tta_type=='d4':
    model = tta.TTAWrapper(model, tta.d4_image2mask)    

Visualize the model

In [None]:
#inputs = torch.randn(1,3,256,256)
#y = model(Variable(inputs))['logits_coarse']
#g = make_dot(y, model.state_dict())
#g.view()

In [None]:
#data_b = next(iter(dataloader_train))
#output = model(dataloader_train[''])
#loss_f_classification(output['class_logits'], data_b['classification'].float())

In [11]:
for param in model.encoder.parameters():
    param.requires_grad = False

In [12]:
#%%pixie_debugger
runner.train(
    model=model,
    criterion=losses,
    optimizer=optimizer,
    callbacks=[
        CriterionCallback(input_key="targets",
                     output_key="logits",
                     prefix="loss",
                     criterion_key='loss_f_segmentation'),
        CriterionCallback(input_key="coarse_targets",
                          output_key="logits_coarse",
                          prefix="loss_coarse",
                          criterion_key='loss_f_segmentation', multiplier = 0.5),
        CriterionCallback(input_key="classification",
                          output_key="class_logits",
                          prefix="loss_classification",
                          criterion_key='loss_f_classification', multiplier = 0.5),
        JaccardScoreCallback(mode='multilabel',
                             input_key='targets', 
                             output_key="logits",
                             prefix='jaccard_targets'),
        JaccardScoreCallback(mode='multilabel',
                             input_key='coarse_targets', 
                             output_key="logits_coarse",
                             prefix='jaccard_targets_coarse'),
        AUCCallback(input_key =  'classification',
                    output_key = 'class_logits'),
        F1ScoreCallback(
            input_key = 'classification',
            output_key =  'class_logits',
            prefix = 'f1_score',
            activation = 'Sigmoid'),
        JaccardMetricPerImage(),
        OptimalThreshold(),
    ],
    loaders=loaders,
    logdir=log_dir,
    num_epochs=num_epochs_with_frozen_encoder,
    verbose=True
)

0/5 * Epoch (train): 100% 315/315 [06:41<00:00,  1.27s/it, _timers/_fps=2541.184, f1_score=0.390, jaccard_targets=0.266, jaccard_targets_coarse=0.231, loss=2.303, loss_classification=0.400, loss_coarse=0.732]
0/5 * Epoch (valid): 100% 158/158 [06:53<00:00,  2.62s/it, _timers/_fps=2114.264, f1_score=0.583, jaccard_targets=0.018, jaccard_targets_coarse=0.018, loss=2.759, loss_classification=0.446, loss_coarse=0.955]
[2019-10-01 05:45:58,085] 
0/5 * Epoch 0 (train): _base/lr=1.000e-05 | _base/momentum=0.9000 | _timers/_fps=1996.7161 | _timers/batch_time=0.0261 | _timers/data_time=0.0089 | _timers/model_time=0.0171 | auc/_mean=0.5240 | auc/class_0=0.5240 | f1_score=0.4403 | jaccard=0.0333 | jaccard_targets=0.1445 | jaccard_targets_class_0=0.0441 | jaccard_targets_class_1=0.0471 | jaccard_targets_class_2=0.1443 | jaccard_targets_class_3=0.2076 | jaccard_targets_class_4=0.1536 | jaccard_targets_coarse=0.1146 | jaccard_targets_coarse_class_0=0.0503 | jaccard_targets_coarse_class_1=0.0429 | ja

4/5 * Epoch (train): 100% 315/315 [06:37<00:00,  1.26s/it, _timers/_fps=2943.823, f1_score=0.458, jaccard_targets=0.097, jaccard_targets_coarse=0.100, loss=1.494, loss_classification=0.429, loss_coarse=0.903]
4/5 * Epoch (valid): 100% 158/158 [06:49<00:00,  2.59s/it, _timers/_fps=2141.453, f1_score=0.596, jaccard_targets=0.021, jaccard_targets_coarse=0.018, loss=1.395, loss_classification=0.430, loss_coarse=0.940]
[2019-10-01 06:40:06,047] 
4/5 * Epoch 4 (train): _base/lr=1.000e-05 | _base/momentum=0.9000 | _timers/_fps=2067.2893 | _timers/batch_time=0.0242 | _timers/data_time=0.0086 | _timers/model_time=0.0155 | auc/_mean=0.5223 | auc/class_0=0.5223 | f1_score=0.4405 | jaccard=0.0303 | jaccard_targets=0.1432 | jaccard_targets_class_0=0.0439 | jaccard_targets_class_1=0.0519 | jaccard_targets_class_2=0.1619 | jaccard_targets_class_3=0.2330 | jaccard_targets_class_4=0.1362 | jaccard_targets_coarse=0.1175 | jaccard_targets_coarse_class_0=0.0577 | jaccard_targets_coarse_class_1=0.0395 | ja

In [13]:
for param in model.encoder.parameters():
    param.requires_grad = True

In [14]:
#%%pixie_debugger
runner.train(
    model=model,
    criterion=losses,
    optimizer=optimizer,
    scheduler=scheduler,
    callbacks=[
        CriterionCallback(input_key="targets",
                     output_key="logits",
                     prefix="loss",
                     criterion_key='loss_f_segmentation'),
        CriterionCallback(input_key="coarse_targets",
                          output_key="logits_coarse",
                          prefix="loss_coarse",
                          criterion_key='loss_f_segmentation', multiplier = 1.5),
        CriterionCallback(input_key="classification",
                          output_key="class_logits",
                          prefix="loss_classification",
                          criterion_key='loss_f_classification', multiplier = 1.0),
        JaccardScoreCallback(mode='multilabel',
                             input_key='targets', 
                             output_key="logits",
                             prefix='jaccard_targets'),
        JaccardScoreCallback(mode='multilabel',
                             input_key='coarse_targets', 
                             output_key="logits_coarse",
                             prefix='jaccard_targets_coarse'),
        AUCCallback(input_key =  'classification',
                    output_key = 'class_logits'),
        F1ScoreCallback(
            input_key = 'classification',
            output_key =  'class_logits',
            prefix = 'f1_score',
            activation = 'Sigmoid'),
        JaccardMetricPerImage(),
        OptimalThreshold(),
        EarlyStoppingCallback(25, metric='loss', minimize=True)
    ],
    loaders=loaders,
    logdir=log_dir,
    num_epochs=num_epochs,
    verbose=True
)

0/100 * Epoch (train): 100% 315/315 [07:49<00:00,  1.49s/it, _timers/_fps=2410.649, f1_score=0.383, jaccard_targets=0.084, jaccard_targets_coarse=0.081, loss=1.410, loss_classification=0.825, loss_coarse=2.816]
0/100 * Epoch (valid): 100% 158/158 [06:49<00:00,  2.59s/it, _timers/_fps=2087.757, f1_score=0.588, jaccard_targets=0.026, jaccard_targets_coarse=0.019, loss=1.658, loss_classification=0.879, loss_coarse=3.584]
[2019-10-01 07:18:26,564] 
0/100 * Epoch 6 (train): _base/lr=1.000e-05 | _base/momentum=0.9000 | _timers/_fps=2197.7534 | _timers/batch_time=0.0233 | _timers/data_time=0.0081 | _timers/model_time=0.0152 | auc/_mean=0.5220 | auc/class_0=0.5220 | f1_score=0.4396 | jaccard=0.0275 | jaccard_targets=0.1416 | jaccard_targets_class_0=0.0428 | jaccard_targets_class_1=0.0538 | jaccard_targets_class_2=0.1683 | jaccard_targets_class_3=0.2275 | jaccard_targets_class_4=0.1282 | jaccard_targets_coarse=0.1179 | jaccard_targets_coarse_class_0=0.0586 | jaccard_targets_coarse_class_1=0.039

4/100 * Epoch (train): 100% 315/315 [07:46<00:00,  1.48s/it, _timers/_fps=2871.827, f1_score=0.546, jaccard_targets=0.065, jaccard_targets_coarse=0.077, loss=1.654, loss_classification=0.703, loss_coarse=2.981]
4/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.59s/it, _timers/_fps=1986.174, f1_score=0.584, jaccard_targets=0.043, jaccard_targets_coarse=0.019, loss=2.118, loss_classification=0.888, loss_coarse=5.496]
[2019-10-01 08:17:16,769] 
4/100 * Epoch 10 (train): _base/lr=1.000e-05 | _base/momentum=0.9000 | _timers/_fps=2326.7140 | _timers/batch_time=0.0230 | _timers/data_time=0.0089 | _timers/model_time=0.0140 | auc/_mean=0.5291 | auc/class_0=0.5291 | f1_score=0.4430 | jaccard=0.0332 | jaccard_targets=0.1499 | jaccard_targets_class_0=0.0359 | jaccard_targets_class_1=0.0156 | jaccard_targets_class_2=0.1663 | jaccard_targets_class_3=0.0823 | jaccard_targets_class_4=0.1654 | jaccard_targets_coarse=0.1226 | jaccard_targets_coarse_class_0=0.0571 | jaccard_targets_coarse_class_1=0.04

8/100 * Epoch (train): 100% 315/315 [07:43<00:00,  1.47s/it, _timers/_fps=3020.065, f1_score=0.195, jaccard_targets=0.350, jaccard_targets_coarse=0.293, loss=0.863, loss_classification=0.621, loss_coarse=2.017]
8/100 * Epoch (valid): 100% 158/158 [06:49<00:00,  2.59s/it, _timers/_fps=2098.923, f1_score=0.588, jaccard_targets=0.091, jaccard_targets_coarse=0.017, loss=0.868, loss_classification=0.878, loss_coarse=2.781]
[2019-10-01 09:15:58,908] 
8/100 * Epoch 14 (train): _base/lr=1.000e-05 | _base/momentum=0.9000 | _timers/_fps=2550.4485 | _timers/batch_time=0.0209 | _timers/data_time=0.0086 | _timers/model_time=0.0122 | auc/_mean=0.5144 | auc/class_0=0.5144 | f1_score=0.4387 | jaccard=0.0795 | jaccard_targets=0.1969 | jaccard_targets_class_0=0.0199 | jaccard_targets_class_1=0.0023 | jaccard_targets_class_2=0.2405 | jaccard_targets_class_3=0.0900 | jaccard_targets_class_4=0.2054 | jaccard_targets_coarse=0.1143 | jaccard_targets_coarse_class_0=0.0599 | jaccard_targets_coarse_class_1=0.03

12/100 * Epoch (train): 100% 315/315 [07:42<00:00,  1.47s/it, _timers/_fps=3031.251, f1_score=0.560, jaccard_targets=0.147, jaccard_targets_coarse=0.069, loss=0.881, loss_classification=0.817, loss_coarse=2.294]
12/100 * Epoch (valid): 100% 158/158 [06:49<00:00,  2.59s/it, _timers/_fps=2064.825, f1_score=0.577, jaccard_targets=0.084, jaccard_targets_coarse=0.014, loss=0.811, loss_classification=0.906, loss_coarse=2.186]
[2019-10-01 10:14:19,112] 
12/100 * Epoch 18 (train): _base/lr=5.000e-06 | _base/momentum=0.9000 | _timers/_fps=2577.2133 | _timers/batch_time=0.0204 | _timers/data_time=0.0084 | _timers/model_time=0.0119 | auc/_mean=0.5223 | auc/class_0=0.5223 | f1_score=0.4412 | jaccard=0.1126 | jaccard_targets=0.2375 | jaccard_targets_class_0=0.0137 | jaccard_targets_class_1=0.0002 | jaccard_targets_class_2=0.2907 | jaccard_targets_class_3=0.0472 | jaccard_targets_class_4=0.2563 | jaccard_targets_coarse=0.1076 | jaccard_targets_coarse_class_0=0.0632 | jaccard_targets_coarse_class_1=0

16/100 * Epoch (train): 100% 315/315 [07:43<00:00,  1.47s/it, _timers/_fps=3006.602, f1_score=0.176, jaccard_targets=0.941, jaccard_targets_coarse=0.438, loss=1.287, loss_classification=0.606, loss_coarse=2.045]
16/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.59s/it, _timers/_fps=1968.348, f1_score=0.572, jaccard_targets=0.092, jaccard_targets_coarse=0.010, loss=0.805, loss_classification=0.920, loss_coarse=1.830]
[2019-10-01 11:12:50,812] 
16/100 * Epoch 22 (train): _base/lr=5.000e-06 | _base/momentum=0.9000 | _timers/_fps=2465.8381 | _timers/batch_time=0.0212 | _timers/data_time=0.0080 | _timers/model_time=0.0131 | auc/_mean=0.5101 | auc/class_0=0.5101 | f1_score=0.4374 | jaccard=0.1355 | jaccard_targets=0.2740 | jaccard_targets_class_0=0.0115 | jaccard_targets_class_1=3.752e-05 | jaccard_targets_class_2=0.3170 | jaccard_targets_class_3=0.0907 | jaccard_targets_class_4=0.2997 | jaccard_targets_coarse=0.1079 | jaccard_targets_coarse_class_0=0.0641 | jaccard_targets_coarse_class_

20/100 * Epoch (train): 100% 315/315 [07:42<00:00,  1.47s/it, _timers/_fps=3008.961, f1_score=0.339, jaccard_targets=0.354, jaccard_targets_coarse=0.107, loss=0.881, loss_classification=0.648, loss_coarse=2.487]
20/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.59s/it, _timers/_fps=2071.133, f1_score=0.573, jaccard_targets=0.158, jaccard_targets_coarse=0.011, loss=0.772, loss_classification=0.917, loss_coarse=2.020]
[2019-10-01 12:11:13,236] 
20/100 * Epoch 26 (train): _base/lr=5.000e-06 | _base/momentum=0.9000 | _timers/_fps=2563.0273 | _timers/batch_time=0.0212 | _timers/data_time=0.0090 | _timers/model_time=0.0121 | auc/_mean=0.5173 | auc/class_0=0.5173 | f1_score=0.4403 | jaccard=0.1517 | jaccard_targets=0.2970 | jaccard_targets_class_0=0.0109 | jaccard_targets_class_1=2.638e-05 | jaccard_targets_class_2=0.3404 | jaccard_targets_class_3=0.1310 | jaccard_targets_class_4=0.3301 | jaccard_targets_coarse=0.1035 | jaccard_targets_coarse_class_0=0.0650 | jaccard_targets_coarse_class_

24/100 * Epoch (train): 100% 315/315 [07:47<00:00,  1.48s/it, _timers/_fps=3041.762, f1_score=0.464, jaccard_targets=0.352, jaccard_targets_coarse=0.175, loss=0.801, loss_classification=0.866, loss_coarse=1.732]
24/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.59s/it, _timers/_fps=2100.434, f1_score=0.582, jaccard_targets=0.262, jaccard_targets_coarse=0.011, loss=0.745, loss_classification=0.892, loss_coarse=1.863]
[2019-10-01 13:09:41,244] 
24/100 * Epoch 30 (train): _base/lr=5.000e-06 | _base/momentum=0.9000 | _timers/_fps=2499.6940 | _timers/batch_time=0.0217 | _timers/data_time=0.0089 | _timers/model_time=0.0127 | auc/_mean=0.5138 | auc/class_0=0.5138 | f1_score=0.4394 | jaccard=0.1700 | jaccard_targets=0.3285 | jaccard_targets_class_0=0.0105 | jaccard_targets_class_1=4.866e-05 | jaccard_targets_class_2=0.3796 | jaccard_targets_class_3=0.1831 | jaccard_targets_class_4=0.3579 | jaccard_targets_coarse=0.1000 | jaccard_targets_coarse_class_0=0.0621 | jaccard_targets_coarse_class_

28/100 * Epoch (train): 100% 315/315 [07:40<00:00,  1.46s/it, _timers/_fps=2947.184, f1_score=0.246, jaccard_targets=0.351, jaccard_targets_coarse=0.178, loss=1.219, loss_classification=0.859, loss_coarse=2.301]
28/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.58s/it, _timers/_fps=2123.094, f1_score=0.573, jaccard_targets=0.177, jaccard_targets_coarse=0.010, loss=0.757, loss_classification=0.916, loss_coarse=1.894]
[2019-10-01 14:08:05,971] 
28/100 * Epoch 34 (train): _base/lr=5.000e-06 | _base/momentum=0.9000 | _timers/_fps=2526.9179 | _timers/batch_time=0.0209 | _timers/data_time=0.0081 | _timers/model_time=0.0127 | auc/_mean=0.5226 | auc/class_0=0.5226 | f1_score=0.4395 | jaccard=0.1827 | jaccard_targets=0.3501 | jaccard_targets_class_0=0.0068 | jaccard_targets_class_1=1.273e-05 | jaccard_targets_class_2=0.3997 | jaccard_targets_class_3=0.2564 | jaccard_targets_class_4=0.3806 | jaccard_targets_coarse=0.0948 | jaccard_targets_coarse_class_0=0.0605 | jaccard_targets_coarse_class_

32/100 * Epoch (train): 100% 315/315 [07:42<00:00,  1.47s/it, _timers/_fps=3010.176, f1_score=0.242, jaccard_targets=0.466, jaccard_targets_coarse=0.079, loss=0.877, loss_classification=0.572, loss_coarse=3.325]
32/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.59s/it, _timers/_fps=2111.470, f1_score=0.574, jaccard_targets=0.276, jaccard_targets_coarse=0.009, loss=0.726, loss_classification=0.914, loss_coarse=1.945]
[2019-10-01 15:06:23,818] 
32/100 * Epoch 38 (train): _base/lr=2.500e-06 | _base/momentum=0.9000 | _timers/_fps=2516.2012 | _timers/batch_time=0.0218 | _timers/data_time=0.0090 | _timers/model_time=0.0127 | auc/_mean=0.5142 | auc/class_0=0.5142 | f1_score=0.4391 | jaccard=0.1887 | jaccard_targets=0.3615 | jaccard_targets_class_0=0.0081 | jaccard_targets_class_1=6.149e-06 | jaccard_targets_class_2=0.4120 | jaccard_targets_class_3=0.2719 | jaccard_targets_class_4=0.3902 | jaccard_targets_coarse=0.0919 | jaccard_targets_coarse_class_0=0.0571 | jaccard_targets_coarse_class_

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



39/100 * Epoch (train): 100% 315/315 [07:42<00:00,  1.47s/it, _timers/_fps=3051.651, f1_score=0.503, jaccard_targets=0.245, jaccard_targets_coarse=0.107, loss=0.776, loss_classification=0.600, loss_coarse=2.066]
39/100 * Epoch (valid): 100% 158/158 [06:48<00:00,  2.59s/it, _timers/_fps=1829.726, f1_score=0.581, jaccard_targets=0.262, jaccard_targets_coarse=0.009, loss=0.722, loss_classification=0.897, loss_coarse=2.005]
[2019-10-01 16:48:16,034] 
39/100 * Epoch 45 (train): _base/lr=2.500e-06 | _base/momentum=0.9000 | _timers/_fps=2599.3547 | _timers/batch_time=0.0202 | _timers/data_time=0.0082 | _timers/model_time=0.0119 | auc/_mean=0.5159 | auc/class_0=0.5159 | f1_score=0.4397 | jaccard=0.1988 | jaccard_targets=0.3790 | jaccard_targets_class_0=0.0077 | jaccard_targets_class_1=3.876e-05 | jaccard_targets_class_2=0.4312 | jaccard_targets_class_3=0.3332 | jaccard_targets_class_4=0.4066 | jaccard_targets_coarse=0.0920 | jaccard_targets_coarse_class_0=0.0593 | jaccard_targets_coarse_class_

Load best epochs

In [15]:
model.load_state_dict(torch.load(os.path.join(log_dir,'checkpoints/best.pth'))['model_state_dict'])
model.cuda()
model.eval()

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

In [16]:
use_tiles_predictions = True
tile_step = 56
target_shape = (256, 1600)

In [None]:
train_df, val_df = return_masks(train_df_path)
images_id = []
predictions = []
if use_tiles_predictions:
        tiler = ImageSlicer(target_shape, tile_size=(crop_size, crop_size),
                        tile_step=(tile_step, tile_step), weight='mean')
for image_idx in tqdm(range(len(val_df.index.values))):
    image = cv2.imread(os.path.join(data_folder, val_df.index.values[image_idx]))
    augmented  = validation_augmentations()(image=image)
    image_processed = augmented['image']
    if use_tiles_predictions:
        merger = CudaTileMerger(tiler.target_shape, output_channels, tiler.weight)
        tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image_processed)]
        for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), 
                                                    batch_size=batch_size,
                                                    pin_memory=True):
            tiles_batch = tiles_batch.float().cuda()
            with torch.no_grad():
                pred_batch = nn.Sigmoid()(model(tiles_batch)['logits'])
            merger.integrate_batch(pred_batch, coords_batch)
        merged_mask = np.moveaxis(merger.merge().cpu().numpy(), 0, -1)
    else:
        with torch.no_grad():
            image_processed = image_processed.transpose((2, 0, 1))
            image_processed = torch.from_numpy(np.expand_dims(image_processed,0)).cuda()
            #take only logits
            merged_mask = nn.Sigmoid()(model(image_processed)['logits'])
            merged_mask = np.moveaxis((merged_mask[0].cpu().numpy()), 0, -1)  
    predictions.append(merged_mask)
    images_id.append(val_df.index.values[image_idx])


  0%|          | 0/2514 [00:00<?, ?it/s][A
  0%|          | 1/2514 [00:00<06:28,  6.47it/s][A
  0%|          | 2/2514 [00:00<06:24,  6.54it/s][A
  0%|          | 3/2514 [00:00<06:27,  6.49it/s][A
  0%|          | 4/2514 [00:00<06:24,  6.53it/s][A
  0%|          | 5/2514 [00:00<06:22,  6.57it/s][A
  0%|          | 6/2514 [00:00<06:20,  6.60it/s][A
  0%|          | 7/2514 [00:01<06:18,  6.62it/s][A
  0%|          | 8/2514 [00:01<06:28,  6.44it/s][A
  0%|          | 9/2514 [00:01<06:27,  6.46it/s][A
  0%|          | 10/2514 [00:01<06:24,  6.50it/s][A
  0%|          | 11/2514 [00:01<06:22,  6.55it/s][A
  0%|          | 12/2514 [00:01<06:20,  6.58it/s][A
  1%|          | 13/2514 [00:01<06:18,  6.60it/s][A
  1%|          | 14/2514 [00:02<06:19,  6.59it/s][A
  1%|          | 15/2514 [00:02<06:18,  6.60it/s][A
  1%|          | 16/2514 [00:02<06:18,  6.60it/s][A
  1%|          | 17/2514 [00:02<06:18,  6.59it/s][A
  1%|          | 18/2514 [00:02<06:17,  6.61it/s][A
  1%|     

  6%|▌         | 153/2514 [00:23<06:04,  6.48it/s][A
  6%|▌         | 154/2514 [00:23<06:01,  6.52it/s][A
  6%|▌         | 155/2514 [00:23<06:01,  6.53it/s][A
  6%|▌         | 156/2514 [00:23<05:58,  6.57it/s][A
  6%|▌         | 157/2514 [00:23<05:57,  6.60it/s][A
  6%|▋         | 158/2514 [00:23<05:57,  6.60it/s][A
  6%|▋         | 159/2514 [00:23<05:56,  6.61it/s][A
  6%|▋         | 160/2514 [00:24<05:56,  6.61it/s][A
  6%|▋         | 161/2514 [00:24<05:54,  6.63it/s][A
  6%|▋         | 162/2514 [00:24<05:55,  6.61it/s][A
  6%|▋         | 163/2514 [00:24<05:55,  6.62it/s][A
  7%|▋         | 164/2514 [00:24<05:55,  6.60it/s][A
  7%|▋         | 165/2514 [00:24<05:55,  6.60it/s][A
  7%|▋         | 166/2514 [00:25<05:56,  6.59it/s][A
  7%|▋         | 167/2514 [00:25<05:54,  6.62it/s][A
  7%|▋         | 168/2514 [00:25<05:53,  6.63it/s][A
  7%|▋         | 169/2514 [00:25<05:53,  6.64it/s][A
  7%|▋         | 170/2514 [00:25<05:54,  6.61it/s][A
  7%|▋         | 171/2514 [0

 12%|█▏        | 304/2514 [00:45<05:32,  6.64it/s][A
 12%|█▏        | 305/2514 [00:45<05:32,  6.65it/s][A
 12%|█▏        | 306/2514 [00:46<05:31,  6.66it/s][A
 12%|█▏        | 307/2514 [00:46<05:31,  6.65it/s][A
 12%|█▏        | 308/2514 [00:46<05:33,  6.61it/s][A
 12%|█▏        | 309/2514 [00:46<05:32,  6.63it/s][A
 12%|█▏        | 310/2514 [00:46<05:32,  6.64it/s][A
 12%|█▏        | 311/2514 [00:46<05:31,  6.64it/s][A
 12%|█▏        | 312/2514 [00:47<05:31,  6.64it/s][A
 12%|█▏        | 313/2514 [00:47<05:31,  6.65it/s][A
 12%|█▏        | 314/2514 [00:47<05:30,  6.65it/s][A
 13%|█▎        | 315/2514 [00:47<05:30,  6.66it/s][A
 13%|█▎        | 316/2514 [00:47<05:30,  6.65it/s][A
 13%|█▎        | 317/2514 [00:47<05:30,  6.65it/s][A
 13%|█▎        | 318/2514 [00:47<05:30,  6.63it/s][A
 13%|█▎        | 319/2514 [00:48<05:30,  6.64it/s][A
 13%|█▎        | 320/2514 [00:48<05:31,  6.62it/s][A
 13%|█▎        | 321/2514 [00:48<05:32,  6.60it/s][A
 13%|█▎        | 322/2514 [0

 18%|█▊        | 455/2514 [01:08<05:10,  6.63it/s][A
 18%|█▊        | 456/2514 [01:08<05:09,  6.65it/s][A
 18%|█▊        | 457/2514 [01:08<05:09,  6.65it/s][A
 18%|█▊        | 458/2514 [01:09<05:08,  6.66it/s][A
 18%|█▊        | 459/2514 [01:09<05:10,  6.62it/s][A
 18%|█▊        | 460/2514 [01:09<05:13,  6.55it/s][A
 18%|█▊        | 461/2514 [01:09<05:12,  6.57it/s][A
 18%|█▊        | 462/2514 [01:09<05:11,  6.59it/s][A
 18%|█▊        | 463/2514 [01:09<05:13,  6.54it/s][A
 18%|█▊        | 464/2514 [01:09<05:13,  6.55it/s][A
 18%|█▊        | 465/2514 [01:10<05:11,  6.58it/s][A
 19%|█▊        | 466/2514 [01:10<05:11,  6.57it/s][A
 19%|█▊        | 467/2514 [01:10<05:10,  6.60it/s][A
 19%|█▊        | 468/2514 [01:10<05:09,  6.61it/s][A
 19%|█▊        | 469/2514 [01:10<05:08,  6.63it/s][A
 19%|█▊        | 470/2514 [01:10<05:07,  6.64it/s][A
 19%|█▊        | 471/2514 [01:11<05:07,  6.64it/s][A
 19%|█▉        | 472/2514 [01:11<05:06,  6.65it/s][A
 19%|█▉        | 473/2514 [0

 24%|██▍       | 606/2514 [01:31<04:46,  6.66it/s][A
 24%|██▍       | 607/2514 [01:31<04:47,  6.64it/s][A
 24%|██▍       | 608/2514 [01:31<04:47,  6.64it/s][A
 24%|██▍       | 609/2514 [01:31<04:47,  6.63it/s][A
 24%|██▍       | 610/2514 [01:31<04:47,  6.63it/s][A
 24%|██▍       | 611/2514 [01:32<04:47,  6.62it/s][A
 24%|██▍       | 612/2514 [01:32<04:46,  6.64it/s][A
 24%|██▍       | 613/2514 [01:32<04:46,  6.65it/s][A
 24%|██▍       | 614/2514 [01:32<04:46,  6.63it/s][A
 24%|██▍       | 615/2514 [01:32<04:46,  6.64it/s][A
 25%|██▍       | 616/2514 [01:32<04:45,  6.65it/s][A
 25%|██▍       | 617/2514 [01:33<04:44,  6.67it/s][A
 25%|██▍       | 618/2514 [01:33<04:43,  6.68it/s][A
 25%|██▍       | 619/2514 [01:33<04:43,  6.68it/s][A
 25%|██▍       | 620/2514 [01:33<04:45,  6.64it/s][A
 25%|██▍       | 621/2514 [01:33<04:45,  6.63it/s][A
 25%|██▍       | 622/2514 [01:33<04:44,  6.65it/s][A
 25%|██▍       | 623/2514 [01:33<04:44,  6.65it/s][A
 25%|██▍       | 624/2514 [0

 30%|███       | 757/2514 [01:54<04:27,  6.56it/s][A
 30%|███       | 758/2514 [01:54<04:25,  6.62it/s][A
 30%|███       | 759/2514 [01:54<04:24,  6.64it/s][A
 30%|███       | 760/2514 [01:54<04:23,  6.66it/s][A
 30%|███       | 761/2514 [01:54<04:22,  6.67it/s][A
 30%|███       | 762/2514 [01:54<04:22,  6.67it/s][A
 30%|███       | 763/2514 [01:54<04:22,  6.66it/s][A
 30%|███       | 764/2514 [01:55<04:22,  6.68it/s][A
 30%|███       | 765/2514 [01:55<04:22,  6.65it/s][A
 30%|███       | 766/2514 [01:55<04:22,  6.66it/s][A
 31%|███       | 767/2514 [01:55<04:23,  6.63it/s][A
 31%|███       | 768/2514 [01:55<04:21,  6.67it/s][A
 31%|███       | 769/2514 [01:55<04:21,  6.67it/s][A
 31%|███       | 770/2514 [01:56<04:21,  6.68it/s][A
 31%|███       | 771/2514 [01:56<04:21,  6.67it/s][A
 31%|███       | 772/2514 [01:56<04:20,  6.68it/s][A
 31%|███       | 773/2514 [01:56<04:21,  6.66it/s][A
 31%|███       | 774/2514 [01:56<04:21,  6.64it/s][A
 31%|███       | 775/2514 [0

 36%|███▌      | 908/2514 [02:16<04:00,  6.67it/s][A
 36%|███▌      | 909/2514 [02:16<04:00,  6.67it/s][A
 36%|███▌      | 910/2514 [02:16<04:00,  6.67it/s][A
 36%|███▌      | 911/2514 [02:17<03:59,  6.69it/s][A
 36%|███▋      | 912/2514 [02:17<03:59,  6.69it/s][A
 36%|███▋      | 913/2514 [02:17<04:00,  6.67it/s][A
 36%|███▋      | 914/2514 [02:17<03:59,  6.67it/s][A
 36%|███▋      | 915/2514 [02:17<03:59,  6.69it/s][A
 36%|███▋      | 916/2514 [02:17<03:58,  6.71it/s][A
 36%|███▋      | 917/2514 [02:18<03:58,  6.71it/s][A
 37%|███▋      | 918/2514 [02:18<03:58,  6.70it/s][A
 37%|███▋      | 919/2514 [02:18<03:58,  6.68it/s][A
 37%|███▋      | 920/2514 [02:18<03:57,  6.70it/s][A
 37%|███▋      | 921/2514 [02:18<03:57,  6.72it/s][A
 37%|███▋      | 922/2514 [02:18<03:58,  6.69it/s][A
 37%|███▋      | 923/2514 [02:18<03:59,  6.65it/s][A
 37%|███▋      | 924/2514 [02:19<03:58,  6.65it/s][A
 37%|███▋      | 925/2514 [02:19<03:58,  6.66it/s][A
 37%|███▋      | 926/2514 [0

 42%|████▏     | 1058/2514 [02:39<03:38,  6.67it/s][A
 42%|████▏     | 1059/2514 [02:39<03:37,  6.68it/s][A
 42%|████▏     | 1060/2514 [02:39<03:36,  6.71it/s][A
 42%|████▏     | 1061/2514 [02:39<03:37,  6.70it/s][A
 42%|████▏     | 1062/2514 [02:39<03:36,  6.70it/s][A
 42%|████▏     | 1063/2514 [02:39<03:36,  6.71it/s][A
 42%|████▏     | 1064/2514 [02:40<03:36,  6.71it/s][A
 42%|████▏     | 1065/2514 [02:40<03:36,  6.70it/s][A
 42%|████▏     | 1066/2514 [02:40<03:35,  6.71it/s][A
 42%|████▏     | 1067/2514 [02:40<03:35,  6.70it/s][A
 42%|████▏     | 1068/2514 [02:40<03:35,  6.70it/s][A
 43%|████▎     | 1069/2514 [02:40<03:35,  6.69it/s][A
 43%|████▎     | 1070/2514 [02:40<03:35,  6.71it/s][A
 43%|████▎     | 1071/2514 [02:41<03:35,  6.70it/s][A
 43%|████▎     | 1072/2514 [02:41<03:34,  6.72it/s][A
 43%|████▎     | 1073/2514 [02:41<03:34,  6.71it/s][A
 43%|████▎     | 1074/2514 [02:41<03:35,  6.69it/s][A
 43%|████▎     | 1075/2514 [02:41<03:35,  6.69it/s][A
 43%|████▎

 48%|████▊     | 1206/2514 [03:01<03:15,  6.70it/s][A
 48%|████▊     | 1207/2514 [03:01<03:15,  6.68it/s][A
 48%|████▊     | 1208/2514 [03:01<03:16,  6.65it/s][A
 48%|████▊     | 1209/2514 [03:01<03:15,  6.67it/s][A
 48%|████▊     | 1210/2514 [03:01<03:16,  6.65it/s][A
 48%|████▊     | 1211/2514 [03:02<03:15,  6.67it/s][A
 48%|████▊     | 1212/2514 [03:02<03:15,  6.66it/s][A
 48%|████▊     | 1213/2514 [03:02<03:15,  6.65it/s][A
 48%|████▊     | 1214/2514 [03:02<03:14,  6.68it/s][A
 48%|████▊     | 1215/2514 [03:02<03:14,  6.67it/s][A
 48%|████▊     | 1216/2514 [03:02<03:14,  6.67it/s][A
 48%|████▊     | 1217/2514 [03:03<03:15,  6.63it/s][A
 48%|████▊     | 1218/2514 [03:03<03:14,  6.66it/s][A
 48%|████▊     | 1219/2514 [03:03<03:15,  6.63it/s][A
 49%|████▊     | 1220/2514 [03:03<03:14,  6.66it/s][A
 49%|████▊     | 1221/2514 [03:03<03:13,  6.67it/s][A
 49%|████▊     | 1222/2514 [03:03<03:13,  6.69it/s][A
 49%|████▊     | 1223/2514 [03:03<03:12,  6.69it/s][A
 49%|████▊

 54%|█████▍    | 1354/2514 [03:23<02:53,  6.69it/s][A
 54%|█████▍    | 1355/2514 [03:23<02:53,  6.70it/s][A
 54%|█████▍    | 1356/2514 [03:23<02:53,  6.69it/s][A
 54%|█████▍    | 1357/2514 [03:24<02:52,  6.70it/s][A
 54%|█████▍    | 1358/2514 [03:24<02:53,  6.67it/s][A
 54%|█████▍    | 1359/2514 [03:24<02:53,  6.64it/s][A
 54%|█████▍    | 1360/2514 [03:24<02:53,  6.66it/s][A
 54%|█████▍    | 1361/2514 [03:24<02:53,  6.65it/s][A
 54%|█████▍    | 1362/2514 [03:24<02:53,  6.66it/s][A
 54%|█████▍    | 1363/2514 [03:24<02:52,  6.68it/s][A
 54%|█████▍    | 1364/2514 [03:25<02:51,  6.69it/s][A
 54%|█████▍    | 1365/2514 [03:25<02:51,  6.70it/s][A
 54%|█████▍    | 1366/2514 [03:25<02:51,  6.69it/s][A
 54%|█████▍    | 1367/2514 [03:25<02:51,  6.69it/s][A
 54%|█████▍    | 1368/2514 [03:25<02:51,  6.69it/s][A
 54%|█████▍    | 1369/2514 [03:25<02:51,  6.70it/s][A
 54%|█████▍    | 1370/2514 [03:25<02:51,  6.67it/s][A
 55%|█████▍    | 1371/2514 [03:26<02:50,  6.69it/s][A
 55%|█████

In [None]:
dict_of_predictions = dict(zip(images_id, predictions))

Prepare GT masks

In [None]:
gt_masks = []
for image_idx in tqdm(range(len(val_df.index.values))):
    image_name =  val_df.index.values[image_idx]
    labels = val_df.loc[image_name,:][:4]
    masks = np.zeros((256, 1600, 4), dtype=np.float32) # float32 is V.Imp
    for idx, label in enumerate(labels.values):
        if label is not np.nan:
            label = label.split(" ")
            positions = map(int, label[0::2])
            length = map(int, label[1::2])
            mask = np.zeros(256 * 1600, dtype=np.uint8)
            for pos, le in zip(positions, length):
                mask[pos:(pos + le)] = 1
            masks[:, :, idx] = mask.reshape(256, 1600, order='F')
    gt_masks.append(masks)

In [None]:
dict_of_gt_masks = dict(zip(images_id, gt_masks))

Calculate overall dice score + per-image dice score

In [None]:
min_area = [600, 600, 1000, 2000]
dice_preds = []
images_per_defect = []
thr_prediction = 0.5
for idx in tqdm(range(len(predictions))):
    for defect_type in range(4):
        mask_pred =  (predictions[idx][...,defect_type] > thr_prediction).astype(int)
        mask_gt = gt_masks[idx][...,defect_type]
        if mask_pred.sum() < min_area[defect_type]:
            mask_pred = np.zeros(mask_pred.shape)
        dice_gt_pr = dice(mask_gt,
                          mask_pred,
                          empty_score=1.0)
        dice_preds.append(dice_gt_pr)
        images_per_defect.append(images_id[idx])
print('DICE validation {}'.format(np.mean(dice_preds)))

In [None]:
dice_per_image = pd.DataFrame({'image_id':images_per_defect,'dice_per_defect':dice_preds})
dice_per_image = pd.DataFrame(dice_per_image.groupby(['image_id']).agg({'dice_per_defect': ['min', 'mean']}))
dice_per_image.columns = ['_'.join(col).strip() for col in dice_per_image.columns.values]
dice_per_image.sort_values(by=['dice_per_defect_mean'],inplace=True)
dice_per_image.head()

Visualize bad examples

In [None]:
idx = 91
image_name = dice_per_image.index.values[idx]
image = cv2.imread(os.path.join(data_folder, image_name))
f, ax = plt.subplots(4,3, figsize=(15,5))
for defect_type in range(4):
    #generate mask 
    gt_mask = dict_of_gt_masks[image_name]
    prediction = dict_of_predictions[image_name]
    ax[defect_type,0].imshow(image)
    ax[defect_type,0].imshow(gt_mask[...,defect_type], alpha=0.4)
    ax[defect_type,1].imshow(image)
    ax[defect_type,1].imshow(prediction[...,(defect_type)],vmin=0, vmax=1.0, alpha=0.4)
    mask_pred =  (prediction[...,(defect_type)] > 0.5).astype(int)
    print('Mask size for defect class {} : {}'.format(defect_type+1, mask_pred.sum()))
    if mask_pred.sum() < min_area[defect_type]:
        mask_pred = np.zeros(mask_pred.shape)
    ax[defect_type,2].imshow(image)
    ax[defect_type,2].imshow(mask_pred, alpha=0.4)
    dice_gt_pr = dice(gt_mask[...,defect_type],
                      mask_pred,
                      empty_score=1.0)
    print('Defect class {}, dice {}'.format(defect_type+1, dice_gt_pr))