In [11]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms.v2
from torchinfo import summary
from tqdm import tqdm
from ema_pytorch import EMA

In [12]:
train_dataset = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=torchvision.transforms.v2.Compose(
        [
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ]
    ),
)
val_dataset = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=torchvision.transforms.v2.Compose(
        [
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ]
    ),
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=64, shuffle=False
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 25663496.51it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 13958941.44it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 18737276.81it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10613107.95it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [135]:
class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
        super(DepthwiseSeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
        super(DownBlock, self).__init__()
        self.block = nn.Sequential(
            nn.PixelShuffle(2),
            DepthwiseSeparableConv2d(in_channels // 4, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1):
        super(UpBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
        )

    def forward(self, x):
        return self.block(x)


class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.model = nn.Sequential(
            UpBlock(latent_dim, latent_dim * 2, 3, 1, 1),
            UpBlock(latent_dim *2, latent_dim * 4, 3, 1, 1),
            UpBlock(latent_dim * 4, latent_dim * 8, 3, 1, 1),
            UpBlock(latent_dim * 8, latent_dim * 16, 3, 1, 1),
            UpBlock(latent_dim * 16, 1, 3, 1, 1),
            nn.Sigmoid(),
        )
        self.class_emb = nn.Embedding(11, latent_dim)

    def forward(self, z, labels):
        class_emb = self.class_emb(labels)
        z = (z + class_emb).view(z.size(0), -1, 1, 1)
        img = self.model(z)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * 2 * 2, 1),
            nn.Sigmoid(),
        )

        self.embedding = nn.Embedding(11, 1 * 32 * 32)

    def forward(self, img, labels):
        class_emb = self.embedding(labels)
        class_emb = class_emb.view(img.size(0), 1, 32, 32)
        img = img + class_emb
        validity = self.model(img)
        return validity

In [136]:
device = "mps"

generator = Generator(latent_dim=16, img_shape=(1, 32, 32)).to(device)
discriminator = Discriminator(img_shape=(1, 32, 32)).to(device)

print(
    summary(
        generator,
        input_data=(
            torch.randn(7, 512).to(device),
            torch.randint(0, 10, (7,)).to(device),
        ),
        device=device,
    ), end="\n\n\n\n"
)
print(
    summary(
        discriminator,
        input_data=(
            torch.randn(7, 1, 32, 32).to(device),
            torch.randint(0, 10, (7,)).to(device),
        ),
        device=device,
    )
)

optim_g = torch.optim.AdamW(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_d = torch.optim.AdamW(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Embedding: 1]

In [137]:
# train
for epoch in range(10):
    for i, (x, y) in enumerate(pbar := tqdm(train_loader)):
        x, y = x.to(device), y.to(device)
        batch_size = x.shape[0]

        # Adversarial ground truths
        valid = torch.ones(batch_size, 1, device=device)
        fake = torch.zeros(batch_size, 1, device=device)

        # -----------------
        #  Train Generator
        # -----------------

        optim_g.zero_grad()

        # Sample noise as generator input
        z = torch.randn(batch_size, 512, device=device)

        # Generate a batch of images
        gen_imgs = generator(z, y)

        # Loss measures generator's ability to fool the discriminator
        g_loss = F.binary_cross_entropy(discriminator(gen_imgs, y), valid)

        g_loss.backward()
        optim_g.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optim_d.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = F.binary_cross_entropy(discriminator(x, y), valid)
        fake_loss = F.binary_cross_entropy(discriminator(gen_imgs.detach(), y), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optim_d.step()

        pbar.set_postfix_str(f"d_loss: {d_loss.item()}, g_loss: {g_loss.item()}")

        if i % 100 == 0:
            generate()

  0%|          | 0/938 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (512) must match the size of tensor b (16) at non-singleton dimension 1

In [133]:
#samplefrom generator
@torch.no_grad()
def generate():
    generator.eval()
    z = torch.randn(11 * 8, 512).to(device)
    y = torch.arange(11).repeat(8).to(device)
    gen_imgs = generator(z, y).detach().cpu()
    grid = torchvision.utils.make_grid(gen_imgs, nrow=11)
    torchvision.utils.save_image(grid, "gen.png")
    generator.train()

generate()