In [38]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck
from torchvision import models
import segmentation_models_pytorch as smp

from segmentation_models_pytorch.encoders._base import EncoderMixin
import segmentation_models_pytorch.encoders as smp_enc

from torchvision.models.resnet import ResNet
from copy import deepcopy

import torchvision as tv

from layers_2D import RotConv, Vector2Magnitude, VectorBatchNorm, VectorMaxPool, VectorUpsampling

In [64]:
model = smp.Unet()

In [65]:
model

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

# Building resnet34

In [2]:
from segmentation_models_pytorch.encoders import resnet_encoders

#### test basic block

In [8]:
BasicBlock(64,64,1)

BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [23]:
fake_img = torch.rand((8,64,64,64))
gg = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(fake_img)
gg = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(gg)
gg = nn.ReLU()(gg)
gg = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(gg)
gg = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(gg)

In [24]:
gg.shape

torch.Size([8, 64, 64, 64])

#### Test RotBasicBlock

In [28]:
temp = RotBasicBlock(64,64,1)
temp

RotBasicBlock(
  (conv1): RotConv()
  (v2m): Vector2Magnitude()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): RotConv()
  (v2m2): Vector2Magnitude()
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [62]:
fake_img = torch.rand((8,64,64,64))
guh = RotConv(64, 64, (3,3) , padding=(1,1), n_angles=8, mode=1)(fake_img)
guh = VectorMaxPool(2)(guh)
guh = VectorBatchNorm(64)(guh)
guh = VectorUpsampling(64)(guh)
guh = Vector2Magnitude()(guh)
guh = nn.ReLU()(guh)

In [63]:
guh.shape

torch.Size([8, 64, 64, 64])

In [179]:
guh = VectorBatchNorm(64)(guh)

In [182]:
Vector2Magnitude()(guh).shape

torch.Size([8, 64, 62, 62])

In [140]:
test = temp.conv1(fake_img)

In [141]:
RotBasicBlock(64,64,1).v2m(test).shape

torch.Size([8, 64, 62, 62])

In [130]:
test = RotBasicBlock(64,64,1).bn1(test)

torch.Size([8, 64, 62, 62])

In [26]:
class RotBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                    base_width=64, dilation=1, norm_layer=None):
        super(RotBasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1

        self.conv1 = RotConv(inplanes, planes, kernel_size=(3,3), stride=stride)
        self.v2m = Vector2Magnitude()
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = RotConv(planes, planes, kernel_size=(3,3), stride=(1,1))
        self.v2m2 = Vector2Magnitude()
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.v2m(out)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.v2m2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [173]:
class ResNetEncoder(nn.Module, EncoderMixin):

    def __init__(self,depth=5,**kwargs):
        super().__init__()
        self._depth = depth
        self._out_channels: List[int] = [3,64,64,128,256,512]
        self._in_channels = 3
        self.block = BasicBlock
        self.rotblock = RotBasicBlock
        self.inplanes = 64
        self.layers: List[int] = [3,4,6,3]
        

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # self.layer1 = self._make_layer(self.block, 64, self.layers[0])
        
        self.rot_layer1 = self._make_rot_layer(BasicBlock, 64, self.layers[0])
        self.layer2 = self._make_layer(self.block, 128, self.layers[1], stride=2)
        self.layer3 = self._make_layer(self.block, 256, self.layers[2], stride=2)
        self.layer4 = self._make_layer(self.block, 512, self.layers[3], stride=2)
    
        # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 , 1)

    def get_stages(self):
        return [
            nn.Identity(),
            nn.Sequential(self.conv1, self.bn1, self.relu),
            nn.Sequential(self.maxpool, self.rot_layer1),
            self.layer2,
            self.layer3,
            self.layer4,
        ]

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None  
   
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        
        self.inplanes = planes
        
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def _make_rot_layer(self, block, planes, blocks, stride=1):
        downsample = None  
   
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(self.rotblock(self.inplanes, planes, stride, downsample))
        
        self.inplanes = planes
        
        for _ in range(1, blocks):
            layers.append(self.rotblock(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        stages = self.get_stages()

        features = []
        for i in range(self._depth + 1):
            # print(stages[i])
            # print(x.shape)

            x = stages[i](x)
            features.append(x)

        return features

In [108]:
smp.encoders.encoders['my_awesome_encoder']

{'encoder': __main__.ResNetEncoder, 'pretrained_settings': {}, 'params': {}}

In [107]:
smp.encoders.encoders["my_awesome_encoder"] = {
    "encoder": ResNetEncoder,
    "pretrained_settings": {
    },
    'params': {}
}

In [113]:
model = smp.Unet(encoder_name='my_awesome_encoder', encoder_weights=None)

In [116]:
BasicBlock(64,64)

BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [111]:
model

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (rot_layer1): Sequential(
      (0): RotBasicBlock(
        (conv1): RotConv()
        (bn1): VectorBatchNorm()
        (v2m): Vector2Magnitude()
        (relu): ReLU(inplace=True)
        (conv2): RotConv()
        (bn2): VectorBatchNorm()
        (v2m2): Vector2Magnitude()
      )
      (1): RotBasicBlock(
        (conv1): RotConv()
        (bn1): VectorBatchNorm()
        (v2m): Vector2Magnitude()
        (relu): ReLU(inplace=True)
        (conv2): RotConv()
        (bn2): VectorBatchNorm()
        (v2m2): Vector2Magnitude()
      )
      (2): RotBasicBlock(
        (conv1): RotConv()
        (bn1): VectorBatchNorm()
        (v2m): Vector2Magnitude(

### References

In [46]:
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1):
        super().__init__()
        
        self.inplanes = 64

        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 , num_classes)


    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None  
   
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, 1, stride, bias=False),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        
        self.inplanes = planes
        
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)           # 224x224
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)         # 112x112

        x = self.layer1(x)          # 56x56
        x = self.layer2(x)          # 28x28
        x = self.layer3(x)          # 14x14
        x = self.layer4(x)          # 7x7

        x = self.avgpool(x)         # 1x1
        x = torch.flatten(x, 1)     # remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)

        return x

In [47]:
def resnet34():
    layers = [3,4,6,3]
    model = ResNet_RotEq(BasicBlock, layers)
    return model

In [48]:
model = resnet34()

In [51]:
model.layer1

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (2): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, mome