In [None]:
class Block(nn.Module):
    def __init__(self, input_chan, output_chan, downsample = None, stride= 1):
        super(Block, self).__init__()
        self.expansion = 4
        self.conv1 = nn.Conv2d(input_chan, output_chan, kernel_size= 1, stride=1, padding = 0)
        self.batch_n1 = nn.BatchNorm2d(output_chan)

        self.conv2 = nn.Conv2d(input_chan, output_chan, kernel_size= 1, stride=stride, padding = 1)
        self.batch_n2 = nn.BatchNorm2d(output_chan)

        self.conv3 = nn.Conv2d(output_chan, output_chan*self.expansion, kernel_size= 1, stride=1, padding = 0)
        self.batch_n3 = nn.BatchNorm2d(output_chan*self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample

    def forward(self,x):
        identity = x
        x = self.conv1(x)
        x = self.batch_n1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.batch_n2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.batch_n3(x)

        if self.downsample is not None:
            identity = self.downsample(identity)

        x= x + identity
        x = self.relu(x)
        return x

class ResNet(nn.Module):
    def __init__(self, Block, layers, channels, num_output):
        super(ResNet, self).__init__()
        self.input_chan = 64
        self.conv1 = nn.Conv2d(channels, 64, kernel_size = 7, stride=2, padding = 3)
        self.batch_n1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

        #Resnet layers  

        self.layer1 = self._make_layer(Block,layers[0], output_chan = 64, stride = 1)
        self.layer2 = self._make_layer(Block,layers[1], output_chan = 128, stride = 2)
        self.layer3 = self._make_layer(Block,layers[2], output_chan = 256, stride = 2)
        self.layer4 = self._make_layer(Block,layers[3], output_chan = 512, stride = 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*4, num_output)

    def forward(self,x):
        x = self.conv1(x)
        x = self.batch_n1(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x= self.layer1(x)
        x= self.layer2(x)
        x= self.layer3(x)
        x= self.layer4(x)

        x= self.avgpool(x)
        x = x.reshape(x.shape[0],-1)
        x= self.fc(x)

        return x

    def _make_layer(self, Block, num_residual_blocks, output_chan, stride):
        downsample = None
        layers = []

        if stride != 1 or self.input_chan != output_chan*4:
            downsample = nn.Sequential(nn.Conv2d(self.input_chan, output_chan*4, kernel_size=1, stride = stride),
                                        nn.BatchNorm2d(output_chan*4))
            
        layers.append(Block(self.input_chan, output_chan, downsample, stride))
        self.input_chan = output_chan*4

        for i in range(num_residual_blocks-1):
            layers.append(Block(self.input_chan, output_chan))

        return(nn.Sequential(*layers))
