In [21]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric

In [22]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(channels_img,features_d,kernel_size = 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2),
            self.block(features_d, features_d * 2, 4, 2, 1),
            self.block(features_d * 2, features_d * 4, 4, 2, 1),
            self.block(features_d * 4, features_d * 8, 4, 2, 1),
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )
    
    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self,x):
        return self.disc(x)

In [23]:
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super().__init__()
        self.gen = nn.Sequential(
            self.block(z_dim, features_g*16, 4, 1 ,0),
            self.block(features_g*16, features_g*8, 4, 2 ,1),
            self.block(features_g*8, features_g*4, 4, 2 ,1),
            self.block(features_g*4, features_g*2, 4, 2 ,1),
            nn.ConvTranspose2d(features_g*2, channels_img, kernel_size = 4, stride =2, padding = 1),
            nn.Tanh(),
        )
    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def forward(self,x):
        return self.gen(x)

In [28]:
class Modelpt(pl.LightningModule):
    def __init__(self, z_dim, channels_img, features_g, features_d, batch_size):
        super().__init__()
        self.z_dim = z_dim
        self.batch_size = batch_size
        self.channels_img = channels_img
        self.features_g = features_g
        self.features_d = features_d
        self.D = Discriminator(channels_img, features_d)
        self.G = Generator(z_dim, channels_img, features_g)
        self.BCE_loss = nn.BCELoss()
        
    def init_weights(self, m):
        for m in model.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
                nn.init.normal_(m.weight.data, 0.0, 0.02)
    
    def forward(self, x):
        return self.G(self.G.forward(x))
    
    def configure_optimizers(self):
        optimizer_G = torch.optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
        optimizer_D = torch.optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        return optimizer_G, optimizer_D
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        x , y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        noise = torch.randn(self.batch_size, self.z_dim, 1, 1).to(self.device)
        x_fake = self.G(noise)
        D_fake = self.D(x_fake)
        #generator
        if optimizer_idx == 0:
            loss_G = self.BCE_loss(D_fake.reshape(-1), torch.ones_like(D_fake.reshape(-1)))
            return loss_G
        #discriminator
        if optimizer_idx == 1:
            D_real = self.D(x)
            loss_D = (self.BCE_loss(D_real, torch.ones_like(D_real)) + self.BCE_loss(D_fake, torch.zeros_like(D_fake))) / 2
            return loss_D
    
    