In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import math

import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary


In [2]:
class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.2):
        super(TransitionBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.leakly_relu = nn.LeakyReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride = 1, padding=0, bias=False)
        self.dropout_prob = dropout
        self.avgpool = nn.AvgPool2d(2, stride = 2)
        
    def foward(self, x):
        out = self.conv1(self.leaky_relu(self.bn1(x)))
        out = F.dropout(out, p=self.dropout_prob, inplace=False, training = self.training)
        out = self.avgpool(out)
        return out

        #return F.avg_pool2d(out, 2)
    
    

In [7]:

class TransitionBlock_test(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.2):
        super(TransitionBlock_test, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.leaky_relu = nn.LeakyReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
        self.dropout_prob = dropout
        self.avgpool = nn.AvgPool2d(2, stride = 2)
        
    def forward(self, input):
        out = self.conv1(self.leaky_relu(self.bn1(input)))
        out = F.dropout(out, p=self.dropout_prob, inplace=False, training=self.training)
        out = self.avgpool(out)
        return out
        

In [8]:
# tests
T = TransitionBlock(3,3)
T_test = TransitionBlock_test(3,3)



In [9]:
summary(T, (3, 224, 224))

NotImplementedError: 

In [10]:
summary(T_test,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
       BatchNorm2d-1          [-1, 3, 224, 224]               6
         LeakyReLU-2          [-1, 3, 224, 224]               0
            Conv2d-3          [-1, 3, 224, 224]               9
         AvgPool2d-4          [-1, 3, 112, 112]               0
Total params: 15
Trainable params: 15
Non-trainable params: 0
----------------------------------------------------------------
