In [1]:
import torch
import torch.nn as nn

In [10]:
class CNNBlock(nn.Module):
  def __init__(self, in_channel, out_channel, stride=2):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, 4, stride, bias=False, padding_mode="reflect"),
        nn.BatchNorm2d(out_channel),
        nn.LeakyReLU(0.2)

    )

  def forward(self, x):
    return self.conv(x)


# x, y <- concatenate these along the channels

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]): # 256 input -> 30x30 output
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )
    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          CNNBlock(in_channels, feature, stride=1 if feature == features[-1] else 2)
      )
      in_channels = feature

    layers.append(
        nn.Conv2d(
            in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
        )
    )

    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = torch.cat([x,y], dim=1)
    x = self.initial(x)
    return self.model(x)


def test():
  x = torch.randn((1, 3, 256, 256))
  y = torch.randn((1, 3, 256, 256))
  model = Discriminator()
  preds = model(x,y)
  print(preds.shape)

test()

torch.Size([1, 1, 26, 26])


# UNET


In [None]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4,2,1,bias=False,padding_mode="reflect"),
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2)
    )