In [52]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [53]:
# helper methods


def group_dict_by_key(cond, d):
    return_val = [dict(), dict()]
    for key in d.keys():
        match = bool(cond(key))
        ind = int(not match)
        return_val[ind][key] = d[key]
    return (*return_val, )


def group_by_key_prefix_and_remove_prefix(prefix, d):
    kwargs_with_prefix, kwargs = group_dict_by_key(
        lambda x: x.startswith(prefix), d)
    kwargs_without_prefix = dict(
        map(lambda x: (x[0][len(prefix):], x[1]),
            tuple(kwargs_with_prefix.items())))
    return kwargs_without_prefix, kwargs

In [54]:
# classes


class LayerNorm(nn.Module):  # layernorm, but done in the channel dimension #1
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) / (std + self.eps) * self.g + self.b

In [55]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)


In [56]:
class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.BatchNorm2d(dim_in),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)



In [57]:
class Attention(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        padding = proj_kernel // 2
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)

        self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
        self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        b, n, _, y, h = *shape, self.heads
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)



In [58]:
class Transformer(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


In [59]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

In [60]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, up_sample_mode):
        super(UpBlock, self).__init__()
        if up_sample_mode == 'conv_transpose':
            #self.up_sample = nn.ConvTranspose2d(in_channels-out_channels, in_channels-out_channels, kernel_size=2, stride=2)        
            self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)        
        elif up_sample_mode == 'bilinear':
            self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            raise ValueError("Unsupported `up_sample_mode` (can take one of `conv_transpose` or `bilinear`)")
        self.double_conv = DoubleConv(in_channels, out_channels)

    #def forward(self, down_input, skip_input= torch.Tensor().to(device)):
    #    x = self.up_sample(down_input)
    #    x = torch.cat([x, skip_input], dim=1)
    #    return self.double_conv(x)
    
    def forward(self, down_input):        
        x = self.up_sample(down_input)        
        #x = torch.cat([x, skip_input], dim=1)
        return self.double_conv(x)

class UpBlockskip(nn.Module):
    def __init__(self, in_channels, out_channels, up_sample_mode):
        super(UpBlockskip, self).__init__()
        if up_sample_mode == 'conv_transpose':
            #self.up_sample = nn.ConvTranspose2d(in_channels-out_channels, in_channels-out_channels, kernel_size=2, stride=2)        
            self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)        
        elif up_sample_mode == 'bilinear':
            self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            raise ValueError("Unsupported `up_sample_mode` (can take one of `conv_transpose` or `bilinear`)")
        self.double_conv = DoubleConv(in_channels, out_channels)

    #def forward(self, down_input, skip_input= torch.Tensor().to(device)):
    #    x = self.up_sample(down_input)
    #    x = torch.cat([x, skip_input], dim=1)
    #    return self.double_conv(x)
    
    def forward(self, down_input, skip_input):        
        x = torch.cat([down_input, skip_input], dim=1)
        x = self.up_sample(x)          
        return self.double_conv(x)


In [61]:
class CvT(nn.Module):
    def __init__(
            self,
            *,
            num_classes=2, s1_emb_dim=32, s1_emb_kernel=5, s1_emb_stride=2,
            s1_proj_kernel=3, s1_kv_proj_stride=2, s1_heads=1, s1_depth=1, s1_mlp_mult=4,
    
            s2_emb_dim=32, s2_emb_kernel=3, s2_emb_stride=2, s2_proj_kernel=3, s2_kv_proj_stride=2,
            s2_heads=1, s2_depth=2, s2_mlp_mult=4,
    
            s3_emb_dim=32, s3_emb_kernel=3, s3_emb_stride=2, s3_proj_kernel=3,
            s3_kv_proj_stride=2, s3_heads=6, s3_depth=10, s3_mlp_mult=4, dropout=0.,
            up_sample_mode='bilinear'):
        super().__init__()
        kwargs = dict(locals())

        dim_input = 3
        layers = []

        #Attention path
        self.att_convB1 = nn.Conv2d(dim_input,
                                    s1_emb_dim,
                                    kernel_size=s1_emb_kernel,
                                    padding=(s1_emb_kernel // 2),
                                    stride=s1_emb_stride)
        self.att_normB1 = LayerNorm(s1_emb_dim)
        self.att_tranB1 = Transformer(dim=s1_emb_dim,
                                      proj_kernel=s1_proj_kernel,
                                      kv_proj_stride=s1_kv_proj_stride,
                                      depth=s1_depth,
                                      heads=s1_heads,
                                      mlp_mult=s1_mlp_mult,
                                      dropout=dropout)

        self.att_convB2 = nn.Conv2d(s1_emb_dim,
                                    s2_emb_dim,
                                    kernel_size=s2_emb_kernel,
                                    padding=(s2_emb_kernel // 2),
                                    stride=s2_emb_stride)
        self.att_normB2 = LayerNorm(s2_emb_dim)
        self.att_tranB2 = Transformer(dim=s2_emb_dim,
                                      proj_kernel=s2_proj_kernel,
                                      kv_proj_stride=s2_kv_proj_stride,
                                      depth=s2_depth,
                                      heads=s2_heads,
                                      mlp_mult=s2_mlp_mult,
                                      dropout=dropout)

        self.att_convB3 = nn.Conv2d(s2_emb_dim,
                                    s3_emb_dim,
                                    kernel_size=s3_emb_kernel,
                                    padding=(s3_emb_kernel // 2),
                                    stride=s3_emb_stride)
        self.att_normB3 = LayerNorm(s3_emb_dim)
        self.att_tranB3 = Transformer(dim=s3_emb_dim,
                                      proj_kernel=s3_proj_kernel,
                                      kv_proj_stride=s3_kv_proj_stride,
                                      depth=s3_depth,
                                      heads=s3_heads,
                                      mlp_mult=s3_mlp_mult,
                                      dropout=dropout)

        self.double_conv = DoubleConv(s3_emb_dim, 512)

        #Upsampling path
        self.up_convB4 = UpBlock(512, 128, up_sample_mode)
        self.up_convB3 = UpBlockskip(128+32, 64, up_sample_mode)
        self.norm1     = LayerNorm(64)
        self.up_convB2 = UpBlockskip(64+32, 32, up_sample_mode)
        self.norm2     = LayerNorm(32)
        #self.up_convB1 = UpBlock(32, 16, up_sample_mode)

        # Output match
        self.conv_last = nn.Conv2d(32, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.att_convB1(x) ;skip_l1 = x #;print('B1',x.shape)        
        x = self.att_normB1(x) #;print('B1',x.shape)
        x = self.att_tranB1(x) #;print('B1',x.shape)
        
        x = self.att_convB2(x) #;print('B2',x.shape)
        skip_l2 = x            #;print('B2',x.shape)
        x = self.att_normB2(x) #;print('B2',x.shape)
        x = self.att_tranB2(x) #;print('B2',x.shape)
        
        x = self.att_convB3(x) #;print('B3',x.shape)
        x = self.att_normB3(x) #;print('B3',x.shape)
        x = self.att_tranB3(x) #;print('B3',x.shape)

        x = self.double_conv(x) #;print('Dob',x.shape)
        
        x = self.up_convB4(x) #;print(x.shape)        
        x = self.up_convB3(x, skip_l2)        #;print(x.shape)
        x = self.norm1(x)
        x = self.up_convB2(x, skip_l1) #;print(x.shape)
        x = self.norm2(x)
       # print('x:',x.shape, ' - skip:', skip_l1.shape)
        #x = self.up_convB1(x)
        x = self.conv_last(x) #;print(x.shape)
        #print('x:',x.shape)
        return x

In [62]:
from torchinfo import summary

model = CvT()

#summary(model.to('cpu'), (1, 3, 256, 256), col_width=16, verbose=0)

Layer (type:depth-idx)                                       Output Shape     Param #
CvT                                                          --               --
├─Transformer: 1                                             --               --
│    └─ModuleList: 2-1                                       --               --
│    │    └─ModuleList: 3-1                                  --               17,408
├─Transformer: 1                                             --               --
│    └─ModuleList: 2-2                                       --               --
│    │    └─ModuleList: 3-2                                  --               33,792
│    │    └─ModuleList: 3-3                                  --               33,792
├─Transformer: 1                                             --               --
│    └─ModuleList: 2-3                                       --               --
│    │    └─ModuleList: 3-4                                  --               41,984
│    │ 

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [None]:
def ce_loss(true, logits, weights, ignore=255):
    """Computes the weighted multi-class cross-entropy loss.
    Args:
        true: a tensor of shape [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        weight: a tensor of shape [C,]. The weights attributed
            to each class.
        ignore: the class index to ignore.
    Returns:
        ce_loss: the weighted multi-class cross-entropy loss.
    """
    ce_loss = F.cross_entropy(
        logits.float(),
        true.long(),
        ignore_index=ignore,
        weight=weights,
    )
    return ce_loss

In [None]:
#PyTorch
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()
        self.__name__ = 'DiceBCELoss'

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() +
                                                        targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE
    
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()
        self.__name__ = 'IoULoss'

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)       
        inputs = torch.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions 
        
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU



In [None]:
class WCrossEntropy(nn.Module):
    def __init__(self, weight=None, size_average=True, DEVICE='cuda'):
        super(WCrossEntropy, self).__init__()
        self.__name__ = 'WCEntropy'
        self.weight = None
        if weight!= None:
            self.weight = torch.tensor(weight).to(DEVICE)

    def forward(self, inputs, targets):
        # From tensor [B, C, H, W] to tensor [B, H, W]        
        targets = torch.argmax(targets, dim = 1)
        loss = nn.CrossEntropyLoss(weight=self.weight)
        output = loss(inputs, targets )
        return output

In [None]:
# class_weigths = [2.0, 8.0]
# inputs = torch.FloatTensor(2, 2, 5, 5)
# logits = torch.FloatTensor(2, 2, 5, 5)
# weigths = torch.FloatTensor(class_weigths)

# print('inputs:', inputs.shape)
# print('weigths', weigths.shape)
# print('logits ', logits.shape)

# logits = torch.tensor(logits, dtype=torch.long)
# #logits = logits.long()
# #logits
# #inputs = inputs.view(-1)
# #print('inputs:', inputs.shape)

# # loss = nn.CrossEntropyLoss()
# # loss(inputs, logits)

# print(logits[0, 0, 0])

In [None]:
# one_hot = torch.tensor(
#                     np.array([
#                     [[[1, 1, 1, 0, 0], [0, 0, 0, 0, 0]],    
#                     [[0, 0, 0, 0, 0], [1, 1, 1, 0, 0]],
#                     [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]],],
#                     [[[1, 1, 1, 0, 0], [0, 0, 0, 0, 0]],    
#                     [[0, 0, 0, 0, 0], [1, 1, 1, 0, 0]],
#                     [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]],]
#                    ])
# )
# print(one_hot.shape)
# print(one_hot)

# x = torch.argmax(one_hot, dim = 1)

# print(x.shape)
# print(x)


In [None]:
# import os, cv2
# import pandas as pd
# import matplotlib.pyplot as plt

# def one_hot_encode(label, label_values):
#     semantic_map = []
#     for colour in label_values:
#         equality = np.equal(label, colour)
#         class_map = np.all(equality, axis = -1)
#         semantic_map.append(class_map)
#     semantic_map = np.stack(semantic_map, axis=-1)

#     return semantic_map


# DATA_DIR = './datasets/tiff/'
# masks_dir = os.path.join(DATA_DIR, 'train_labels')
# class_dict = pd.read_csv("./datasets/label_class_dict.csv")
# class_names = class_dict['name'].tolist()
# class_rgb_values = class_dict[['r','g','b']].values.tolist()

# select_classes = ['background', 'building']
# select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
# select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

# mask_paths  = [os.path.join(masks_dir, image_id) for image_id in sorted(os.listdir(masks_dir))]
# freq = np.zeros((1,2))
# for img in mask_paths:
#     print(img)
#     mask  = cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2RGB)   
#     mask = one_hot_encode(mask, class_rgb_values).astype('float')
#     mask = np.argmax(mask , axis = -1)
#     #print(mask)
#     #plt.imshow(mask, cmap='gray')
#     unique, counts = np.unique(mask, return_counts=True)    
#     freq += counts
# print(freq)



In [None]:
# import matplotlib.pyplot as plt

# select_classes = ['background', 'building']
# freqs = freq[0]  #Error in prior cell, line 36
# total_pixels = np.sum(freqs)
# freq_prop = freqs/total_pixels


# class_p = {select_classes[0]:freq_prop[0], select_classes[1]:freq_prop[1]}

# import seaborn as sns
# sns.set_theme(style="whitegrid")
# ax = sns.barplot(x=list(class_p.keys()), y=list(class_p.values()))
# plt.show()


# print(f'total_pixels:{total_pixels:.0f}')
# print(f'{select_classes[0]}:{freqs[0]:.0f}({freq_prop[0]*100:.1f}%) {select_classes[1]}:{freqs[1]:.0f}({freq_prop[1]*100:.1f}%)')
# ws  = total_pixels / (len(select_classes) * freqs)
# wsn = ws / np.linalg.norm(ws)
# print(f'weigth_{select_classes[0]}:{ws[0]:.4f} weigth_{select_classes[1]}:{ws[1]:.4f}')
# print(f'Nweigth_{select_classes[0]}:{wsn[0]:.4f} Nweigth_{select_classes[1]}:{wsn[1]:.4f}')
