In [None]:
from fastai.vision.all import *
from pathlib import Path
from timm import create_model
from fastai.callback.fp16 import *
from sklearn.model_selection import train_test_split

In [None]:
SEED=2021
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    #tf.set_random_seed(seed)
seed_everything(SEED)
torch.backends.cudnn.benchmark = True

In [None]:
YOUR_PATH = "./data"
path = Path(YOUR_PATH)

codes = ["Background", "Divots", "Cracks", "Scratches", "Ablation"]

def get_mask(fn):
    fn = Path(str(fn).replace('train_images', 'masks').replace('jpg','png'))
    return PILMask.create(fn)

def get_grayscale(fn):
    return np.array(Image.open(fn))[...,0]

In [None]:
severstal_stats=([0.343], [0.197])

In [None]:
def get_images_with_defects(path):
    L = get_image_files(path)
    L = [p for p in L if np.max(np.array(get_mask(p))) > 0]
    return L

In [None]:
def get_dls_gray(size, batch_size=8):
  
  dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                     get_items=get_image_files,
                     get_x=get_grayscale,
                     get_y = get_mask,
                     batch_tfms=[*aug_transforms(size=size,flip_vert=True), Normalize.from_stats(*severstal_stats)])
  
  return dblock.dataloaders(path/'train_images', batch_size=batch_size)

def get_dls_rgb(size, batch_size=8):
  
  dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                     get_items=get_images_with_defects,
                     get_y = get_mask,
                     batch_tfms=[*aug_transforms(size=size,flip_vert=True), Normalize.from_stats(*severstal_stats)])
  
  return dblock.dataloaders(path/'train_images', batch_size=batch_size)



In [None]:
def StratMaskSplitter(test_size=0.2, random_state=None, train_size=None, shuffle=True):
    "Split `items` into random train and test stratifying by smallest defect in the mask"
    
    def get_smallest_label(mask):
        vals, counts = tensor(mask).unique(return_counts=True)
        return vals[torch.argmin(counts)]
    
    def _inner(o, **kwargs):
        
        labels = tensor([get_smallest_label(get_mask(fn)) for fn in o])       
        
        train,valid = train_test_split(range_of(o), test_size=test_size, random_state=random_state,
                                        stratify=labels, train_size=train_size, shuffle=shuffle)
        return L(train), L(valid)
    return _inner

In [None]:
dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                     get_items=get_images_with_defects,
                     get_y = get_mask,
                     splitter = StratMaskSplitter(),
                     batch_tfms=[*aug_transforms(size=(256,512),flip_vert=True), Normalize.from_stats(*severstal_stats)])

In [None]:
dls_strat = dblock.dataloaders(path/'train_images', batch_size=6)

In [None]:
dls_strat.show_batch()

In [None]:
sz = (256, 512) # As per 1st place solution
bs = 6       # 1st place solution proposes 12 or 24 with grad accumulation at 24 samples

dls_gray = get_dls_gray(size=sz, batch_size=bs)
dls_rgb = get_dls_rgb(size=sz, batch_size=bs)

### Hacky TIMM encoders

In [None]:
def _update_first_layer(model, n_in, pretrained):
    "Change first layer based on number of input channels"
    if n_in == 3: return
    first_layer, parent, name = _get_first_layer(model)
    assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'
    assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, "in_channels")} while expecting 3'
    params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}
    params['bias'] = getattr(first_layer, 'bias') is not None
    params['in_channels'] = n_in
    new_layer = nn.Conv2d(**params)
    if pretrained:
        _load_pretrained_weights(new_layer, first_layer)
    setattr(parent, name, new_layer)
    
def _add_norm(dls, meta, pretrained):
    if not pretrained: return
    stats = meta.get('stats')
    if stats is None: return
    if not dls.after_batch.fs.filter(risinstance(Normalize)):
        dls.add_tfms([Normalize.from_stats(*stats)],'after_batch')
    
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
    "Creates a body from any model in the `timm` library."
    model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
    _update_first_layer(model, n_in, pretrained)
    if cut is None:
        ll = list(enumerate(model.children()))
        cut = next(i for i,o in reversed(ll) if has_pool_type(o))
    if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut): return cut(model)
    else: raise NamedError("cut must be either integer or function")

def unet_learner_timm(dls, arch, normalize=True, n_out=None, n_in=3, pretrained=True, config=None,
                 # learner args
                 loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
                 model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
                 # other model args
                 **kwargs):
    "Build a unet learner from `dls` and `arch`"

    if config:
        warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')
        kwargs = {**config, **kwargs}
    _default_meta    = {'cut':None, 'split':default_split}
    
    
    meta = model_meta.get(arch, _default_meta)
    if normalize: _add_norm(dls, meta, pretrained)

    n_out = ifnone(n_out, get_c(dls))
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    img_size = dls.one_batch()[0].shape[-2:]
    assert img_size, "image size could not be inferred from data"
    timm_body = create_timm_body(arch, n_in=n_in)
    model = DynamicUnet(timm_body, n_out, img_size, **kwargs)

    splitter=ifnone(splitter, meta['split'])
    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
                   moms=moms)
    if pretrained: learn.freeze()
    # keep track of args for loggers
    store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
    return learn

### Loss func

In [None]:
class DiceLoss:
    "Dice loss for segmentation"
    def __init__(self, axis=1, smooth=1):
        store_attr()
    def __call__(self, pred, targ):
        targ = self._one_hot(targ, pred.shape[self.axis])
        pred, targ = flatten_check(self.activation(pred), targ)
        inter = (pred*targ).sum()
        union = (pred+targ).sum()
        return 1 - (2. * inter + self.smooth)/(union + self.smooth)
    @staticmethod
    def _one_hot(x, classes, axis=1):
        "Creates one binay mask per class"
        return torch.stack([torch.where(x==c, 1, 0) for c in range(classes)], axis=axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)
    def decodes(self, x):    return x.argmax(dim=self.axis)

In [None]:
class CustomComboLoss:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1., alpha=0.75):
        store_attr()
        self.focal_loss = FocalLossFlat(axis=axis)
        self.dice_loss =  DiceLoss(axis, smooth)       
    def __call__(self, pred, targ):
        return (self.alpha * self.focal_loss(pred, targ)) + ((1-self.alpha) * self.dice_loss(pred, targ))
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

### Training

In [None]:

learn = unet_learner(dls_strat, resnet50, loss_func=CustomComboLoss(alpha=0.5), metrics=[DiceMulti], opt_func=ranger, act_cls=Mish, cbs=[GradientAccumulation(n_acc=24)],self_attention = True).to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fit_flat_cos(10, slice(1e-4))

In [None]:
learn.unfreeze()
learn.lr_find()

In [None]:
learn.summary()

In [None]:
lr = 1e-5
lrs = slice(lr/100, lr)
learn.fit_flat_cos(10, lrs, cbs = [GradientAccumulation(n_acc=24),ReduceLROnPlateau(factor=4,patience=2, min_lr=1e-9)])

In [None]:
learn.export("UNET_R50_29.07.2021.pkl")

In [None]:
learn.loss_func = CustomComboLoss2()

In [None]:
class CustomComboLoss2:
    "Dice and Focal combined"
    def __init__(self, axis=1, smooth=1., alpha=0.8):
        store_attr()
        self.ce_loss = CrossEntropyLossFlat(axis=axis, weight=tensor(0.5,2.0,2.0,1.0,1.5))
        self.dice_loss =  DiceLoss(axis, smooth)       
    def __call__(self, pred, targ):
        return (self.alpha * self.ce_loss(pred, targ)) + ((1-self.alpha) * self.dice_loss(pred, targ))
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

In [None]:
learn.opt.hypers

In [None]:
learn.lr_find(start_lr=1e-10, end_lr=1e-3)

In [None]:
learn.lr_find() #old

In [None]:
learn_new = learn = unet_learner(dls_strat, resnet50, metrics=[DiceMulti], opt_func=ranger, act_cls=Mish, cbs=[GradientAccumulation(n_acc=24)],self_attention = True).to_fp16()

In [None]:
learn_new.loss_func

### Adam16 optimizer to help with NaN (not needed, Ranger works)

In [None]:
import math
from torch.optim.optimizer import Optimizer

# This version of Adam keeps an fp32 copy of the parameters and 
# does all of the parameter updates in fp32, while still doing the
# forwards and backwards passes using fp16 (i.e. fp16 copies of the 
# parameters and fp16 activations).
#
# Note that this calls .float().cuda() on the params such that it 
# moves them to gpu 0--if you're using a different GPU or want to 
# do multi-GPU you may need to deal with this.
class Adam16(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        params = list(params)
        super(Adam16, self).__init__(params, defaults)
        # for group in self.param_groups:
            # for p in group['params']:
        
        self.fp32_param_groups = [p.data.float().cuda() for p in params]
        if not isinstance(self.fp32_param_groups[0], dict):
            self.fp32_param_groups = [{'params': self.fp32_param_groups}]

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group,fp32_group in zip(self.param_groups,self.fp32_param_groups):
            for p,fp32_p in zip(group['params'],fp32_group['params']):
                if p.grad is None:
                    continue
                    
                grad = p.grad.data.float()
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = grad.new().resize_as_(grad).zero_()
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], fp32_p)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
            
                # print(type(fp32_p))
                fp32_p.addcdiv_(-step_size, exp_avg, denom)
                p.data = fp32_p.half()

        return loss

### Jarvis.ai TIMM encoder implementation (WIP)

In [None]:
class Encoder(Module):
    def __init__(self, model_name='efficientnet_b0'):
        self.encoder = create_model(model_name, features_only=True, pretrained=True)
    
    def __getitem__(self,i):
        return self.encoder.blocks[i]
    
    def forward(self,x):
        return self.encoder(x)
    
def UnetBlock(Module):
    def __init__(self, in_channels, channels, out_channels, act=ReLU, attn=None):
        self.act = act
        self.conv1(ConvLayer(in_channels,channels, act_cls = self.act))
        self.conv2(ConvLayer(channels,out_channels, act_cls = self.act))
        self.attn_layer = attn(out_channels) if attn else noop
        
    def forward(self,x):
        x=F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.attn_layer(x)
        return x

    
def UnetDecoder(Module):
    def __init__(self, fs=32, expansion = 4, n_out=1, hypercol=False, attn=None, act = ReLU):
        center_ch = 512*expansion
        decoder5_ch = center_ch + (256 * expansion)
        channels=512
        self.hypercol = hypercol
        self.center = nn.Sequential(ConvLayer(center_ch,center_ch, act_cls=act), ConvLayer(center_ch,center_ch//2,act_cls=act))
        self.decoder5 = UnetBlock(decoder5_ch, channels, fs, act, attn)
        self.decoder4 = UnetBlock(256*expansion+fs, 256, fs, act, attn)
        self.decoder3 = UnetBlock(128*expansion+fs, 128, fs, act, attn)
        self.decoder2 = UnetBlock(64*expansion+fs, 64, fs, act, attn)
        self.decoder1 = UnetBlock(fs, fs, fs, act, attn)
        if hypercol:
            self.logit = nn.Sequential(ConvLayer(fs*5,fs*2), ConvLayer(fs*2,fs),nn.Conv2d(fs,n_out,kernel_size=1))
        else:
            self.logit = nn.Sequential(ConvLayer(fs,fs//2), ConvLayer(fs//2,fs//2),nn.Conv2d(fs//2,n_out,kernel_size=1))
    def forward(self, features):
        e1,e2,e3,e4,e5 = features
        f = self.center(e5)
        d5 = self.decoder5(torch.cat([f,e5],1))
        d4 = self.decoder4(torch.cat([d5,e4],1))
        d3 = self.decoder3(torch.cat([d4,e3],1))
        d2 = self.decoder2(torch.cat([d3,e2],1))
        d1 = self.decoder1(d2)