In [3]:
from torchvision.models.resnet import ResNet, BasicBlock

In [5]:
ResNet(block=BasicBlock, layers=18)

TypeError: 'int' object is not subscriptable

In [33]:
import torch.nn as nn
class resnet32wider(nn.Module):
    def __init__(self, input_channels, latent_dim, activation):
        super(resnet32wider, self).__init__()
        
        if activation == 'Swish':
            act = Swish()
        elif activation == 'LeakyReLU':
            act = nn.LeakyReLU(inplace=True)
        else:
            act = getattr(nn, activation)()
            
        self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1, bias=True)
        self.res1 = BasicBlock(in_c=128, out_c=128, stride=2)
        self.res2 = BasicBlock(in_c=128, out_c=128, stride=1)
        self.res3 = BasicBlock(in_c=128, out_c=128, stride=1)
        self.res4 = BasicBlock(in_c=128, out_c=256, stride=2)
        self.res5 = BasicBlock(in_c=256, out_c=256, stride=1)
        self.res6 = BasicBlock(in_c=256, out_c=256, stride=1)
        self.res7 = BasicBlock(in_c=256, out_c=512, stride=2)
        self.res8 = BasicBlock(in_c=512, out_c=512, stride=1)
        self.res9 = BasicBlock(in_c=512, out_c=512, stride=1)
        self.fc1 = nn.Sequential(
                    nn.Linear(512*4*4, 512),
                    act)
        
        self.nss1_net = nn.Linear(512, latent_dim)
        self.nss2_net = nn.Linear(512, latent_dim)
                    
        self.flatten = nn.Flatten()
            
        
    def forward(self, x):
        h0 = self.conv1(x)
        h1 = self.res1(h0)
        h2 = self.res2(h1)
        h3 = self.res3(h2)
        h4 = self.res4(h3)
        h5 = self.res5(h4)
        h6 = self.res6(h5)
        h7 = self.res7(h6)
        h8 = self.res8(h7)
        h9 = self.res9(h8)
        h10 = self.fc1(self.flatten(h9))
        nss1 = self.nss1_net(h10)
        nss2 = self.nss2_net(h10)
        return nss1, -nss2**2
        
        
class BasicBlock(nn.Module):
    """
    basic block module
        stride   -- stride of the 1st cnn in the 1st block in a group
        bn_flag -- whether do batch normalization
    """
    def __init__(self, in_c, out_c, stride, activation=nn.LeakyReLU(inplace=True)):
        super(BasicBlock, self).__init__()

        self.activation = activation
        self.c1 = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True)
        self.c2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=stride, padding=1, bias=True)
                    
        if in_c != out_c:
            self.shortcut = nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, bias=True)
        else:
            if stride != 1:
                self.shortcut = nn.AvgPool2d(kernel_size=2)
            else:
                self.shortcut = nn.Identity()

    def forward(self, x):
        h1 = self.activation(self.c1(x))
        h2 = self.activation(self.c2(h1))
        out = self.activation(h2 + self.shortcut(x))
        
        return out  

In [34]:
resnet = resnet32wider()

In [35]:
import torch
x = torch.randn((128, 3, 32, 32))
o = resnet(x)

In [36]:
o.shape

torch.Size([128, 512, 4, 4])