In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
%matplotlib inline



In [None]:
DATASET = "CIFAR10"
EPOCHS = 100
BATCH_SIZE = 64
LATENT_DIM = 100
MESSAGE_WIDTH = 16
GENERATOR_NUM_FILTERS = 256
DISCRIMINATOR_NUM_FILTERS = 128
DECODER_NUM_FILTERS = 64

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, msg_width, img_shape, base_filters=128):
        super().__init__()
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim + msg_width, base_filters * self.init_size ** 2))
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(base_filters),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(base_filters, base_filters, 3, 1, 1),
            nn.BatchNorm2d(base_filters, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(base_filters, base_filters // 2, 3, 1, 1),
            nn.BatchNorm2d(base_filters // 2, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(base_filters // 2, img_shape[0], 3, 1, 1),
            nn.Tanh(),
        )

    def forward(self, z, msg):
        x = torch.cat((z, msg), dim=1)
        out = self.l1(x)
        out = out.view(out.size(0), GENERATOR_NUM_FILTERS, self.init_size, self.init_size)
        return self.conv_blocks(out)


class Discriminator(nn.Module):
    def __init__(self, img_shape, base_filters=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_shape[0], base_filters, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            nn.Conv2d(base_filters, base_filters * 2, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.25),
            nn.Conv2d(base_filters * 2, base_filters * 4, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
        )
        _ = torch.zeros(1, *img_shape)
        flat_dim = self.model(_).shape[1]
        self.fc = nn.Linear(flat_dim, 1)

    def forward(self, img):
        return self.fc(self.model(img))


class Decoder(nn.Module):
    def __init__(self, img_shape, msg_width):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_shape[0], 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * (img_shape[1] // 4) * (img_shape[2] // 4), 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, msg_width),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)


In [None]:
generator = Generator(LATENT_DIM, MESSAGE_WIDTH, (3, 32, 32), GENERATOR_NUM_FILTERS).to(device)
discriminator = Discriminator((3, 32, 32), DISCRIMINATOR_NUM_FILTERS).to(device)
decoder = Decoder((3, 32, 32), MESSAGE_WIDTH).to(device)

decoder_loss = nn.BCELoss()
gen_optim = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
disc_optim = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
dec_opt = optim.Adam(decoder.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
def train():
    for epoch in range(EPOCHS):
        for i, (imgs, _) in enumerate(dataloader):
            imgs = imgs.to(device)

            # --- Train Discriminator ---
            disc_optim.zero_grad()

            z = torch.randn(imgs.size(0), LATENT_DIM, device=device)
            msgs = torch.randint(0, 2, (imgs.size(0), MESSAGE_WIDTH), device=device, dtype=torch.float)
            fake_imgs = generator(z, msgs).detach()

            real_pred = discriminator(imgs)
            fake_pred = discriminator(fake_imgs)

            real_loss = torch.mean(F.relu(1.0 - real_pred))
            fake_loss = torch.mean(F.relu(1.0 + fake_pred))
            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            disc_optim.step()

            # --- Train Generator + Decoder ---
            gen_optim.zero_grad()
            dec_opt.zero_grad()

            z = torch.randn(imgs.size(0), LATENT_DIM, device=device)
            msgs = torch.randint(0, 2, (imgs.size(0), MESSAGE_WIDTH), device=device, dtype=torch.float)
            gen_imgs = generator(z, msgs)
            validity = discriminator(gen_imgs)
            decoded = decoder(gen_imgs)

            gen_adv_loss = -torch.mean(validity)
            gen_aux_loss = decoder_loss(decoded, msgs)
            gen_loss = 0.5 * (gen_adv_loss + gen_aux_loss)

            gen_loss.backward()
            gen_optim.step()
            dec_opt.step()

            # --- Metrics ---
            pred_bits = (decoded > 0.5).float()
            bitwise_acc = (pred_bits == msgs).float().mean().item()
            full_recovery = ((pred_bits == msgs).all(dim=1).float().mean().item())

            if (i + 1) % 100 == 0:
                print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] "
                      f"[D loss: {disc_loss.item():.3f}] [G loss: {gen_loss.item():.3f}] "
                      f"(Adv: {gen_adv_loss.item():.3f}, Dec: {gen_aux_loss.item():.3f}) "
                      f"[Bit acc: {bitwise_acc:.3f}] [Full rec: {full_recovery:.3f}]")


        with torch.no_grad():
            z = torch.randn(25, LATENT_DIM, device=device)
            msg = torch.randint(0, 2, (25, MESSAGE_WIDTH), device=device, dtype=torch.float)
            samples = generator(z, msg).cpu()
            samples = (samples + 1) / 2  # [-1,1] -> [0,1]
            grid = np.transpose(samples, (0, 2, 3, 1))

        fig, axes = plt.subplots(5, 5, figsize=(6, 6))
        for j, ax in enumerate(axes.flatten()):
            ax.imshow(grid[j])
            ax.axis("off")
        plt.tight_layout()
        plt.show()



In [None]:
def show_samples():
    with torch.no_grad():
        z = torch.randn(25, LATENT_DIM, device=device)
        msg = torch.randint(0, 2, (25, MESSAGE_WIDTH), device=device, dtype=torch.float)
        samples = generator(z, msg).cpu()
        grid = np.transpose((samples + 1) / 2, (0, 2, 3, 1))

    fig, axes = plt.subplots(5, 5, figsize=(6, 6))
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(grid[i])
        ax.axis("off")
    plt.tight_layout()
    plt.show()


In [None]:
train()
show_samples()
