In [21]:
import os
from torch.utils.data import DataLoader

In [22]:

"""ResNet variants"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair

__all__ = ['ResNet', 'Bottleneck']

_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth'

_model_sha256 = {name: checksum for checksum, name in [
    ]}


def short_hash(name):
    if name not in _model_sha256:
        raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
    return _model_sha256[name][:8]

resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
    name in _model_sha256.keys()
}

class DropBlock2D(object):
    def __init__(self, *args, **kwargs):
        raise NotImplementedError

class SplAtConv2d(Module):
    """Split-Attention Conv2d
    """
    def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
                 dilation=(1, 1), groups=1, bias=True,
                 radix=2, reduction_factor=4,
                 rectify=False, rectify_avg=False, norm_layer=None,
                 dropblock_prob=0.0, **kwargs):
        super(SplAtConv2d, self).__init__()
        padding = _pair(padding)
        self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
        self.rectify_avg = rectify_avg
        inter_channels = max(in_channels*radix//reduction_factor, 32)
        self.radix = radix
        self.cardinality = groups
        self.channels = channels
        self.dropblock_prob = dropblock_prob
        if self.rectify:
            from rfconv import RFConv2d
            self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
                                 groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
        else:
            self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
                               groups=groups*radix, bias=bias, **kwargs)
        self.use_bn = norm_layer is not None
        if self.use_bn:
            self.bn0 = norm_layer(channels*radix)
        self.relu = ReLU(inplace=True)
        self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
        if self.use_bn:
            self.bn1 = norm_layer(inter_channels)
        self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality)
        if dropblock_prob > 0.0:
            self.dropblock = DropBlock2D(dropblock_prob, 3)
        self.rsoftmax = rSoftMax(radix, groups)

    def forward(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn0(x)
        if self.dropblock_prob > 0.0:
            x = self.dropblock(x)
        x = self.relu(x)

        batch, rchannel = x.shape[:2]
        if self.radix > 1:
            if torch.__version__ < '1.5':
                splited = torch.split(x, int(rchannel//self.radix), dim=1)
            else:
                splited = torch.split(x, rchannel//self.radix, dim=1)
            gap = sum(splited)
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        if self.use_bn:
            gap = self.bn1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

        if self.radix > 1:
            if torch.__version__ < '1.5':
                attens = torch.split(atten, int(rchannel//self.radix), dim=1)
            else:
                attens = torch.split(atten, rchannel//self.radix, dim=1)
            out = sum([att*split for (att, split) in zip(attens, splited)])
        else:
            out = atten * x
        return out.contiguous()

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        """Global average pooling over the input's spatial dimensions"""
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, inputs):
        return nn.functional.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)

class Bottleneck(nn.Module):
    """ResNet Bottleneck
    """
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 radix=1, cardinality=1, bottleneck_width=64,
                 avd=False, avd_first=False, dilation=1, is_first=False,
                 rectified_conv=False, rectify_avg=False,
                 norm_layer=None, dropblock_prob=0.0, last_gamma=False):
        super(Bottleneck, self).__init__()
        group_width = int(planes * (bottleneck_width / 64.)) * cardinality
        self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
        self.bn1 = norm_layer(group_width)
        self.dropblock_prob = dropblock_prob
        self.radix = radix
        self.avd = avd and (stride > 1 or is_first)
        self.avd_first = avd_first

        if self.avd:
            self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
            stride = 1

        if dropblock_prob > 0.0:
            self.dropblock1 = DropBlock2D(dropblock_prob, 3)
            if radix == 1:
                self.dropblock2 = DropBlock2D(dropblock_prob, 3)
            self.dropblock3 = DropBlock2D(dropblock_prob, 3)

        if radix >= 1:
            self.conv2 = SplAtConv2d(
                group_width, group_width, kernel_size=3,
                stride=stride, padding=dilation,
                dilation=dilation, groups=cardinality, bias=False,
                radix=radix, rectify=rectified_conv,
                rectify_avg=rectify_avg,
                norm_layer=norm_layer,
                dropblock_prob=dropblock_prob)
        elif rectified_conv:
            from rfconv import RFConv2d
            self.conv2 = RFConv2d(
                group_width, group_width, kernel_size=3, stride=stride,
                padding=dilation, dilation=dilation,
                groups=cardinality, bias=False,
                average_mode=rectify_avg)
            self.bn2 = norm_layer(group_width)
        else:
            self.conv2 = nn.Conv2d(
                group_width, group_width, kernel_size=3, stride=stride,
                padding=dilation, dilation=dilation,
                groups=cardinality, bias=False)
            self.bn2 = norm_layer(group_width)

        self.conv3 = nn.Conv2d(
            group_width, planes * 4, kernel_size=1, bias=False)
        self.bn3 = norm_layer(planes*4)

        if last_gamma:
            from torch.nn.init import zeros_
            zeros_(self.bn3.weight)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.dilation = dilation
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        if self.dropblock_prob > 0.0:
            out = self.dropblock1(out)
        out = self.relu(out)

        if self.avd and self.avd_first:
            out = self.avd_layer(out)

        out = self.conv2(out)
        if self.radix == 0:
            out = self.bn2(out)
            if self.dropblock_prob > 0.0:
                out = self.dropblock2(out)
            out = self.relu(out)

        if self.avd and not self.avd_first:
            out = self.avd_layer(out)

        out = self.conv3(out)
        out = self.bn3(out)
        if self.dropblock_prob > 0.0:
            out = self.dropblock3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNeSt(nn.Module):
    """ResNet Variants

    Parameters
    ----------
    block : Block
        Class for the residual block. Options are BasicBlockV1, BottleneckV1.
    layers : list of int
        Numbers of layers in each block
    classes : int, default 1000
        Number of classification classes.
    dilated : bool, default False
        Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
        typically used in Semantic Segmentation.
    norm_layer : object
        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
        for Synchronized Cross-GPU BachNormalization).

    Reference:

        - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

        - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
    """
    def __init__(self, block, layers, radix=1, groups=1, bottleneck_width=64,
                 num_classes=1000, dilated=False, dilation=1,
                 deep_stem=False, stem_width=64, avg_down=False,
                 rectified_conv=False, rectify_avg=False,
                 avd=False, avd_first=False,
                 final_drop=0.0, dropblock_prob=0,
                 last_gamma=False, norm_layer=nn.BatchNorm2d):
        self.cardinality = groups
        self.bottleneck_width = bottleneck_width
        # ResNet-D params
        self.inplanes = stem_width*2 if deep_stem else 64
        self.avg_down = avg_down
        self.last_gamma = last_gamma
        # ResNeSt params
        self.radix = radix
        self.avd = avd
        self.avd_first = avd_first

        super(ResNeSt, self).__init__()
        self.rectified_conv = rectified_conv
        self.rectify_avg = rectify_avg
        if rectified_conv:
            from rfconv import RFConv2d
            conv_layer = RFConv2d
        else:
            conv_layer = nn.Conv2d
        conv_kwargs = {'average_mode': rectify_avg} if rectified_conv else {}
        if deep_stem:
            self.conv1 = nn.Sequential(
                conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
                norm_layer(stem_width),
                nn.ReLU(inplace=True),
                conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
                norm_layer(stem_width),
                nn.ReLU(inplace=True),
                conv_layer(stem_width, stem_width*2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
            )
        else:
            self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
                                   bias=False, **conv_kwargs)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
        if dilated or dilation == 4:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
                                           dilation=2, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
                                           dilation=4, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        elif dilation==2:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                           dilation=1, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
                                           dilation=2, norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        else:
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                           norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                           norm_layer=norm_layer,
                                           dropblock_prob=dropblock_prob)
        self.avgpool = GlobalAvgPool2d()
        self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, norm_layer):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
                    dropblock_prob=0.0, is_first=True):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            down_layers = []
            if self.avg_down:
                if dilation == 1:
                    down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride,
                                                    ceil_mode=True, count_include_pad=False))
                else:
                    down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1,
                                                    ceil_mode=True, count_include_pad=False))
                down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
                                             kernel_size=1, stride=1, bias=False))
            else:
                down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
                                             kernel_size=1, stride=stride, bias=False))
            down_layers.append(norm_layer(planes * block.expansion))
            downsample = nn.Sequential(*down_layers)

        layers = []
        if dilation == 1 or dilation == 2:
            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                                radix=self.radix, cardinality=self.cardinality,
                                bottleneck_width=self.bottleneck_width,
                                avd=self.avd, avd_first=self.avd_first,
                                dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
                                rectify_avg=self.rectify_avg,
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                                last_gamma=self.last_gamma))
        elif dilation == 4:
            layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                                radix=self.radix, cardinality=self.cardinality,
                                bottleneck_width=self.bottleneck_width,
                                avd=self.avd, avd_first=self.avd_first,
                                dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
                                rectify_avg=self.rectify_avg,
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                                last_gamma=self.last_gamma))
        else:
            raise RuntimeError("=> unknown dilation size: {}".format(dilation))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes,
                                radix=self.radix, cardinality=self.cardinality,
                                bottleneck_width=self.bottleneck_width,
                                avd=self.avd, avd_first=self.avd_first,
                                dilation=dilation, rectified_conv=self.rectified_conv,
                                rectify_avg=self.rectify_avg,
                                norm_layer=norm_layer, dropblock_prob=dropblock_prob,
                                last_gamma=self.last_gamma))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        if self.drop:
            x = self.drop(x)
        x = self.fc(x)

        return x

    def forward_feature(self, x, out_block_stage):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x0 = self.maxpool(x)

        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        if out_block_stage == 1: return [x], x1
        elif out_block_stage == 2: return [x1, x], x2
        elif out_block_stage == 3: return [x2, x1, x], x3
        elif out_block_stage == 4: return [x3, x2, x1, x], x4


def resnet50(pretrained=False, root='~/.encoding/models', **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnet50'], progress=True, check_hash=True))
    return model

def resnet101(pretrained=False, root='~/.encoding/models', **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnet101'], progress=True, check_hash=True))
    return model

def resnet152(pretrained=False, root='~/.encoding/models', **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(torch.hub.load_state_dict_from_url(
            resnest_model_urls['resnet152'], progress=True, check_hash=True))
    return model

In [23]:
import torch
import torch.utils.model_zoo as model_zoo


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
    'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
    'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
    'resnest50': 'https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/528c19ca-resnest50.pth'
}

def IS2D_model(args) :
    return MFMSNet(args.num_classes,
                   args.scale_branches,
                   args.frequency_branches,
                   args.frequency_selection,
                   args.block_repetition,
                   args.min_channel,
                   args.min_resolution,
                   args.cnn_backbone)

def load_cnn_backbone_model(backbone_name, pretrained=False, **kwargs):
    if backbone_name=='resnest50':
        model = ResNeSt(Bottleneck, [3, 4, 6, 3],
                        radix=2, groups=1, bottleneck_width=64,
                        deep_stem=True, stem_width=32, avg_down=True,
                        avd=True, avd_first=False, **kwargs)
    else:
        print("Invalid backbone")
        sys.exit()

    if pretrained:
        if backbone_name == 'resnest50':
            model.load_state_dict(torch.load('/kaggle/input/resnest/resnest50-528c19ca.pth'))
        else:
            model.load_state_dict(model_zoo.load_url(model_urls[backbone_name]))

        print("Complete loading your pretrained backbone {}".format(backbone_name))
    return model

def model_to_device(args, model):
    model = model.to(args.device)

    return model

In [24]:
import math
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F


def get_freq_indices(method):
    assert method in ['top1','top2','top4','top8','top16','top32',
                      'bot1','bot2','bot4','bot8','bot16','bot32',
                      'low1','low2','low4','low8','low16','low32']
    num_freq = int(method[3:])
    if 'top' in method:
        all_top_indices_x = [0,0,6,0,0,1,1,4,5,1,3,0,0,0,3,2,4,6,3,5,5,2,6,5,5,3,3,4,2,2,6,1]
        all_top_indices_y = [0,1,0,5,2,0,2,0,0,6,0,4,6,3,5,2,6,3,3,3,5,1,1,2,4,2,1,1,3,0,5,3]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0,0,1,1,0,2,2,1,2,0,3,4,0,1,3,0,1,2,3,4,5,0,1,2,3,4,5,6,1,2,3,4]
        all_low_indices_y = [0,1,0,1,2,0,1,2,2,3,0,0,4,3,1,5,4,3,2,1,0,6,5,4,3,2,1,0,6,5,4,3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6,1,3,3,2,4,1,2,4,4,5,1,4,6,2,5,6,1,6,2,2,4,3,3,5,5,6,2,5,5,3,6]
        all_bot_indices_y = [6,4,4,6,6,3,1,4,4,5,6,5,2,2,5,1,4,3,5,0,3,1,1,2,4,2,1,1,5,3,3,3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
    else:
        raise NotImplementedError
    return mapper_x, mapper_y

class MultiFrequencyChannelAttention(nn.Module):
    def __init__(self,
                 in_channels,
                 dct_h, dct_w,
                 frequency_branches=16,
                 frequency_selection='top',
                 reduction=16):
        super(MultiFrequencyChannelAttention, self).__init__()

        assert frequency_branches in [1, 2, 4, 8, 16, 32]
        frequency_selection = frequency_selection + str(frequency_branches)

        self.num_freq = frequency_branches
        self.dct_h = dct_h
        self.dct_w = dct_w

        mapper_x, mapper_y = get_freq_indices(frequency_selection)
        self.num_split = len(mapper_x)
        mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
        mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]

        assert len(mapper_x) == len(mapper_y)

        # fixed DCT init
        for freq_idx in range(frequency_branches):
            self.register_buffer('dct_weight_{}'.format(freq_idx), self.get_dct_filter(dct_h, dct_w, mapper_x[freq_idx], mapper_y[freq_idx], in_channels))

        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, stride=1, padding=0, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, stride=1, padding=0, bias=False))

        self.average_channel_pooling = nn.AdaptiveAvgPool2d(1)
        self.max_channel_pooling = nn.AdaptiveMaxPool2d(1)

    def forward(self, x):
        batch_size, C, H, W = x.shape

        x_pooled = x

        if H != self.dct_h or W != self.dct_w:
            x_pooled = torch.nn.functional.adaptive_avg_pool2d(x, (self.dct_h, self.dct_w))

        multi_spectral_feature_avg, multi_spectral_feature_max, multi_spectral_feature_min = 0, 0, 0
        for name, params in self.state_dict().items():
            if 'dct_weight' in name:
                x_pooled_spectral = x_pooled * params
                multi_spectral_feature_avg += self.average_channel_pooling(x_pooled_spectral)
                multi_spectral_feature_max += self.max_channel_pooling(x_pooled_spectral)
                multi_spectral_feature_min += -self.max_channel_pooling(-x_pooled_spectral)
        multi_spectral_feature_avg = multi_spectral_feature_avg / self.num_freq
        multi_spectral_feature_max = multi_spectral_feature_max / self.num_freq
        multi_spectral_feature_min = multi_spectral_feature_min / self.num_freq


        multi_spectral_avg_map = self.fc(multi_spectral_feature_avg).view(batch_size, C, 1, 1)
        multi_spectral_max_map = self.fc(multi_spectral_feature_max).view(batch_size, C, 1, 1)
        multi_spectral_min_map = self.fc(multi_spectral_feature_min).view(batch_size, C, 1, 1)

        multi_spectral_attention_map = F.sigmoid(multi_spectral_avg_map + multi_spectral_max_map + multi_spectral_min_map)

        return x * multi_spectral_attention_map.expand_as(x)

    def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, in_channels):
        dct_filter = torch.zeros(in_channels, tile_size_x, tile_size_y)

        for t_x in range(tile_size_x):
            for t_y in range(tile_size_y):
                dct_filter[:, t_x, t_y] = self.build_filter(t_x, mapper_x, tile_size_x) * self.build_filter(t_y, mapper_y, tile_size_y)

        return dct_filter

    def build_filter(self, pos, freq, POS):
        result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
        if freq == 0:
            return result
        else:
            return result * math.sqrt(2)

class MFMSAttentionBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 scale_branches=2,
                 frequency_branches=16,
                 frequency_selection='top',
                 block_repetition=1,
                 min_channel=64,
                 min_resolution=8,
                 groups=32):
        super(MFMSAttentionBlock, self).__init__()

        self.scale_branches = scale_branches
        self.frequency_branches = frequency_branches
        self.block_repetition = block_repetition
        self.min_channel = min_channel
        self.min_resolution = min_resolution

        self.multi_scale_branches = nn.ModuleList([])
        for scale_idx in range(scale_branches):
            inter_channel = in_channels // 2**scale_idx
            if inter_channel < self.min_channel: inter_channel = self.min_channel

            self.multi_scale_branches.append(nn.Sequential(
                nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1 + scale_idx, dilation=1 + scale_idx, groups=groups, bias=False),
                nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
                nn.Conv2d(in_channels, inter_channel, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(inter_channel), nn.ReLU(inplace=True)
            ))

        c2wh = dict([(32, 112), (64, 56), (128, 28), (256, 14), (512, 7)])
        self.multi_frequency_branches = nn.ModuleList([])
        self.multi_frequency_branches_conv1 = nn.ModuleList([])
        self.multi_frequency_branches_conv2 = nn.ModuleList([])
        self.alpha_list = nn.ParameterList([nn.Parameter(torch.ones(1)) for _ in range(scale_branches)])
        self.beta_list = nn.ParameterList([nn.Parameter(torch.ones(1)) for _ in range(scale_branches)])

        for scale_idx in range(scale_branches):
            inter_channel = in_channels // 2**scale_idx
            if inter_channel < self.min_channel: inter_channel = self.min_channel

            if frequency_branches > 0:
                self.multi_frequency_branches.append(
                    nn.Sequential(
                        MultiFrequencyChannelAttention(inter_channel, c2wh[in_channels], c2wh[in_channels], frequency_branches, frequency_selection)))
            self.multi_frequency_branches_conv1.append(
                nn.Sequential(
                    nn.Conv2d(inter_channel, 1, kernel_size=1, stride=1, padding=0, bias=False),
                    nn.Sigmoid()))
            self.multi_frequency_branches_conv2.append(
                nn.Sequential(
                    nn.Conv2d(inter_channel, in_channels, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True)))

    def forward(self, x):
        feature_aggregation = 0
        for scale_idx in range(self.scale_branches):
            feature = F.avg_pool2d(x, kernel_size=2 ** scale_idx, stride=2 ** scale_idx, padding=0) if int(x.shape[2] // 2 ** scale_idx) >= self.min_resolution else x
            feature = self.multi_scale_branches[scale_idx](feature)
            if self.frequency_branches > 0:
                feature = self.multi_frequency_branches[scale_idx](feature)
            spatial_attention_map = self.multi_frequency_branches_conv1[scale_idx](feature)
            feature = self.multi_frequency_branches_conv2[scale_idx](feature * (1 - spatial_attention_map) * self.alpha_list[scale_idx] + feature * spatial_attention_map * self.beta_list[scale_idx])
            feature_aggregation += F.interpolate(feature, size=None, scale_factor=2**scale_idx, mode='bilinear', align_corners=None) if (x.shape[2] != feature.shape[2]) or (x.shape[3] != feature.shape[3]) else feature
        feature_aggregation /= self.scale_branches
        feature_aggregation += x

        return feature_aggregation

class UpsampleBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 skip_connection_channels,
                 scale_branches=2,
                 frequency_branches=16,
                 frequency_selection='top',
                 block_repetition=1,
                 min_channel=64,
                 min_resolution=8):
        super(UpsampleBlock, self).__init__()

        in_channels = in_channels + skip_connection_channels
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))

        self.attention_layer = MFMSAttentionBlock(out_channels, scale_branches, frequency_branches, frequency_selection, block_repetition, min_channel, min_resolution)

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))

    def forward(self, x, skip_connection=None):
        x = F.interpolate(x, size=None, scale_factor=2, mode='bilinear', align_corners=None)

        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv1(x)
        x = self.attention_layer(x)
        x = self.conv2(x)

        return x

class CascadedSubDecoderBinary(nn.Module):
    def __init__(self,
                 in_channels,
                 num_classes,
                 scale_factor,
                 interpolation_mode='bilinear'):
        super(CascadedSubDecoderBinary, self).__init__()

        self.output_map_conv = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.output_distance_conv = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
        self.output_boundary_conv = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))

        self.upsample = nn.Upsample(scale_factor=scale_factor, mode=interpolation_mode, align_corners=True)
        self.count = 0

    def forward(self, x):
        map = self.output_map_conv(x) # B, 1, H, W
        distance = self.output_distance_conv(x) * torch.sigmoid(map)
        boundary = self.output_boundary_conv(x) * torch.sigmoid(distance)

        boundary = self.upsample(boundary)
        distance = self.upsample(distance) + torch.sigmoid(boundary)
        map = self.upsample(map) + torch.sigmoid(distance)

        return map, distance, boundary

class MFMSNet(nn.Module):
    def __init__(self,
                 num_classes=1,
                 scale_branches=2,
                 frequency_branches=16,
                 frequency_selection='top',
                 block_repetition=1,
                 min_channel=64,
                 min_resolution=8,
                 cnn_backbone='resnet50'):
        super(MFMSNet, self).__init__()


        self.num_classes = num_classes

        self.feature_encoding = load_cnn_backbone_model(backbone_name=cnn_backbone, pretrained=True)
        
        self.in_channels = 2048
        self.skip_channel_list = [1024, 512, 256, 64]
        self.decoder_channel_list = [256, 128, 64, 32]

        self.feature_encoding.fc = nn.Identity()

        self.skip_channel_down_list = [64, 64, 64, 64]

        self.skip_connection1 = nn.Sequential(
            nn.Conv2d(self.skip_channel_list[0], self.skip_channel_down_list[0], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.skip_channel_down_list[0]), nn.ReLU(inplace=True))
        self.skip_connection2 = nn.Sequential(
            nn.Conv2d(self.skip_channel_list[1], self.skip_channel_down_list[1], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.skip_channel_down_list[1]), nn.ReLU(inplace=True))
        self.skip_connection3 = nn.Sequential(
            nn.Conv2d(self.skip_channel_list[2], self.skip_channel_down_list[2], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.skip_channel_down_list[2]), nn.ReLU(inplace=True))
        self.skip_connection4 = nn.Sequential(
            nn.Conv2d(self.skip_channel_list[3], self.skip_channel_down_list[3], kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.skip_channel_down_list[3]), nn.ReLU(inplace=True))

        self.decoder_stage1 = UpsampleBlock(self.in_channels, self.decoder_channel_list[0], self.skip_channel_down_list[0], scale_branches, frequency_branches, frequency_selection, block_repetition, min_channel, min_resolution)
        self.decoder_stage2 = UpsampleBlock(self.decoder_channel_list[0], self.decoder_channel_list[1], self.skip_channel_down_list[1], scale_branches, frequency_branches, frequency_selection, block_repetition, min_channel, min_resolution)
        self.decoder_stage3 = UpsampleBlock(self.decoder_channel_list[1], self.decoder_channel_list[2], self.skip_channel_down_list[2], scale_branches, frequency_branches, frequency_selection, block_repetition, min_channel, min_resolution)
        self.decoder_stage4 = UpsampleBlock(self.decoder_channel_list[2], self.decoder_channel_list[3], self.skip_channel_down_list[3], scale_branches, frequency_branches, frequency_selection, block_repetition, min_channel, min_resolution)

        # Sub-Decoder
        self.sub_decoder_stage1 = CascadedSubDecoderBinary(self.decoder_channel_list[0], num_classes, scale_factor=16)
        self.sub_decoder_stage2 = CascadedSubDecoderBinary(self.decoder_channel_list[1], num_classes, scale_factor=8)
        self.sub_decoder_stage3 = CascadedSubDecoderBinary(self.decoder_channel_list[2], num_classes, scale_factor=4)
        self.sub_decoder_stage4 = CascadedSubDecoderBinary(self.decoder_channel_list[3], num_classes, scale_factor=2)

    def forward(self, x, mode='train'):
        if x.size()[1] == 1: x = x.repeat(1, 3, 1, 1)
        _, _, H, W = x.shape

        features, x = self.feature_encoding.forward_feature(x, out_block_stage=4)

        x1 = self.decoder_stage1(x, self.skip_connection1(features[0]))
        x2 = self.decoder_stage2(x1, self.skip_connection2(features[1]))
        x3 = self.decoder_stage3(x2, self.skip_connection3(features[2]))
        x4 = self.decoder_stage4(x3, self.skip_connection4(features[3]))
        if mode == 'train':
            map_output1, distance_output1, boundary_output1 = self.sub_decoder_stage1(x1)
            map_output2, distance_output2, boundary_output2 = self.sub_decoder_stage2(x2)
            map_output3, distance_output3, boundary_output3 = self.sub_decoder_stage3(x3)
            map_output4, distance_output4, boundary_output4 = self.sub_decoder_stage4(x4)


            return [map_output1, distance_output1, boundary_output1], \
                   [map_output2, distance_output2, boundary_output2], \
                   [map_output3, distance_output3, boundary_output3], \
                   [map_output4, distance_output4, boundary_output4]
        else:
            map, _, _ = self.sub_decoder_stage4(x4)

            return map

    def _calculate_criterion(self, y_pred, y_true):
        if isinstance(y_pred, tuple):
            y_pred = y_pred[0][0]  
        loss = self.structure_loss(y_pred, y_true)
        return loss

    import torch
    import torch.nn.functional as F

    def structure_loss(self, pred, mask):
        if not isinstance(pred, torch.Tensor) or not isinstance(mask, torch.Tensor):
            raise TypeError("Both pred and mask should be torch tensors")

        weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)

        wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')

        wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))

        pred_prob = torch.sigmoid(pred)

        inter = ((pred_prob * mask) * weit).sum(dim=(2, 3))
        union = ((pred_prob + mask) * weit).sum(dim=(2, 3))

        wiou = 1 - (inter + 1) / (union - inter + 1)

        total_loss = (wbce + wiou).mean()  

        return total_loss


In [25]:
import scipy
import numpy as np
from skimage import morphology
from skimage.measure import label, regionprops
from scipy import ndimage
from scipy.ndimage import convolve, distance_transform_edt as bwdist
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, jaccard_score, confusion_matrix)


metrics_list = ['DSC', 'IoU', 'WeightedF-Measure', 'S-Measure', 'E-Measure', 'MAE']

class BMIS_Metrics_Calculator(object):
    def __init__(self, metrics_list):
        super(BMIS_Metrics_Calculator).__init__()

        self.metrics_list = metrics_list

        self.smooth = 1e-5

        self.total_metrics_dict = dict()

        for metric in self.metrics_list:
            self.total_metrics_dict[metric] = list()

    def get_metrics_dict(self, y_pred, y_true):
        y_true = y_true.squeeze().detach().cpu().numpy()
        y_pred = (y_pred.squeeze().detach().cpu().numpy() >= 0.5).astype(np.int_)
        # y_pred = y_pred.squeeze().detach().cpu().numpy()

        y_true = np.asarray(y_true, np.float32)
        y_true /= (y_true.max() + 1e-8)
        y_true[y_true > 0.5] = 1; y_true[y_true != 1] = 0

        metrics_dict = dict()

        for metric in self.metrics_list:
            metrics_dict[metric] = 0
            result = self.get_metrics(metric, y_pred, y_true)
            if np.isnan(result): result = 1e-6
            metrics_dict[metric] = result

        return metrics_dict

    def get_metrics(self, metric, y_pred, y_true):
        if metric == 'Accuracy': return self.calculate_Accuracy(y_pred, y_true)
        elif metric == 'DSC': return self.calculate_DSC(y_pred, y_true)
        elif metric == 'Precision': return self.calculate_Precision(y_pred, y_true)
        elif metric == 'Recall': return self.calculate_Recall(y_pred, y_true)
        elif metric == 'Specificity': return self.calculate_Specificity(y_pred, y_true)
        elif metric == 'Jaccard': return self.calculate_Jaccard(y_pred, y_true)
        elif metric == 'IoU': return self.calculate_IoU(y_pred, y_true)
        elif metric == 'WeightedF-Measure':  return self.calculate_WeightedFMeasure(y_pred, y_true)
        elif metric == 'F-Measure': return self.calculate_FMeasure(y_pred, y_true)
        elif metric == 'S-Measure': return self.calculate_SMeasure(y_pred, y_true)
        elif metric == 'E-Measure': return self.calculate_EMeasure(y_pred, y_true)
        elif metric == 'MAE': return self.calculate_MAE(y_pred, y_true)

    def calculate_Accuracy(self, y_pred, y_true):
        y_true = np.asarray(y_true.flatten(), dtype=np.int64)
        y_pred = np.asarray(y_pred.flatten(), dtype=np.int64)

        return accuracy_score(y_true, y_pred)

    def calculate_Precision(self, y_pred, y_true):
        y_true = np.asarray(y_true.flatten(), dtype=np.int64)
        y_pred = np.asarray(y_pred.flatten(), dtype=np.int64)

        return precision_score(y_true, y_pred)

    def calculate_Recall(self, y_pred, y_true):
        y_true = np.asarray(y_true.flatten(), dtype=np.int64)
        y_pred = np.asarray(y_pred.flatten(), dtype=np.int64)

        return recall_score(y_true, y_pred)

    def calculate_Specificity(self, y_pred, y_true):
        y_true = np.asarray(y_true.flatten(), dtype=np.int64)
        y_pred = np.asarray(y_pred.flatten(), dtype=np.int64)

        cm = list(confusion_matrix(y_true, y_pred).ravel())

        if len(cm) == 1: cm += [0, 0, 0]

        tn, fp, fn, tp = cm
        specificity = tn / (tn+fp)

        return specificity

    def calculate_Jaccard(self, y_pred, y_true):
        y_true = np.asarray(y_true.flatten(), dtype=np.int64)
        y_pred = np.asarray(y_pred.flatten(), dtype=np.int64)

        return jaccard_score(y_true, y_pred)

    def calculate_DSC(self, y_pred, y_true):
        y_true = np.asarray(y_true.flatten(), dtype=np.int64)
        y_pred = np.asarray(y_pred.flatten(), dtype=np.int64)

        intersection = np.sum(y_true * y_pred)
        return (2. * intersection + self.smooth) / (np.sum(y_true) + np.sum(y_pred) + self.smooth)

    def calculate_IoU(self, y_pred, y_true):
        y_pred_f = y_pred > 0.5
        y_true_f = y_true > 0.5

        intersection_f = (y_pred_f & y_true_f).sum()
        union_f = (y_pred_f | y_true_f).sum()

        iou_f = (intersection_f + self.smooth) / (union_f + self.smooth)

        return iou_f

    def calculate_WeightedFMeasure(self, y_pred, y_true):
        Dst, Idxt = bwdist(y_true == 0, return_indices=True)

        E = np.abs(y_pred - y_true)
        Et = np.copy(E)
        Et[y_true == 0] = Et[Idxt[0][y_true == 0], Idxt[1][y_true == 0]]

        K = self.matlab_style_gauss2D((7, 7), sigma=5)
        EA = convolve(Et, weights=K, mode='constant', cval=0)
        MIN_E_EA = np.where(np.array(y_true, dtype=np.bool_) & (EA < E), EA, E)

        B = np.where(y_true == 0, 2 - np.exp(np.log(0.5) / 5 * Dst), np.ones_like(y_true))
        Ew = MIN_E_EA * B

        TPw = np.sum(y_true) - np.sum(Ew[y_true == 1])
        FPw = np.sum(Ew[y_true == 0])

        R = 1 - np.mean(Ew[np.array(y_true, dtype=np.bool_)])
        P = TPw / (1e-6 + TPw + FPw)

        Q = 2 * R * P / (1e-6 + R + P)

        return Q

    def calculate_FMeasure(self, y_pred, y_true):
        th = 2 * y_pred.mean()
        if th > 1:  th = 1
        binary = np.zeros_like(y_pred)
        binary[y_pred >= th] = 1
        hard_gt = np.zeros_like(y_true)
        hard_gt[y_true > 0.5] = 1
        tp = (binary * hard_gt).sum()
        if tp == 0:
            meanF = 0
        else:
            pre = tp / binary.sum()
            rec = tp / hard_gt.sum()
            meanF = 1.3 * pre * rec / (0.3 * pre + rec)

        return meanF

    def calculate_SMeasure(self, y_pred, y_true):
        y = np.mean(y_true)

        if y == 0:
            score = 1 - np.mean(y_pred)
        elif y == 1:
            score = np.mean(y_pred)
        else:
            score = 0.5 * self.object(y_pred, y_true) + 0.5 * self.region(y_pred, y_true)
        return score

    def calculate_EMeasure(self, y_pred, y_true):
        th = 2 * y_pred.mean()
        if th > 1:
            th = 1
        FM = np.zeros(y_true.shape)
        FM[y_pred >= th] = 1
        FM = np.array(FM,dtype=bool)
        GT = np.array(y_true,dtype=bool)
        dFM = np.double(FM)
        if (sum(sum(np.double(GT)))==0):
            enhanced_matrix = 1.0-dFM
        elif (sum(sum(np.double(~GT)))==0):
            enhanced_matrix = dFM
        else:
            dGT = np.double(GT)
            align_matrix = self.AlignmentTerm(dFM, dGT)
            enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix)
        [w, h] = np.shape(GT)
        score = sum(sum(enhanced_matrix))/ (w * h - 1 + 1e-8)
        return score

    def calculate_MAE(self, y_pred, y_true):
        return np.mean(np.abs(y_pred - y_true))

    def matlab_style_gauss2D(self, shape=(7, 7), sigma=5):
        """
        2D gaussian mask - should give the same result as MATLAB's
        fspecial('gaussian',[shape],[sigma])
        """
        m, n = [(ss - 1.) / 2. for ss in shape]
        y, x = np.ogrid[-m:m + 1, -n:n + 1]
        h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
        h[h < np.finfo(h.dtype).eps * h.max()] = 0
        sumh = h.sum()
        if sumh != 0:
            h /= sumh
        return h

    def object(self, pred, gt):
        fg = pred * gt
        bg = (1 - pred) * (1 - gt)

        u = np.mean(gt)

        return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt))

    def s_object(self, in1, in2):
        in1 = np.array(in1, dtype=np.int64)
        in2 = np.array(in2, dtype=np.int64)

        x = np.mean(in1[in2])
        sigma_x = np.std(in1[in2])
        return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8)

    def region(self, pred, gt):
        [y, x] = ndimage.center_of_mass(gt)
        y = int(round(y)) + 1
        x = int(round(x)) + 1
        [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y)
        pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y)

        score1 = self.ssim(pred1, gt1)
        score2 = self.ssim(pred2, gt2)
        score3 = self.ssim(pred3, gt3)
        score4 = self.ssim(pred4, gt4)

        return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4

    def divideGT(self, gt, x, y):
        h, w = gt.shape
        area = h * w
        LT = gt[0:y, 0:x]
        RT = gt[0:y, x:w]
        LB = gt[y:h, 0:x]
        RB = gt[y:h, x:w]

        w1 = x * y / area
        w2 = y * (w - x) / area
        w3 = (h - y) * x / area
        w4 = (h - y) * (w - x) / area

        return LT, RT, LB, RB, w1, w2, w3, w4

    def dividePred(self, pred, x, y):
        h, w = pred.shape
        LT = pred[0:y, 0:x]
        RT = pred[0:y, x:w]
        LB = pred[y:h, 0:x]
        RB = pred[y:h, x:w]

        return LT, RT, LB, RB

    def ssim(self, in1, in2):
        in2 = np.float32(in2)
        h, w = in1.shape
        N = h * w

        x = np.mean(in1)
        y = np.mean(in2)
        sigma_x = np.var(in1)
        sigma_y = np.var(in2)
        sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1)

        alpha = 4 * x * y * sigma_xy
        beta = (x * x + y * y) * (sigma_x + sigma_y)

        if alpha != 0:
            score = alpha / (beta + 1e-8)
        elif alpha == 0 and beta == 0:
            score = 1
        else:
            score = 0

        return score

    def AlignmentTerm(self,dFM,dGT):
        mu_FM = np.mean(dFM)
        mu_GT = np.mean(dGT)
        align_FM = dFM - mu_FM
        align_GT = dGT - mu_GT
        align_Matrix = 2. * (align_GT * align_FM)/ (align_GT* align_GT + align_FM* align_FM + 1e-8)
        return align_Matrix

    def EnhancedAlignmentTerm(self,align_Matrix):
        enhanced = np.power(align_Matrix + 1,2) / 4
        return enhanced


In [26]:
import os
import sys


from torch.utils.data import DataLoader


class BaseSegmentationExperiment(object):
    def __init__(self, args):
        super(BaseSegmentationExperiment, self).__init__()

        self.args = args

        self.args.device = get_device()

        
        print("STEP1. Load {} Training Dataset Loader...".format(args.train_data_type))
        train_image_transform, train_target_transform = self.transform_generator()
        if args.train_data_type == 'T1N':
            train_dataset = Covid19CTScanDataset(args.train_dataset_dir, mode='train', transform=train_image_transform, target_transform=train_target_transform)
            self.train_loader = DataLoader(train_dataset, batch_size=12, shuffle=False, num_workers=args.num_workers, pin_memory=True)

        print("STEP2. Load MADGNet ...")
        self.model = IS2D_model(args).to(self.args.device)

        print("STEP2. Load MADGNet ...")
        self.model = IS2D_model(args)

    def fix_seed(self):
        random.seed(4321)
        np.random.seed(4321)
        torch.manual_seed(4321)
        torch.cuda.manual_seed(4321)
        torch.cuda.manual_seed_all(4321)

    
            

In [27]:
import os
import random
import sys

import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

import pandas as pd
from PIL import Image

class Covid19CTScanDataset(Dataset) :
    def __init__(self, dataset_dir, mode, transform=None, target_transform=None):
        super(Covid19CTScanDataset, self).__init__()
        self.image_folder = 'T1N'
        self.label_folder = 'SEGMENTED'
        self.mode = mode
        self.mapping_dir = '/kaggle/input/mapping-files'
        self.dataset_dir = '/kaggle/input/brats2021/datasetProcessed'
        self.transform = transform
        self.target_transform = target_transform
        self.frame = pd.read_csv(os.path.join(self.mapping_dir, '{}_frame.csv'.format(mode)))#[:10]

        print(len(self.frame))

        if mode == 'train':
            self.frame = pd.read_csv(os.path.join(self.mapping_dir, 'train_frame.csv'))
            train, val = train_test_split(self.frame, test_size=0.2, shuffle=False, random_state=4321)

            print(len(train))
            print(len(val))

            print(train)

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        image_path = os.path.join(self.dataset_dir, self.image_folder, self.frame.image_path[idx])
        label_path = os.path.join(self.dataset_dir, self.label_folder, self.frame.mask_path[idx])

        image = Image.open(image_path).convert('L')
        label = Image.open(label_path).convert('L')

        if self.transform:
            seed = random.randint(0, 2 ** 32)
            self._set_seed(seed); image = self.transform(image)
            self._set_seed(seed); label = self.target_transform(label)

        label[label >= 0.5] = 1; label[label < 0.5] = 0

        return image, label

    def _set_seed(self, seed):
        random.seed(seed)
        torch.manual_seed(seed)

In [28]:
from tqdm import tqdm  
import torch
import random
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score, recall_score, accuracy_score, jaccard_score

In [29]:
from types import SimpleNamespace
import torch
import torchvision.transforms as transforms
import numpy as np

def get_device():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("You are using \"{}\" device.".format(device))
    return device

class BMISegmentationExperiment(object):
    def __init__(self, args):
        if isinstance(args, dict):
            args = SimpleNamespace(**args)

        super().__init__()

        self.args = args  
        self.args.device = get_device()  
        
        print("STEP1. Load {} Training Dataset Loader...".format(args.train_data_type))
        train_image_transform, train_target_transform = self.transform_generator()
       
        train_dataset = Covid19CTScanDataset(args.train_dataset_dir, mode='train', transform=train_image_transform, target_transform=train_target_transform)
        self.train_loader = DataLoader(train_dataset, batch_size=12, shuffle=False, num_workers=args.num_workers, pin_memory=True)

        print("STEP2. Load MADGNet ...")
        self.model = IS2D_model(args).to(self.args.device)


    def fix_seed(self):
        random.seed(4321)
        np.random.seed(4321)
        torch.manual_seed(4321)
        torch.cuda.manual_seed(4321)
        torch.cuda.manual_seed_all(4321)

    def forward(self, image, target, mode):
        image, target = image.to(self.args.device), target.to(self.args.device)

        with torch.cuda.amp.autocast(enabled=True):
            output = self.model(image, mode)
            loss = self.model._calculate_criterion(output, target)

        return loss, output, target

    def get_device():
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        print("You are using \"{}\" device.".format(device))

        return device

    def transform_generator(self):
        transform_list = [
            transforms.Resize((self.args.image_size, self.args.image_size)),
            transforms.ToTensor(),
        ]

        target_transform_list = [
            transforms.Resize((self.args.image_size, self.args.image_size)),
            transforms.ToTensor(),
        ]

        return transforms.Compose(transform_list), transforms.Compose(target_transform_list)

    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for image, target in val_loader:
                loss, _, _ = self.forward(image, target, mode='val')
                total_loss += loss.item()

        avg_loss = total_loss / len(val_loader)
        return avg_loss

    def train(self):
        val_image_transform, val_target_transform = self.transform_generator()
        val_dataset = Covid19CTScanDataset(self.args.train_dataset_dir, mode='val', transform=val_image_transform,
                                           target_transform=val_target_transform)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=self.args.num_workers,
                                pin_memory=True)

        best_val_loss = float('inf')
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)

        for epoch in range(self.args.final_epoch):
            self.model.train()
            total_train_loss = 0

            # Training loop
            for image, target in self.train_loader:
                loss, _, _ = self.forward(image, target, mode='train')
                total_train_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            avg_train_loss = total_train_loss / len(self.train_loader)

            val_loss = self.validate(val_loader)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), 'best_model.pth')

            print(f"Epoch [{epoch + 1}/{self.args.final_epoch}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        print("Training completed. Saving the final model weights to 'final_model.pth'.")
        torch.save(self.model.state_dict(), 'final_model.pth')

        print("Final model saved.")


In [30]:
import os
import sys

args = {
    "num_workers": 4,
    "data_path": "/kaggle/input/brats2021/datasetProcessed",
    "train_data_type": "T1C",
    "final_epoch": 50,
    "metric_list": ['DSC', 'IoU'],
    "metrics_list": ['DSC', 'IoU', 'WeightedF-Measure', 'S-Measure', 'E-Measure', 'MAE'],
    "num_channels": 1,
    "image_size": 352,
    "in_channels": 2048,
    "num_classes": 1,
    "scale_branches": 2,
    "frequency_branches": 16,
    "frequency_selection": 'top',
    "block_repetition": 1,
    "min_channel": 64,
    "min_resolution": 8,
    "groups": 32,
    "cnn_backbone": 'resnest50'
}

def IS2D_train(args):
    try:
        print("Started training for MRI Modality:", args["train_data_type"])
        
        if not args["data_path"] or not args["train_data_type"]:
            raise ValueError("Both 'data_path' and 'train_d self.args.device = get_device(ata_type' must be defined in the arguments.")
        
        args["train_dataset_dir"] = os.path.join(args["data_path"], args["train_data_type"])
        if not os.path.exists(args["train_dataset_dir"]):
            raise FileNotFoundError(f"The dataset directory '{args['train_dataset_dir']}' does not exist.")

        experiment = BMISegmentationExperiment(args)
        experiment.train()

    except KeyError as e:
        print(f"KeyError: Missing key in arguments - {str(e)}")
        sys.exit(1)
    except FileNotFoundError as e:
        print(f"FileNotFoundError: {str(e)}")
        sys.exit(1)
    except ValueError as e:
        print(f"ValueError: {str(e)}")
        sys.exit(1)
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)}")
        sys.exit(1)

IS2D_train(args)


Started training for MRI Modality: T1C
You are using "cuda" device.
STEP1. Load T1C Training Dataset Loader...
1038
830
208
     Image_Id                     image_path                      mask_path
0           0  BraTS-GLI-00005-100-t1n_1.jpg  BraTS-GLI-00005-100-seg_1.jpg
1           1  BraTS-GLI-00005-100-t1n_2.jpg  BraTS-GLI-00005-100-seg_2.jpg
2           2  BraTS-GLI-00005-100-t1n_3.jpg  BraTS-GLI-00005-100-seg_3.jpg
3           3  BraTS-GLI-00005-101-t1n_1.jpg  BraTS-GLI-00005-101-seg_1.jpg
4           4  BraTS-GLI-00005-101-t1n_2.jpg  BraTS-GLI-00005-101-seg_2.jpg
..        ...                            ...                            ...
825       825  BraTS-GLI-02826-100-t1n_1.jpg  BraTS-GLI-02826-100-seg_1.jpg
826       826  BraTS-GLI-02826-100-t1n_2.jpg  BraTS-GLI-02826-100-seg_2.jpg
827       827  BraTS-GLI-02826-100-t1n_3.jpg  BraTS-GLI-02826-100-seg_3.jpg
828       828  BraTS-GLI-02826-101-t1n_1.jpg  BraTS-GLI-02826-101-seg_1.jpg
829       829  BraTS-GLI-02826-101-t1n_2

  model.load_state_dict(torch.load('/kaggle/input/resnest/resnest50-528c19ca.pth'))


Complete loading your pretrained backbone resnest50
44


  with torch.cuda.amp.autocast(enabled=True):


Epoch [1/50], Train Loss: 1.1515, Validation Loss: 1.7503
Epoch [2/50], Train Loss: 0.9767, Validation Loss: 1.7863
Epoch [3/50], Train Loss: 0.9474, Validation Loss: 1.8846
Epoch [4/50], Train Loss: 0.9306, Validation Loss: 1.4633
Epoch [5/50], Train Loss: 0.9152, Validation Loss: 1.5874
Epoch [6/50], Train Loss: 0.9104, Validation Loss: 1.4966
Epoch [7/50], Train Loss: 0.9008, Validation Loss: 1.6459
Epoch [8/50], Train Loss: 0.9042, Validation Loss: 1.6893
Epoch [9/50], Train Loss: 0.8910, Validation Loss: 1.7420
Epoch [10/50], Train Loss: 0.8856, Validation Loss: 1.6590
Epoch [11/50], Train Loss: 0.8792, Validation Loss: 1.7112
Epoch [12/50], Train Loss: 0.8796, Validation Loss: 1.6755
Epoch [13/50], Train Loss: 0.8780, Validation Loss: 1.6906
Epoch [14/50], Train Loss: 0.8856, Validation Loss: 1.7171
Epoch [15/50], Train Loss: 0.8867, Validation Loss: 1.7520
Epoch [16/50], Train Loss: 0.8697, Validation Loss: 1.6947
Epoch [17/50], Train Loss: 0.8657, Validation Loss: 1.7312
Epoch 