In [1]:
import torch
import numpy as np
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class BaseModel():

    def predict(self, dataloader):
        """This function returns the predictions for all images present in the dataloader"""
        self.eval()

        with torch.no_grad():
            
            predictions = []
            for images, _ in tqdm(dataloader):
                images = images.to(device)
                outputs = self(images)
                outputs = torch.nn.functional.softmax(outputs, dim=1)
                predictions.append(outputs.permute(0, 2, 3, 1))
            predictions = torch.cat(predictions, dim=0)
        
        return predictions.detach().cpu().numpy()
    
    def predict_image(self, image):
        """This function returns the prediction of one image"""
        self.eval()
        with torch.no_grad():
            
            image = image.unsqueeze(0)
            image = image.to(device)
            output = self(image)
            output = torch.nn.functional.softmax(output, dim=1)
            output = output.detach().cpu().numpy()
            output = np.transpose(output, (0, 2, 3, 1))
            return output

In [10]:
class SimpleDecoderBlock(nn.Module):

    def __init__(self, d_in, d_out):

        super().__init__()

        self.upconv = nn.UpsamplingBilinear2d(scale_factor=2)
        self.conv_1 = nn.Conv2d(d_in, d_out, 1, 1)
        self.relu = nn.ReLU()
        self.conv_2 = nn.Conv2d(d_out*2, d_out, 3, 1, "same")
        self.conv_3 = nn.Conv2d(d_out, d_out, 3, 1, "same")

    def forward(self, inp, a):

        x = self.upconv(inp)
        x = self.relu(self.conv_1(x))

        if a is not None:
            x = torch.cat([a, x], axis=1)
            x = self.relu(self.conv_2(x))
            
        x = self.relu(self.conv_3(x))

        return x


        

In [5]:
import torch
import torchvision

from torch import nn

class DecoderBlock(nn.Module):

    def __init__(self, d_in, d_out):


        super().__init__()
        self.upconv = nn.ConvTranspose2d(d_in, d_out, 3, 2, padding=1)
        self.conv_1 = nn.Conv2d(d_out*2, d_out, 3, 1, padding=1)
        self.relu = nn.ReLU()
        self.conv_2 = nn.Conv2d(d_out, d_out, 3, 1, padding=1)
        

    def forward(self, inp, a):
        
        x = self.relu(self.upconv(inp))
        #x = self.upconv(inp)
        if a is not None:
            x = torch.cat([a, x], axis=1)
            x = self.relu(self.conv_1(x))
            
        x = self.relu(self.conv_2(x))

        return x


class Decoder(nn.Module):

    def __init__(self, d_in, filters, num_classes, simple=False, sigmoid=False):

        super().__init__()

        self.decoder_blocks = []

        for f in filters:
            
            if simple:
                db = SimpleDecoderBlock(d_in, f)
            else:
                db = DecoderBlock(d_in, f)

            self.decoder_blocks.append(db)
            d_in = f
        
        self.output = nn.Conv2d(f, num_classes, 1, 1)
        self.decoder_blocks = nn.ModuleList(self.decoder_blocks)
        self.sig = nn.Sigmoid() if sigmoid else None
    
    def forward(self, inputs, activations):

        x = inputs
        for db, a in zip(self.decoder_blocks, activations):

            x = db(x, a)
        
        output = self.output(x)
        if self.sig is not None:
            output = self.sig(output)


        return output
        

        


In [6]:
import torch
from torch import nn

class SpatialAttention(nn.Module):

    def __init__(self, in_channels):

        super(SpatialAttention, self).__init__()

        self.C = in_channels

        self.alpha = nn.Parameter(torch.tensor(0.0))

        self.conv1 = nn.Conv2d(in_channels=self.C, out_channels=self.C, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(in_channels=self.C, out_channels=self.C, kernel_size=1, stride=1)
        self.conv3 = nn.Conv2d(in_channels=self.C, out_channels=self.C, kernel_size=1, stride=1)

    
    def forward(self, x):
        

        H = x.shape[2]
        W = x.shape[3]

        N = H * W

        a = x
        b = self.conv1(x)
        c = self.conv2(x)
        d = self.conv3(x)

        b = b.view(-1, self.C, N)
        c = c.view(-1, self.C, N)
        d = d.view(-1, self.C, N)

        c = torch.bmm(c.transpose(1, 2), b)
        S = nn.Softmax(dim=1)(c)
        S = S.transpose(1, 2)

        d = self.alpha * torch.bmm(d, S)
        d = d.view(-1, self.C, H, W)
        E = a + d

        return E


class ChannelAttention(nn.Module):

    def __init__(self, in_channels):

        super(ChannelAttention, self).__init__()
        self.beta = nn.Parameter(torch.tensor(0.0))

        self.C = in_channels
    
    def forward(self, x):

        a1=a2=a3=a4 = x
        H = x.shape[2]
        W = x.shape[3]
        N = H * W

        a2 = a2.view(-1, self.C, N)
        a3 = a3.view(-1, self.C, N)
        a4 = a4.view(-1, self.C, N)
        a4 = a4.transpose(1, 2)

        aa_T = torch.bmm(a3, a4)
        X = nn.Softmax(dim=1)(aa_T)
        X = X.transpose(1, 2)

        a2_pass = torch.bmm(X, a2) * self.beta
        a2_pass = a2_pass.view(-1, self.C, H, W)

        E = a1 + a2_pass

        return E

class DualAttention(nn.Module):

    def __init__(self, in_channels):

        super(DualAttention, self).__init__()
        self.C = in_channels

        self.conv1 = nn.Conv2d(self.C, self.C, 1)
        self.conv2 = nn.Conv2d(self.C, self.C, 1)

        self.sam = SpatialAttention(in_channels)
        self.cam = ChannelAttention(in_channels)

    def forward(self, x):

        e1 = self.sam(x)
        e2 = self.sam(x)

        e1 = self.conv1(e1)
        e2 = self.conv2(e2)

        F = e1 + e2
        return F








In [7]:
class CSA(nn.Module):

    def __init__(self, in_channels):
        
        super(CSA, self).__init__()

        self.C = in_channels
        self.C_by_2 = int(in_channels / 2)
        self.conv_1 = nn.Conv2d(in_channels, in_channels, 1)
        self.conv_3x3_1 = nn.Conv2d(in_channels=self.C_by_2, out_channels=self.C_by_2, kernel_size=3, stride=1, padding="same")
        self.conv_3x3_2 = nn.Conv2d(in_channels=self.C_by_2, out_channels=self.C_by_2, kernel_size=3, stride=1, padding="same")
        self.conv_3x3_3 = nn.Conv2d(in_channels=in_channels, out_channels=self.C_by_2, kernel_size=3, stride=1, padding="same")

        self.group_1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        self.group_2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
        self.softmax = nn.Softmax(dim=1)
        self.final_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)

    def forward(self, input):

        H = input.shape[2]
        W = input.shape[3]

        N = H * W

        F = self.conv_1(input)
        F_1, F_2 = F.split(int(self.C / 2), dim=1)

        F_1 = self.conv_3x3_1(F_1)
        F_2 = self.conv_3x3_2(F_2)
        F_2 = torch.concat([F_1, F_2], dim=1)
        F_2 = self.conv_3x3_3(F_2)

        F = torch.concat([F_1, F_2], dim=1)

        #Global average pooling
        F = nn.AdaptiveAvgPool2d((H, W))(F)

        F = self.group_1(F)
        F = self.bn(F)
        F = self.relu(F)
        F = self.group_2(F)

        F_1_s, F_2_s = F.split(int(self.C / 2), dim=1)

        F_1_s = self.softmax(F_1)
        F_2_s = self.softmax(F_2)

        F_1_final = F_1 * F_1_s
        F_2_final = F_2 * F_2_s

        F_final = torch.concat([F_1_final, F_2_final], dim=1)
        F_final = self.final_conv(F_final)

        output = F_final + input

        return output

In [8]:
from torchvision.models import resnet50, ResNet50_Weights

resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)

class ResNetUNet(nn.Module, BaseModel):

    def __init__(self, num_classes, simple=False, sigmoid=False, attention=False):

        super().__init__()

        self.activations = [None]
        resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)

        self.resnet_backbone = nn.Sequential(*(list(resnet_model.children())[0:7]))
        for param in self.resnet_backbone.parameters():
            param.requires_grad = False

        filters = [512, 256, 64, 64]
        self.decoder = Decoder(1024, filters, num_classes, simple, sigmoid)
        self.attention = attention
        if attention == 1:
            self.attention = CSA(1024)
            self.attention_2 = CSA(filters[0])
        elif attention == 2:
            self.attention = DualAttention(1024)
            self.attention_2 = DualAttention(filters[0])
        elif attention > 2 or attention < 0:
            print("Attention can only be 0, 1, or 2")
            return -1
        else:
            pass
    
    def getActivations(self):
        def hook(model, input, output):
            self.activations.append(output)
        return hook
    
    def forward(self, input):

        self.activations = [None]

        hr1 = self.resnet_backbone[2].register_forward_hook(self.getActivations())
        hr2 = self.resnet_backbone[4][2].register_forward_hook(self.getActivations())
        hr3 = self.resnet_backbone[5][-1].register_forward_hook(self.getActivations())

        resnet_output = self.resnet_backbone(input)
        
        if self.attention:
            resnet_output = self.attention(resnet_output)
            self.activations[-1] = self.attention_2(self.activations[-1])

        final_output = self.decoder(resnet_output, self.activations[::-1])

        hr1.remove()
        hr2.remove()
        hr3.remove()

        return final_output

In [11]:
m = ResNetUNet(5, True, False, 2)

In [20]:
resnet_model = resnet50(weights=ResNet50_Weights.DEFAULT)

In [21]:
torch.save(resnet_model, "res.pt")

In [13]:
torch.save(m, "./test_res.pt")

In [15]:
k = nn.Sequential(*list(m.children()))


In [18]:
k

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256