In [49]:
import torch.nn as nn
import torch.nn.functional as F

class DiscriminatorBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, hidden_channels=None, kernel_size=3, padding=1, activation=F.relu, downsample=False):
        super(DiscriminatorBlock, self).__init__()
        
        self.activation = activation
        self.downsample = downsample
        if downsample:
            self.downsampler = torch.nn.AvgPool2d(2)
            
        self.learnable_shortcut = in_channels != out_channels or downsample
        hidden_channels = out_channels if hidden_channels is None else hidden_channels
        
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=padding)
        self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=padding)
        
        if self.learnable_shortcut:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
            
    def forward(self, x):
        # residual
        r = self.activation(x)
        r = self.conv1(r)
        r = self.activation(r)
        r = self.conv2(r)
        if self.downsample:
            r = self.downsampler(r)
        
        # shortcut
        x_sc = x
        if self.learnable_shortcut:
            x_sc = self.shortcut(x_sc)
            if self.downsample:
                x_sc = self.downsampler(x)
            
        return r + x_sc
        

In [50]:
import torch
from torch.autograd import Variable

g = DiscriminatorBlock(128, 10)

x = Variable(torch.ones(8, 128, 32, 32))

g(x)

Variable containing:
(0 ,0 ,.,.) = 
  0.3806  0.2822  0.2624  ...   0.2624  0.2031  0.1583
  0.3281  0.1689  0.1032  ...   0.1032  0.1422  0.0535
  0.4125  0.2509  0.1468  ...   0.1468  0.1384  0.0138
           ...             ⋱             ...          
  0.4125  0.2509  0.1468  ...   0.1468  0.1384  0.0138
  0.4074  0.2634  0.1733  ...   0.1733  0.1745  0.0586
  0.4496  0.3430  0.2534  ...   0.2534  0.2520  0.1361

(0 ,1 ,.,.) = 
 -0.2374 -0.2460 -0.2473  ...  -0.2473 -0.2402 -0.2825
  0.0017  0.0105  0.0162  ...   0.0162 -0.1137 -0.2512
 -0.0270 -0.0535 -0.0459  ...  -0.0459 -0.1483 -0.2835
           ...             ⋱             ...          
 -0.0270 -0.0535 -0.0459  ...  -0.0459 -0.1483 -0.2835
 -0.0526 -0.0806 -0.0605  ...  -0.0605 -0.1644 -0.2743
 -0.0954 -0.0894 -0.0895  ...  -0.0895 -0.2065 -0.2851

(0 ,2 ,.,.) = 
  0.9773  0.9335  0.8777  ...   0.8777  0.8931  0.7753
  0.9012  0.9221  0.8512  ...   0.8512  0.9536  0.8577
  0.9469  1.0165  0.9278  ...   0.9278  0.9945  0.88