In [None]:
%load_ext autoreload
%autoreload
from factory import *
import torch
import os 
import sys
from catalyst.dl.callbacks import CriterionCallback, EarlyStoppingCallback, OptimizerCallback, CriterionAggregatorCallback, F1ScoreCallback, AUCCallback
from catalyst.dl.runner import SupervisedRunner
from pytorch_toolbelt import losses as L
from pytorch_toolbelt.inference import tta
import collections
from pytorch_toolbelt.utils.catalyst import * 
from metrics import *
import matplotlib.pyplot as plt
from viz_utils import *
from tqdm import tqdm
import cv2
#import segmentation_models_pytorch as smp
from catalyst.contrib.optimizers import RAdam, Lookahead
%matplotlib inline

In [None]:
#!pip install --force-reinstall catalyst

In [None]:
encoder_name = 'resnet34' #also densenet169, resnet50, se_resnet50
sample_submission_path = 'data/sample_submission.csv'
train_df_path = 'data/train.csv'
data_folder = "data/train_images/"
test_data_folder = "data/test_images/"
base_exp_name = '{}_multihead'.format(encoder_name)
log_dir = 'logs/{}/'.format(base_exp_name)
batch_size = 16
batch_size_val = 8
accumulation_steps = 5
num_workers = 24
num_epochs_with_frozen_encoder = 5
num_epochs = 60
tta_type = None
output_channels = 4
output_channels_class = 4

In [None]:
!rm -r logs/resnet34_multihead/

In [None]:
def compute_boundary_mask(mask: np.ndarray) -> np.ndarray:
    dilated = binary_dilation(mask, structure=np.ones((9, 9), dtype=np.bool))
    dilated = binary_fill_holes(dilated)
    diff = dilated & ~mask
    diff = cv2.dilate(diff, kernel=(9, 9))
    diff = diff & ~mask
    #kernel = np.ones((4,),np.uint8)
    #diff = cv2.morphologyEx(mask, cv2.MORPH_GRADIENT, kernel)
    return diff.astype(np.uint8)



def make_mask(row_id, df):
    '''Given a row index, return image_id and mask (256, 1600, 4) from the dataframe `df`'''
    fname = df.iloc[row_id].name
    labels = df.iloc[row_id][:4]
    masks = np.zeros((256, 1600, 4), dtype=np.float32) # float32 is V.Imp
    # 4:class 1～4 (ch:0～3)

    for idx, label in enumerate(labels.values):
        if label is not np.nan:
            label = label.split(" ")
            positions = map(int, label[0::2])
            length = map(int, label[1::2])
            mask = np.zeros(256 * 1600, dtype=np.uint8)
            for pos, le in zip(positions, length):
                mask[pos:(pos + le)] = 1
            masks[:, :, idx] = mask.reshape(256, 1600, order='F')
    return fname, masks


class SteelDatasetMultiV3(Dataset):
    def __init__(self, df, data_folder, transforms, phase, prepare_coarse = False, prepare_edges = False, prepare_class = False, prepare_full = False):
        self.df = df
        self.root = data_folder
        self.phase = phase
        self.transforms = transforms
        self.fnames = self.df.index.tolist()
        self.prepare_coarse = prepare_coarse
        self.prepare_edges = prepare_edges
        self.prepare_class = prepare_class
        self.prepare_full = prepare_full
        
    def __getitem__(self, idx):
        image_id, mask = make_mask(idx, self.df)
        image_path = os.path.join(self.root,  image_id)
        img = cv2.imread(image_path)
        if self.transforms is not None:
            augmented = self.transforms(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask'].astype(np.uint8)
        if self.prepare_full:
            all_masks_combined = (mask.sum(axis=2)>0).astype(np.uint8)        
            all_masks_combined = np.expand_dims(all_masks_combined, 2)
            mask = np.concatenate([mask, all_masks_combined], axis=2)
        if self.prepare_coarse:
            coarse_mask = cv2.resize(mask,
                                     dsize=(mask.shape[1]//4, mask.shape[0]//4),
                                     interpolation=cv2.INTER_LINEAR)
        if self.prepare_edges:
            all_masks_combined = (mask.sum(axis=2)>0).astype(np.uint8)  
            edges = compute_boundary_mask(all_masks_combined).astype(np.uint8)
            edges = np.expand_dims(edges, 2)
            mask = np.concatenate([mask, edges], axis=2)
            if self.prepare_coarse:
                coarse_edges = cv2.resize(edges,
                                     dsize=(mask.shape[1]//4, mask.shape[0]//4),
                                     interpolation=cv2.INTER_LINEAR)
                coarse_edges = np.expand_dims(coarse_edges, 2)
                coarse_mask = np.concatenate([coarse_mask, coarse_edges], axis=2)
                
        if self.prepare_class:
            defects =  (np.array([mask[:,:,plane].sum()>0 for plane in range(mask.shape[2])])).astype(int)
            
        data = {'features': tensor_from_rgb_image(img),
                'targets' : tensor_from_mask_image(mask).float(),
                'image_id' : image_id}
        if self.prepare_coarse:
            data['coarse_targets'] =  tensor_from_mask_image(coarse_mask).float()  
        if self.prepare_class:
            data['classification_labels'] = defects.astype(float)
        return data

    def __len__(self):
        return len(self.fnames)    
    
def provider(
    data_folder,
    df_path,
    phase,
    transforms,    
    batch_size=8,
    num_workers=4,
    prepare_coarse = False, 
    prepare_edges = False,
    prepare_class = False, 
    prepare_full = False
):
    '''Returns dataloader for the model training'''
    df = pd.read_csv(df_path)
    # https://www.kaggle.com/amanooo/defect-detection-starter-u-net
    df['ImageId'], df['ClassId'] = zip(*df['ImageId_ClassId'].str.split('_'))
    df['ClassId'] = df['ClassId'].astype(int)
    df = df.pivot(index='ImageId',columns='ClassId',values='EncodedPixels')
    df['defects'] = df.count(axis=1)

    
    train_df, val_df = train_test_split(df, test_size=0.2, stratify=df["defects"], random_state=69)
    df = train_df if phase == "train" else val_df
    image_dataset = SteelDatasetMultiV3(df, data_folder, 
                                        transforms, phase, 
                                        prepare_coarse, prepare_edges, 
                                        prepare_class, prepare_full)
    if phase=='train':
        shuffle = True
    else:
        shuffle = False
    dataloader = DataLoader(
        image_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=shuffle,   
    )

    return dataloader    

def light_augmentations():
    return A.Compose([

        # D4 Augmentations
        A.Compose([
            #A.Transpose(),            
            A.HorizontalFlip(),            
            A.VerticalFlip()
            #A.RandomRotate90(),
        ]),

        # Spatial-preserving augmentations:
        A.OneOf([
            A.Cutout(),
            A.GaussNoise(),
        ]),
        A.OneOf([
            A.RandomBrightnessContrast(),
            A.CLAHE(),
            A.HueSaturationValue(),
            A.RGBShift(),
            A.RandomGamma()
        ]),

        A.Normalize()
    ])


def medium_augmentations():
    return A.Compose([

        # Add occasion blur/sharpening
        A.OneOf([
            A.GaussianBlur(),
            A.MotionBlur(),
            A.IAASharpen()
        ]),

        # D4 Augmentations
        A.Compose([
            #A.Transpose(),
            A.VerticalFlip(),
            A.HorizontalFlip()
            #A.RandomRotate90(),
        ]),

        # Spatial-preserving augmentations:
        A.OneOf([
            A.Cutout(),
            A.GaussNoise(),
        ]),
        A.OneOf([
            A.RandomBrightnessContrast(),
            A.CLAHE(),
            A.HueSaturationValue(),
            A.RGBShift(),
            A.RandomGamma()
        ]),

        A.Normalize()
    ])


def hard_augmentations():
    return A.Compose([

        # Add occasion blur
        A.OneOf([
            A.GaussianBlur(),
            A.GaussNoise(),
            A.IAAAdditiveGaussianNoise(),
            A.NoOp()
        ]),

        # D4 Augmentations
        A.Compose([
            #A.Transpose(),
            A.VerticalFlip(),
            A.HorizontalFlip()
            #A.RandomRotate90(),
        ]),

        A.Cutout(),
        # Spatial-preserving augmentations:
        A.OneOf([
            A.RandomBrightnessContrast(),
            A.CLAHE(),
            A.HueSaturationValue(),
            A.RGBShift(),
            A.RandomGamma(),
            A.NoOp()
        ]),

        A.Normalize()
    ])


def validation_augmentations():
    return A.Compose([
        A.Normalize()
    ],p=1.0)

In [None]:
package_path = './segmentation_models.pytorch/'
sys.path.append(package_path)
import segmentation_models_pytorch_local as smp

In [None]:
dataloader_train = provider(
    data_folder=data_folder,
    df_path=train_df_path,
    phase='train',
    transforms=light_augmentations(),
    batch_size=batch_size,
    num_workers=num_workers, 
    prepare_coarse=True, 
    prepare_edges=False, 
    prepare_class=True,
    prepare_full=False)
dataloader_val = provider(
    data_folder=data_folder,
    df_path=train_df_path,
    phase='val',
    transforms=validation_augmentations(),
    batch_size=batch_size_val,
    num_workers=num_workers, 
    prepare_coarse=True, 
    prepare_edges=False, 
    prepare_class=True,
    prepare_full=False)

In [None]:
data_b = next(iter(dataloader_train))
print(data_b.keys())

In [None]:
f, ax = plt.subplots(2,2)
idx = 0
ax[0,0].imshow(data_b['targets'][idx][0,:,:].numpy())
ax[0,1].imshow(data_b['targets'][idx][1,:,:].numpy())
ax[1,0].imshow(data_b['targets'][idx][2,:,:].numpy())
ax[1,1].imshow(data_b['targets'][idx][3,:,:].numpy())
plt.tight_layout()

In [None]:
f, ax = plt.subplots(2,2)
idx = 0
ax[0,0].imshow(data_b['coarse_targets'][idx][0,:,:].numpy())
ax[0,1].imshow(data_b['coarse_targets'][idx][1,:,:].numpy())
ax[1,0].imshow(data_b['coarse_targets'][idx][2,:,:].numpy())
ax[1,1].imshow(data_b['coarse_targets'][idx][3,:,:].numpy())
plt.tight_layout()

In [None]:
model = smp.Unet(encoder_name=encoder_name,
                 decoder_use_batchnorm=True,
                 classes=output_channels,
                 num_classes_classification=output_channels_class)
loss_f_segmentation = get_loss('bce_dice')
losses = dict({'loss_f_segmentation' : loss_f_segmentation, 
               'coarse_loss_f_segmentation' : loss_f_segmentation, 
                'loss_f_classification' : L.BinaryFocalLoss()
              })
optimizer = RAdam(model.parameters(), lr = 3e-4)
optimizer_Lookahead = Lookahead(optimizer)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_Lookahead, 
                                                 milestones=[10, 15, 20],
                                                 gamma=0.5)

In [None]:
loaders = collections.OrderedDict()
loaders["train"] = dataloader_train
loaders["valid"] = dataloader_val
runner = SupervisedRunner(input_key = 'features',
                          output_key =  None,
                          input_target_key = None)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = False

In [None]:
runner.train(
    model=model,
    criterion=losses,
    optimizer=optimizer_Lookahead,
    callbacks=[
        CriterionCallback(input_key="targets",
                          output_key="logits",
                          prefix="loss_segmentation",
                          criterion_key='loss_f_segmentation', 
                          multiplier=1.0), 
        
        CriterionCallback(input_key="coarse_targets",
                          output_key="coarse_logits",
                          prefix="coarse_loss_segmentation",
                          criterion_key='coarse_loss_f_segmentation',
                          multiplier=1.0), 
        
        CriterionCallback(input_key="classification_labels",
                          output_key="classification_logits",
                          prefix="classification_loss",
                          criterion_key='loss_f_classification',  
                          multiplier=10.0), 
        
        CriterionAggregatorCallback(prefix='combined_loss', 
                                    loss_keys=['loss_segmentation',
                                               'coarse_loss_segmentation',
                                               'classification_loss'
                                              ],
                                    loss_aggregate_fn='sum'),
        
        OptimizerCallback(accumulation_steps=accumulation_steps, 
                          loss_key='loss_segmentation'),
        
        IoUMetricsCallback(mode='multilabel',
                           output_key="logits",
                           input_key='targets',
                           metric="dice",
                           prefix='dice',
                           nan_score_on_empty=False),
        
        IoUMetricsCallback(mode='multilabel',
                           output_key="coarse_logits",
                           input_key='coarse_targets',
                           metric="dice",
                           prefix='coarse_dice',
                           nan_score_on_empty=False),
        
        DiceScoreCallback(mode='binary',
                          output_key="logits",
                          input_key='targets',                          
                          prefix='total_dice',
                          nan_score_on_empty=False),
        F1ScoreCallback(input_key='classification_labels',
                        output_key='classification_logits',
                        prefix='f1_score',),
        AUCCallback(num_classes=output_channels_class,
                    input_key='classification_labels',
                    output_key='classification_logits')
    ],
    loaders=loaders,
    logdir=log_dir,
    main_metric='combined_loss',
    num_epochs=num_epochs_with_frozen_encoder,
    verbose=True
)

In [None]:
for param in model.encoder.parameters():
    param.requires_grad = True
model.load_state_dict(torch.load(os.path.join(log_dir,'checkpoints/best.pth'))['model_state_dict'])

In [None]:
runner.train(
    model=model,
    criterion=losses,
    optimizer=optimizer_Lookahead,
    callbacks=[
        CriterionCallback(input_key="targets",
                          output_key="logits",
                          prefix="loss_segmentation",
                          criterion_key='loss_f_segmentation', 
                          multiplier=1.0), 
        
        CriterionCallback(input_key="coarse_targets",
                          output_key="coarse_logits",
                          prefix="coarse_loss_segmentation",
                          criterion_key='coarse_loss_f_segmentation',
                          multiplier=1.0), 
        
        CriterionCallback(input_key="classification_labels",
                          output_key="classification_logits",
                          prefix="classification_loss",
                          criterion_key='loss_f_classification',  
                          multiplier=10.0), 
        
        CriterionAggregatorCallback(prefix='combined_loss', 
                                    loss_keys=['loss_segmentation',
                                               'coarse_loss_segmentation',
                                               'classification_loss'
                                              ],
                                    loss_aggregate_fn='sum'),
        
        OptimizerCallback(accumulation_steps=accumulation_steps, 
                          loss_key='loss_segmentation'),
        
        IoUMetricsCallback(mode='multilabel',
                           output_key="logits",
                           input_key='targets',
                           metric="dice",
                           prefix='dice',
                           nan_score_on_empty=False),
        
        IoUMetricsCallback(mode='multilabel',
                           output_key="coarse_logits",
                           input_key='coarse_targets',
                           metric="dice",
                           prefix='coarse_dice',
                           nan_score_on_empty=False),
        
        DiceScoreCallback(mode='binary',
                          output_key="logits",
                          input_key='targets',                          
                          prefix='total_dice',
                          nan_score_on_empty=False),
        F1ScoreCallback(input_key='classification_labels',
                        output_key='classification_logits',
                        prefix='f1_score',),
        AUCCallback(num_classes=output_channels_class,
                    input_key='classification_labels',
                    output_key='classification_logits')
    ],
    loaders=loaders,
    logdir=log_dir,
    num_epochs=num_epochs,
    main_metric='combined_loss',
    #resume=f"{log_dir}/checkpoints/best.pth",
    verbose=True
)

In [None]:
dataloader_train = provider(
    data_folder=data_folder,
    df_path=train_df_path,
    phase='train',
    transforms=hard_augmentations(),
    batch_size=batch_size,
    num_workers=num_workers, 
    prepare_coarse=True, 
    prepare_edges=False, 
    prepare_class=True,
    prepare_full=False)
dataloader_val = provider(
    data_folder=data_folder,
    df_path=train_df_path,
    phase='val',
    transforms=validation_augmentations(),
    batch_size=batch_size_val,
    num_workers=num_workers, 
    prepare_coarse=True, 
    prepare_edges=False, 
    prepare_class=True,
    prepare_full=False)
loaders = collections.OrderedDict()
loaders["train"] = dataloader_train
loaders["valid"] = dataloader_val

In [None]:
model.load_state_dict(torch.load(os.path.join(log_dir,'checkpoints/best.pth'))['model_state_dict'])

In [None]:
runner.train(
    model=model,
    criterion=losses,
    optimizer=optimizer_Lookahead,
    callbacks=[
        CriterionCallback(input_key="targets",
                          output_key="logits",
                          prefix="loss_segmentation",
                          criterion_key='loss_f_segmentation', 
                          multiplier=1.0), 
        
        CriterionCallback(input_key="coarse_targets",
                          output_key="coarse_logits",
                          prefix="coarse_loss_segmentation",
                          criterion_key='coarse_loss_f_segmentation',
                          multiplier=1.0), 
        
        CriterionCallback(input_key="classification_labels",
                          output_key="classification_logits",
                          prefix="classification_loss",
                          criterion_key='loss_f_classification',  
                          multiplier=10.0), 
        
        CriterionAggregatorCallback(prefix='combined_loss', 
                                    loss_keys=['loss_segmentation',
                                               'coarse_loss_segmentation',
                                               'classification_loss'
                                              ],
                                    loss_aggregate_fn='sum'),
        
        OptimizerCallback(accumulation_steps=accumulation_steps, 
                          loss_key='loss_segmentation'),
        
        IoUMetricsCallback(mode='multilabel',
                           output_key="logits",
                           input_key='targets',
                           metric="dice",
                           prefix='dice',
                           nan_score_on_empty=False),
        
        IoUMetricsCallback(mode='multilabel',
                           output_key="coarse_logits",
                           input_key='coarse_targets',
                           metric="dice",
                           prefix='coarse_dice',
                           nan_score_on_empty=False),
        
        DiceScoreCallback(mode='binary',
                          output_key="logits",
                          input_key='targets',                          
                          prefix='total_dice',
                          nan_score_on_empty=False),
        F1ScoreCallback(input_key='classification_labels',
                        output_key='classification_logits',
                        prefix='f1_score',),
        AUCCallback(num_classes=output_channels_class,
                    input_key='classification_labels',
                    output_key='classification_logits')
    ],
    loaders=loaders,
    logdir=log_dir,
    num_epochs=num_epochs,
    main_metric='combined_loss',
    resume=f"{log_dir}/checkpoints/best.pth",
    verbose=True
)