In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck
from torchvision import models
import segmentation_models_pytorch as smp

from segmentation_models_pytorch.encoders._base import EncoderMixin
import segmentation_models_pytorch.encoders as smp_enc

from torchvision.models.resnet import ResNet
from copy import deepcopy

import torchvision as tv

from layers_2D import RotConv, Vector2Magnitude, VectorBatchNorm, VectorMaxPool, VectorUpsampling

In [2]:
# Define network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.main = nn.Sequential(
            
            # RotConv 9x9, 6 filters -> 2x2 SP: 14
            RotConv(1, 6, [9, 9], 1, 9 // 2, n_angles=17, mode=1),
            VectorMaxPool(2),
            VectorBatchNorm(6),

            RotConv(6, 16, [9, 9], 1, 9 // 2, n_angles=17, mode=2),
            VectorMaxPool(2),
            VectorBatchNorm(16),

            RotConv(16, 32, [9, 9], 1, 1, n_angles=17, mode=2),
            Vector2Magnitude(),

            nn.Conv2d(32, 128, 1),  # FC1
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.7),
            nn.Conv2d(128, 10, 1),  # FC2

        )

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size()[0], x.size()[1])

        return x

## MNIST

In [4]:
x = torch.randn(8,1,28,28)

### Layer 1

In [5]:
x = RotConv(1, 6, [9, 9], 1, 9 // 2, n_angles=17, mode=1)(x)
print(x[0].shape)
x = VectorMaxPool(2)(x)
print(x[0].shape)
x = VectorBatchNorm(6)(x)
print(x[0].shape)

torch.Size([8, 6, 28, 28])
torch.Size([8, 6, 14, 14])
torch.Size([8, 6, 14, 14])


### Layer 2

In [6]:
x = RotConv(6, 16, [9, 9], 1, 9 // 2, n_angles=17, mode=2)(x)
print(x[0].shape)
x = VectorMaxPool(2)(x)
print(x[0].shape)
x = VectorBatchNorm(16)(x)
print(x[0].shape)

torch.Size([8, 16, 14, 14])
torch.Size([8, 16, 7, 7])
torch.Size([8, 16, 7, 7])


### Layer 3

In [7]:
x = RotConv(16, 32, [9, 9], 1, 1, n_angles=17, mode=2)(x)
print(x[0].shape)

torch.Size([8, 32, 1, 1])


In [8]:
x = Vector2Magnitude()(x)
print(x.shape)

torch.Size([8, 32, 1, 1])


### Layer 4

In [9]:
x = nn.Conv2d(32, 128, 1)(x)  # FC1
print(x.shape)
x = nn.BatchNorm2d(128)(x)
print(x.shape)
x = nn.ReLU()(x)
print(x.shape)

torch.Size([8, 128, 1, 1])
torch.Size([8, 128, 1, 1])
torch.Size([8, 128, 1, 1])


In [10]:
x = nn.Dropout2d(0.7)(x)
print(x.shape)
x = nn.Conv2d(128, 10, 1)(x)  # FC2
print(x.shape)

torch.Size([8, 128, 1, 1])
torch.Size([8, 10, 1, 1])


In [11]:
x.view(x.size()[0], x.size()[1]).shape

torch.Size([8, 10])

## Segmentation example

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, n_angles = 8, mode=1):

### Layer 1

In [3]:
N = 6
x = torch.randn(8,3,256,256)
x = RotConv(3, N, [9, 9], 1, 9//2, n_angles=6, mode=1)(x)
print(x[0].shape)
x = VectorMaxPool(2)(x)
print(x[0].shape)
x = VectorBatchNorm(N)(x)
print(x[0].shape)

torch.Size([8, 6, 256, 256])
torch.Size([8, 6, 128, 128])
torch.Size([8, 6, 128, 128])


In [4]:
x = RotConv(N, 2*N, [9, 9], 1, 9 // 2, n_angles=6, mode=2)(x)
print(x[0].shape)
x = VectorMaxPool(2)(x)
print(x[0].shape)
x = VectorBatchNorm(2*N)(x)
print(x[0].shape)

torch.Size([8, 12, 128, 128])
torch.Size([8, 12, 64, 64])
torch.Size([8, 12, 64, 64])


In [5]:
x = RotConv(2*N, 3*N, [9, 9], 1, 9 // 2, n_angles=6, mode=2)(x)
print(x[0].shape)
x = VectorMaxPool(2)(x)
print(x[0].shape)
x = VectorBatchNorm(3*N)(x)
print(x[0].shape)

torch.Size([8, 18, 64, 64])
torch.Size([8, 18, 32, 32])
torch.Size([8, 18, 32, 32])


In [6]:
x = RotConv(3*N, 4*N, [9, 9], 1, 9 // 2, n_angles=6, mode=2)(x)
print(x[0].shape)

torch.Size([8, 24, 32, 32])


In [7]:
x = VectorUpsampling(size=256)(x)
print(x[0].shape)
x = Vector2Magnitude()(x)
print(x.shape)

torch.Size([8, 24, 256, 256])
torch.Size([8, 24, 256, 256])




In [22]:
x.shape

torch.Size([8, 24, 256, 256])