In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from torch import Tensor

In [8]:
from super_image_resolution.modules.patch_discriminator import PatchDiscriminator
from super_image_resolution.modules.unet import UNet

In [None]:
class PCGAN():
    def __init__(
        self,
        in_channels: int, 
        out_channels:int, 
        gen_norm: str, 
        gen_act: str, 
        gen_final_act: str, 
        dis_norm: str,
        dis_act: str, 
        dis_final_act: str, 
        dis_depth: int, 
        dis_kernel_size: Union[int, Tuple[int, int]],
        dis_stride: Union[int, Tuple[int, int]],
        dis_padding:Union[int, Tuple[int, int]],
        gen_loss_weight: float, 
        dis_loss_weight: float):
        
        self.gen = UNet(
            in_channels=self.in_channels,
            out_channels=self.out_channels, 
            norm=self.gen_norm, 
            activation=self.gen_act, 
            final_act=self.gen_final_act)
        
        self.dis = PatchDiscriminator(
            in_channels=self.out_channels, 
            depth=self.dis_depth, 
            dis_kernel_size=self.dis_depth, 
            stride=self.dis_stride,
            padding=self.dis_padding, 
            activation=self.dis_act, 
            norm=self.dis_norm, 
            final_activation=self.dis_final_act)
        
        self.reconst_criterion = nn.L1Loss()
        self.adv_criterion = nn.BCEWithLogitsLoss()
        
        self.gen_loss_weight = gen_loss_weight
        self.dis_loss_weight = dis_loss_weight
        
    def forward(self, x: Tensor) -> Tensor:
        return self.gen(x)

    def configure_optimizers(self):
        optim_gen = torch.optim.Adam(self.gen.parameters(), lr=1e-3)
        optim_dis = torch.optim.Adam(self.gen.parameters(), lr=1e-3)
        return [optim_gen, optim_dis], []
    
    def generator_step(self, x: Tensor, y: Tensor, mode:str) -> torch.float:
        y_hat = self.gen(x)
        reconst_loss = self.gen_loss_weight * self.reconst_loss(y_hat, y)
        
        self.log(f"gen_{mode}_reconst_loss", reconst_loss.item())
        
        dis_hat_out = self.dis(y_hat)
        adv_loss = self.dis_loss_weight * self.adv_criterion(
            dis_hat_out, 
            torch.ones_like(self.dis_out, requires_grad=False).cuda())
        
        self.log(f"gen_{mode}_adv_loss", adv_loss.item(), on_step=True, sync_dist=True, prog_bar=True)
        return reconst_loss + adv_loss
    
    def discriminator_step(self, x: Tensor, y: Tensor, mode: str) -> torch.float:
        
        y_hat = self.gen(x)
        dis_hat_out = self.dis(y_hat.detach())
        adv_hat_loss = self.dis_loss_weight * self.adv_criterion(
            dis_hat_out, torch.zeros_like(dis_hat_out, requires_grad=False).cuda())
        
        dis_out = self.dis(y)
        adv_loss = self.dis_loss_weight * self.adv_criterion(
            dis_out,
            torch.ones_like())
        
        dis_out
        
	def training_step(self, train_batch, batch_idx):
		x, y = train_batch
		x = x.view(x.size(0), -1)
		z = self.encoder(x)    
		x_hat = self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		self.log('train_loss', loss)
		return loss

	def validation_step(self, val_batch, batch_idx):
		x, y = val_batch
		x = x.view(x.size(0), -1)
		z = self.encoder(x)
		x_hat = self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		self.log('val_loss', loss)

# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)

# model
model = LitAutoEncoder()

# training
trainer = pl.Trainer(gpus=4, num_nodes=8, precision=16, limit_train_batches=0.5)
trainer.fit(model, train_loader, val_loader)
    


In [2]:
import torch

	pip install pytorch-lightning
  
	pip install pytorch-lightning
  
	pip install pytorch-lightning
  
	pip install pytorch-lightning
  
	pip install pytorch-lightning
  
	pip install pytorch-lightning
  
	pip install pytorch-lightning
  

In [4]:
dis = PatchDiscriminator(
    in_channels=3,
    depth=5,
    kernel_size=4, 
    stride=2, 
    padding=1,
    norm="batchnorm", 
    activation="lrelu", 
    final_activation="lrelu")

In [5]:
dis

PatchDiscriminator(
  (out): ModuleList(
    (0): ModuleList(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): ModuleList(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): ModuleList(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): ModuleList(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    

In [6]:
dis(torch.randn(1,3,256,256)).shape

torch.Size([1, 1, 4, 4])