In [1]:
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn
from torchinfo import summary
from torchvision import utils as vutils
from tqdm import tqdm

from digitdreamer import mnistdm

In [2]:
dm = mnistdm.MNISTDataModule("../data") # use the other one for fine-tuning
# dm = mnistdm.MNISTDataModule(
#     data_dir="../data",
#     degrees=(0, 0),
#     translate=(0, 0),
#     scale=(1, 1),
#     shear=(-0, 0),
# )
dm.setup()
train_loader = dm.train_dataloader()
test_loader = dm.test_dataloader()

In [None]:
from digitdreamer import Autoencoder, autoencoder, modules

device = torch.device("mps")

model = Autoencoder().to(device)
ae_optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
summary(model, input_size=(1, 1, 32, 32), device=device, depth=2)

In [None]:
class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.blocks = nn.Sequential(
            *autoencoder.down_block(1, 8),
            modules.Block(8, 8),
            *autoencoder.down_block(8, 16),
            modules.Block(16, 16),
            *autoencoder.down_block(16, 32),
            modules.Block(32, 32),
            nn.Conv2d(32, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.blocks(x)


discriminator = Discriminator().to(device)
d_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=6e-4)
summary(
    discriminator,
    device=device,
    depth=2,
    input_size=(1, 1, 32, 32),
)

In [None]:
val_loss = 0
val_psnr = 0

for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for x, _ in pbar:
        x = x.to(device)

        # Train Discriminator
        with torch.no_grad():
            x_hat = model(x)

        d_optimizer.zero_grad()
        real = discriminator(x)
        fake = discriminator(x_hat)
        d_loss = (
            F.binary_cross_entropy(fake, torch.zeros_like(fake))
            + F.binary_cross_entropy(real, torch.ones_like(real))
        ) / 2
        d_loss.backward()
        d_optimizer.step()

        # Train Autoencoder
        ae_optimizer.zero_grad()
        x_hat = model(x)
        d_fake = discriminator(x_hat.detach())
        bce_loss = F.binary_cross_entropy(x_hat, x)
        adv_loss = F.binary_cross_entropy(d_fake, torch.ones_like(d_fake)) * 1e-2

        ae_loss = bce_loss + adv_loss
        ae_loss.backward()
        ae_optimizer.step()

        mse_loss = F.mse_loss(x_hat, x)

        pbar.set_postfix_str(
            f"loss: {ae_loss:.4f}, psnr: {10 * torch.log10(1 / mse_loss):.2f}, bce: {bce_loss:.2f}, adv: {adv_loss:.2f}, real: {real.mean():.2f}, fake: {fake.mean():.2f}, val_loss: {val_loss:.4f}, val_psnr: {val_psnr:.2f}",
        )

    model.eval()
    val_loss = 0
    val_psnr = 0
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            x_hat = model(x)
            val_loss += F.mse_loss(x_hat, x)
            val_psnr += 10 * torch.log10(1 / F.mse_loss(x_hat, x))

    val_loss /= len(test_loader)
    val_psnr /= len(test_loader)

In [None]:
with torch.no_grad():
    model.eval()
    x, _ = next(iter(test_loader))
    x = x.to(device)
    x_hat = model(x)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.imshow(
        vutils.make_grid(x[:64], nrow=8, normalize=True).cpu().numpy().transpose(1, 2, 0),
    )
    plt.axis("off")
    plt.title("Original")

    plt.subplot(1, 2, 2)
    plt.imshow(
        vutils.make_grid(x_hat[:64], nrow=8, normalize=True)
        .cpu()
        .numpy()
        .transpose(1, 2, 0),
    )
    plt.axis("off")
    plt.title("Reconstructed")

    plt.tight_layout()
    plt.show()

In [12]:
torch.save(model.encoder.state_dict(), "../models/encoder.pth")
torch.save(model.decoder.state_dict(), "../models/decoder.pth")
torch.save(discriminator.state_dict(), "../models/discriminator.pth")

In [None]:
model.encoder.load_state_dict(torch.load("../models/encoder.pth", weights_only=True))
model.decoder.load_state_dict(torch.load("../models/decoder.pth", weights_only=True))
discriminator.load_state_dict(
    torch.load("../models/discriminator.pth", weights_only=True),
)

In [10]:
torch.save(model.encoder.state_dict(), "../models/ft-encoder.pth")
torch.save(model.decoder.state_dict(), "../models/ft-decoder.pth")
torch.save(discriminator.state_dict(), "../models/ft-discriminator.pth")

In [None]:
# generate latents
with torch.no_grad():
    latents = []
    classes = []
    for x, y in tqdm(train_loader):
        x = x.to(device)
        latents.append(model.encoder(x).cpu())
        classes.append(y.cpu())

    latents = torch.cat(latents)
    classes = torch.cat(classes)
    train_data = torch.utils.data.TensorDataset(latents, classes)
    torch.save(train_data, "../data/train_data.pth")

    latents = []
    classes = []

    for x, y in tqdm(test_loader):
        x = x.to(device)
        latents.append(model.encoder(x).cpu())
        classes.append(y.cpu())

    latents = torch.cat(latents)
    classes = torch.cat(classes)
    test_data = torch.utils.data.TensorDataset(latents, classes)
    torch.save(test_data, "../data/test_data.pth")