### Installation

In [None]:
%cd '/content/drive/MyDrive/Colab Notebooks/kaggle/severstal'
import sys
sys.path.append('lib')

In [None]:
!pip uninstall -y opencv-python-headless 
!pip install -q opencv-python-headless==4.1.2.30

In [None]:
!pip install -U -q albumentations
!pip install -q -U pyyaml
!pip install -q catalyst

### Import

In [None]:
import numpy as np
import pandas as pd
import random
import os
import cv2
from datetime import datetime

import catalyst
from catalyst.callbacks.metric import BatchMetricCallback
from catalyst.dl import CriterionCallback, MetricAggregationCallback
from catalyst.callbacks.metrics import MultilabelPrecisionRecallF1SupportCallback
from catalyst.callbacks.metrics import MultilabelAccuracyCallback
from catalyst.dl import BatchTransformCallback
from catalyst.dl import SupervisedRunner
from catalyst.callbacks import SchedulerCallback

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from commons import *
import yaml

from catalyst_extension import CustomCheckpointCallback

### Settings

In [None]:
def provide_reproducibility():
    seed = 42
    torch.backends.cudnn.benchmark = False
    # torch.use_deterministic_algorithms(True)    
    random.seed(seed)
    np.random.seed(seed)    
    torch.manual_seed(seed)

provide_reproducibility()

### Loaders

In [None]:
with open('config/classification.yml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
loaders = create_loaders(**config['dataset'])

### Model

In [None]:
from torch.nn import Linear
from torchvision.models import resnet50
model = resnet50(pretrained=True)
model.fc = Linear(in_features=model.fc.in_features, out_features=4, bias=True)

### Configuratioin

#### Log directory

In [None]:
logdir = config['checkpoint']['dir']
logdir

#### Optimizer

In [None]:
from torch.optim import RAdam
# optimizer = AdamW(model.parameters(), lr=0.00006)        
base_params = []
fresh_params = []
for name, param in model.named_parameters():
    if 'fc' not in name:
        base_params.append(param)
    else:
        fresh_params.append(param)

print('base parameters count', len(base_params))
print('fresh parameters count', len(fresh_params))

params = [
    {'params': base_params, 'lr': 0.0001},
    {'params': fresh_params, 'lr': 0.001}
]

lr = 0.001
weight_decay = 0.0003

base_optimizer = torch.optim.RAdam(
    params, lr=lr, weight_decay=weight_decay)
optimizer = catalyst.contrib.optimizers.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.25, patience=10)

#### Catalyst

In [None]:
criterion = {
    'bce': nn.BCEWithLogitsLoss(),
}

callbacks = [
    BatchTransformCallback(
        scope="on_batch_end",
        transform=torch.sigmoid,
        input_key="logits",
        output_key="outputs",
    ),
    CriterionCallback("logits", "mask", "loss", criterion_key="bce"),
    SchedulerCallback(loader_key='valid', metric_key='loss'),
    MultilabelPrecisionRecallF1SupportCallback('outputs', 'mask', 4, 
                                               log_on_batch=False),
    CustomCheckpointCallback(
        logdir=f'{logdir}/checkpoints', loader_key='valid', metric_key='f1/_macro', 
        minimize=False)
]

### Training

In [None]:
runner = SupervisedRunner(input_key="image", output_key="logits")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir=logdir,
    num_epochs=50,
    resume=config['checkpoint']['resume'],
    verbose=True,
    fp16=True
)