In [47]:
import torch
import torch.nn as nn

from torch import Tensor
from torch import optim

In [19]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━[0m [32m30.7/44.6 kB[0m [31m810.6 kB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m749.4 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [20]:
from einops.layers.pytorch import rearrange

In [5]:
class ConvBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()

        self.conv_seq = nn.Sequential(
            self.get_conv_block(c_in, c_out),
            self.get_conv_block(c_out, c_out)
        )

    def get_conv_block(self, c_in, c_out):
        return nn.Sequential(
            nn.Conv2d(c_in, c_out, (3, 3), 1, 1),
            nn.BatchNorm2d(c_out),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv_seq(x)


In [6]:
class EncoderBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()

        self.conv_block = ConvBlock(c_in, c_out)
        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        x1 = self.conv_block(x)
        x2 = self.maxpool(x1)

        return x1, x2


In [12]:
class DecoderBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()

        self.upconv = nn.ConvTranspose2d(c_in, c_in // 2, (2, 2), 2)
        self.conv_block = ConvBlock(c_in, c_out)

    def forward(self, x, residual):
        x1 = self.upconv(x)
        x2 = torch.cat((x1, residual), dim=1)
        x3 = self.conv_block(x2)

        return x3


In [38]:
class NoiseBlock(nn.Module): # (128, 1, 1) -> (512, 16, 16)
    def __init__(self):
        super().__init__()

        self.deconv1 = nn.ConvTranspose2d(128, 256, (4, 4))
        self.deconv2 = nn.ConvTranspose2d(256, 256, (2, 2), 2)
        self.deconv3 = nn.ConvTranspose2d(256, 512, (2, 2), 2)

    def forward(self, x):
        x1 = self.deconv1(x)
        x2 = self.deconv2(x1)
        x3 = self.deconv3(x2)

        return x3


In [34]:
class PokeGen(nn.Module):
    def __init__(
        self,
        num_poke: int,
    ):
        super().__init__()
        self.noise_emb = nn.Embedding(num_poke, 128)
        self.noise_block = NoiseBlock()

        self.ue1 = EncoderBlock(3, 64)    # (3, 128)  -> (64, 64)
        self.ue2 = EncoderBlock(64, 128)  # (64, 64)  -> (128, 32)
        self.ue3 = EncoderBlock(128, 256) # (128, 32) -> (256, 16)

        self.bridge = ConvBlock(256, 512) # (256, 16) -> (512, 16)

        self.ud1 = DecoderBlock(512, 256) # (512, 16) -> (256, 32)
        self.ud2 = DecoderBlock(256, 128) # (256, 32) -> (128, 64)
        self.ud3 = DecoderBlock(128, 64)  # (128, 64) -> (64, 128)

        self.final = nn.Sequential(
            ConvBlock(64, 3),
            nn.Tanh()
        )

    def forward(
        self,
        label: Tensor,
        mask: Tensor
    ):
        """
        Args:
            label: (B) tensor containing the labels of each pokemon
            mask: (B, 3, 128, 128) binary mask tensor
        """
        noise = self.noise_emb(label) # (B, 128)
        noise = rearrange(noise, 'b c -> b c () ()')
        n = self.noise_block(noise)

        r1, x1 = self.ue1(mask)
        r2, x2 = self.ue2(x1)
        r3, x3 = self.ue3(x2)

        x4 = self.bridge(x3) + n

        x5 = self.ud1(x4, r3)
        x6 = self.ud2(x5, r2)
        x7 = self.ud3(x6, r1)

        output = self.final(x7)

        return output

In [39]:
label = torch.randint(0, 10, (10,))
mask = torch.randn((10, 3, 128, 128))

model = PokeGen(10)
output = model(label, mask)

print(output.shape)

torch.Size([10, 3, 128, 128])


In [40]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.ue1 = EncoderBlock(3, 64)    # (3, 128)  -> (64, 64)
        self.ue2 = EncoderBlock(64, 128)  # (64, 64)  -> (128, 32)
        self.ue3 = EncoderBlock(128, 256) # (128, 32) -> (256, 16)
        self.ue4 = EncoderBlock(256, 256) # (256, 16) -> (256, 8)

        self.fc1 = nn.Linear(256 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        _, x1 = self.ue1(x)
        _, x2 = self.ue2(x1)
        _, x3 = self.ue3(x2)
        _, x4 = self.ue4(x3)

        x5 = rearrange(x4, 'b c h w -> b (c h w)')
        x6 = self.fc1(x5)
        x7 = self.fc2(x6)

        return self.sig(x7)


In [41]:
x = torch.randn((10, 3, 128, 128))
model = Discriminator()
output = model(x)
print(output.shape)

torch.Size([10, 1])


In [43]:
device = "cuda" if torch.cuda.is_available() else "cpu"
num_poke = 2

epoch = 1
batch_size = 4
lr = 2e-4

In [44]:
poke_gen = PokeGen(num_poke).to(device)
disc = Discriminator().to(device)

In [46]:
poke_gen.train()
disc.train()

print('')




In [49]:
opt_poke_gen = optim.Adam(poke_gen.parameters())
opt_disc = optim.Adam(disc.parameters())

crit = nn.BCELoss()

In [None]:
for e in range(epoch):
    poke_gen.train()

    for image, mask, label in dataloader:
        cur_batch_size = image.size(0)

        image = image.to(device)
        mask = mask.to(device)
        label = label.to(device)

        label_real = torch.full((cur_batch_size,), 1.0, device=device)
        label_fake = torch.full((cur_batch_size,), 0.0, device=device)

        # train gen
        poke_gen.zero_grad()

        gen_label = torch.randint(0, num_poke, (cur_batch_size,), device=device)

        fake_image = poke_gen(gen_label, mask)
        disc_output = disc(fake_image)

        loss_gen = crit(disc_output, label_real) + nn.L1Loss(image, fake_image)
        loss_gen.backward()
        opt_poke_gen.step()

        # train disc
        disc.zero_grad()

        disc_output = disc(image, label)