In [1]:
import torch.nn as nn
import torch

In [2]:
class DilatedBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation=2):
        super(DilatedBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.tanh = nn.Tanh()
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=5, padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.GroupNorm(32, out_channels)  
        
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        out = self.tanh(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.skip(x)
        out = self.tanh(out)
        return out

In [3]:
class ModifiedResNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(ModifiedResNet, self).__init__()
        
        self.layer1 = nn.Conv2d(3, 128, kernel_size=7, stride=2, padding=3)
        self.layer2 = DilatedBlock(128, 256)
        self.layer3 = DilatedBlock(256, 512, stride=2, dilation=2)
        self.layer4 = DilatedBlock(512, 512, stride=1, dilation=4)
        self.layer5 = DilatedBlock(512, 1024, stride=1, dilation=4)
        
        self.globalpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.globalpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [4]:
model = ModifiedResNet()
print(model)
scripted_model = torch.jit.script(model)
scripted_model.save("model_Modified_ResNet_scripted.pt")

ModifiedResNet(
  (layer1): Conv2d(3, 128, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (layer2): DilatedBlock(
    (conv1): Conv2d(128, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (tanh): Tanh()
    (conv2): Conv2d(256, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
    (bn2): GroupNorm(32, 256, eps=1e-05, affine=True)
    (skip): Sequential(
      (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): DilatedBlock(
    (conv1): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), dilation=(2, 2), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (tanh): Tanh()
    (conv2): Conv2d(512, 512, kernel_size=(5, 5