In [37]:
import torch 
import torch.nn as nn 

## D&G Classes 

In [76]:
class Discriminator(nn.Module):
    def __init__(self,img_channels,features):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
        # input : N x img_channels x 64 x 64 
        nn.Conv2d(img_channels,features,kernel_size=4,stride =2,padding=1),
        nn.LeakyReLU(0.2),
        # blocks ---> (in_channels,out_channels,kernal_size,stride,padding)
        self.block(features,features*2,4,2,1),
        self.block(features*2,features*4,4,2,1),
        self.block(features*4,features*8,4,2,1),
        # after all block --> img output is 4x4 
        # conv makes into 1x1 
        nn.Conv2d(features*8,1,kernel_size=4,stride=2,padding=0),
        nn.Sigmoid(),
        )

    def block(self,in_channels,out_channels,kernal_size,stride,padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernal_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self,x):
        return self.disc(x)

In [68]:
class Generator(nn.Module):
    def __init__(self,noise_channels,img_channels,features):
        super(Generator,self).__init__()
        self.gen = nn.Sequential(
        self.block(noise_channels, features*16,4,1,0), # img: 4x4
        self.block(features*16,features*8,4,2,1),      # img: 8x8
        self.block(features*8,features*4,4,2,1),        # img: 16x16
        self.block(features*4,features*2,4,2,1),         # img: 32x32
        nn.ConvTranspose2d(features*2,img_channels,kernel_size=4,stride=2,padding=1),
        # output: N x img_channels x 64 x64 
        nn.Tanh(),
        )
        
    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def forward(self,x):
        return self.gen(x)      

### Initializes weights according to the DCGAN paper

In [80]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

## Test

In [48]:
def test():
    N,in_channels,H,W= 8,3,64,64
    noise_dim=100
    x= torch.randn((N,in_channels,H,W))
    disc= Discriminator(in_channels,8)
    gen= Generator(noise_dim,in_channels,8)
    z= torch.randn((N,noise_dim,1,1))
    print("Success")
    if __name__ == "__main__":
        test()