In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
from tqdm import tqdm
import wandb
from kaggle_secrets import UserSecretsClient
import torch.nn.utils.spectral_norm as spectral_norm
import os
import numpy as np
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandbpass")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
!pip install torch-fidelity

In [None]:
wandb.login(key=secret_value_0)

In [None]:
class BottleneckResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False, upsample=False, spectral=False):
        super(BottleneckResBlock, self).__init__()
        mid_channels = in_channels // 4

        self.learned_shortcut = (in_channels != out_channels) or downsample or upsample
        self.downsample = downsample
        self.upsample = upsample

        def conv3x3(in_ch, out_ch):
            conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
            return spectral_norm(conv) if spectral else conv

        def conv1x1(in_ch, out_ch):
            conv = nn.Conv2d(in_ch, out_ch, 1, 1, 0)
            return spectral_norm(conv) if spectral else conv

        self.conv1 = conv1x1(in_channels, mid_channels)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = conv3x3(mid_channels, mid_channels)
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.conv3 = conv1x1(mid_channels, out_channels)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if self.learned_shortcut:
            self.shortcut = conv1x1(in_channels, out_channels)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        if self.upsample:
            out = F.interpolate(out, scale_factor=2)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        if self.downsample:
            out = F.avg_pool2d(out, 2)

        shortcut = self.shortcut(x)
        if self.upsample:
            shortcut = F.interpolate(shortcut, scale_factor=2)
        if self.downsample:
            shortcut = F.avg_pool2d(shortcut, 2)

        return out + shortcut


In [None]:
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, embedding_dim):
        super(ConditionalBatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.gamma = nn.Linear(embedding_dim, num_features)
        self.beta = nn.Linear(embedding_dim, num_features)

    def forward(self, x, y_embed):
        out = self.bn(x)
        gamma = self.gamma(y_embed).unsqueeze(2).unsqueeze(3)
        beta = self.beta(y_embed).unsqueeze(2).unsqueeze(3)
        out = gamma * out + beta
        return out


In [None]:
class ResBlockG(nn.Module):
    def __init__(self, in_channels, out_channels, embedding_dim):
        super(ResBlockG, self).__init__()
        self.cbn1 = ConditionalBatchNorm2d(in_channels, embedding_dim)
        self.cbn2 = ConditionalBatchNorm2d(out_channels, embedding_dim)
        self.relu = nn.ReLU(inplace=False)
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, 3, padding=1))
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 1))
            )
        else:
            self.shortcut = nn.Upsample(scale_factor=2)

    def forward(self, x, y_embed):
        out = self.cbn1(x, y_embed)
        out = self.relu(out)
        out = self.upsample(out)
        out = self.conv1(out)
        out = self.cbn2(out, y_embed)
        out = self.relu(out)
        out = self.conv2(out)
        shortcut = self.shortcut(x)
        return out + shortcut


In [None]:
class BigGANDeepLiteGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, embedding_dim=128, ch=64):
        super(BigGANDeepLiteGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.init_size = 4
        self.project = nn.Linear(latent_dim, (ch * 16) * self.init_size * self.init_size)
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)

        self.resblock1 = ResBlockG(ch * 16, ch * 8, embedding_dim)
        self.resblock2 = ResBlockG(ch * 8, ch * 4, embedding_dim)
        self.resblock3 = ResBlockG(ch * 4, ch * 2, embedding_dim)

        self.bn = nn.BatchNorm2d(ch * 2)
        self.relu = nn.ReLU(inplace=False)
        self.final_conv = nn.utils.spectral_norm(nn.Conv2d(ch * 2, 3, 3, padding=1))
        self.tanh = nn.Tanh()

    def forward(self, z, labels):
        y_embed = self.label_embedding(labels)
        out = self.project(z).view(z.size(0), -1, self.init_size, self.init_size)

        out = self.resblock1(out, y_embed)
        out = self.resblock2(out, y_embed)
        out = self.resblock3(out, y_embed)

        out = self.relu(self.bn(out))
        out = self.final_conv(out)
        out = self.tanh(out)
        return out


In [None]:

class ResBlockD(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True):
        super(ResBlockD, self).__init__()
        self.downsample = downsample
        self.learned_shortcut = (in_channels != out_channels) or downsample

        self.conv1 = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        self.activation = nn.ReLU(inplace=False)
        self.avgpool = nn.AvgPool2d(2)

        if self.learned_shortcut:
            self.shortcut = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def forward(self, x):
        residual = x

        out = self.activation(x)
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        if self.downsample:
            out = self.avgpool(out)

        if self.learned_shortcut:
            residual = self.shortcut(residual)
            if self.downsample:
                residual = self.avgpool(residual)

        return out + residual

In [None]:
class BigGANDeepLiteDiscriminator(nn.Module):
    def __init__(self, num_classes=10, channels=64, use_augself=False, num_aug_types=4):
        super(BigGANDeepLiteDiscriminator, self).__init__()
        self.use_augself = use_augself

        self.block1 = ResBlockD(3, channels, downsample=True)
        self.block2 = ResBlockD(channels, channels * 2, downsample=True)
        self.block3 = ResBlockD(channels * 2, channels * 4, downsample=True)
        self.block4 = ResBlockD(channels * 4, channels * 8, downsample=True)
        self.block5 = ResBlockD(channels * 8, channels * 16, downsample=False)

        self.activation = nn.ReLU(inplace=False)
        self.linear = spectral_norm(nn.Linear(channels * 16, 1))
        self.embed = spectral_norm(nn.Embedding(num_classes, channels * 16))

        # Augmentation classifier (only used if use_augself=True)
        if self.use_augself:
            self.aug_classifier = nn.Linear(channels * 16, num_aug_types)

    def forward(self, x, y, return_aug_logits=False):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)

        out = self.activation(out)
        pooled = torch.sum(out, dim=(2, 3))  # Global sum pooling
        out_linear = self.linear(pooled)

        proj = torch.sum(pooled * self.embed(y), dim=1, keepdim=True)
        output = out_linear + proj

        if self.use_augself and return_aug_logits:
            aug_logits = self.aug_classifier(pooled)
            return output, aug_logits
        return output


In [None]:
data_root = '/kaggle/working/'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)
def get_partial_cifar10(split='train', percentage=0.1, seed=42, transform=None):
    # Load the full CIFAR-10 dataset
    dataset = torchvision.datasets.CIFAR10(
        root='./data',
        train=(split=='train'),
        transform=transform,
        download=True
    )

    # Get class-wise indices
    np.random.seed(seed)
    targets = np.array(dataset.targets)
    selected_indices = []

    for cls in range(10):  # one for each class
        cls_indices = np.where(targets == cls)[0]
        num_samples = int(len(cls_indices) * percentage)
        chosen = np.random.choice(cls_indices, num_samples, replace=False)
        selected_indices.extend(chosen)

    return Subset(dataset, selected_indices)
train_dataset_10 = get_partial_cifar10(split='train', percentage=0.1, transform=transform)
train_loader_10 = DataLoader(train_dataset_10, batch_size=128, shuffle=True,num_workers = 2)
train_dataset_20 = get_partial_cifar10(split='train', percentage=0.2, transform=transform)
train_loader_20 = DataLoader(train_dataset_10, batch_size=128, shuffle=True,num_workers = 2)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    pin_memory=True,
    num_workers = 2# Optional: can help on GPU
)

In [None]:
run = wandb.init(
    entity="Hexager-manipal",
    # Set the wandb project where this run will be logged.
    project="Big-Gan-similar",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": 2e-4,
        "architecture": "BIG-GAN-deep-lite-simplified w/o enhancements",
        "dataset": "CIFAR-10",
        "epochs": 20
    },
)

In [None]:
AUG_TYPES = {
    0: transforms.RandomRotation(degrees=15),
    1: transforms.ColorJitter(0.4, 0.4, 0.4),
    2: transforms.RandomHorizontalFlip(p=1.0),
    3: transforms.RandomGrayscale(p=1.0)
}
NUM_AUGS = len(AUG_TYPES)

def apply_aug_batch(batch, aug_labels):
    # Apply the corresponding augmentation to each image
    return torch.stack([AUG_TYPES[aug.item()](img.cpu()).to(img.device) for aug, img in zip(aug_labels, batch)])


In [None]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
def save_checkpoint(generator, discriminator, g_optimizer, d_optimizer,
                    epoch, step,  path="checkpoints", filename="last.pth"):
    os.makedirs(path, exist_ok=True)
    
    checkpoint = {
        "generator": generator.state_dict(),
        "discriminator": discriminator.state_dict(),
        "g_optimizer": g_optimizer.state_dict(),
        "d_optimizer": d_optimizer.state_dict(),
        "epoch": epoch,
        "step": step,
        #"best_fid": best_fid,
    }
    
    torch.save(checkpoint, os.path.join(path, filename))
    #if is_best:
        #torch.save(checkpoint, os.path.join(path, "best.pth"))


def load_checkpoint(generator, discriminator, g_optimizer, d_optimizer, path="checkpoints/last.pth"):
    checkpoint = torch.load(path)
    generator.load_state_dict(checkpoint["generator"])
    discriminator.load_state_dict(checkpoint["discriminator"])
    g_optimizer.load_state_dict(checkpoint["g_optimizer"])
    d_optimizer.load_state_dict(checkpoint["d_optimizer"])
    return checkpoint["epoch"], checkpoint["step"] #checkpoint["best_fid"]

In [None]:
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

def hinge_discriminator_loss(D_real, D_fake):
    return torch.mean(F.relu(1. - D_real)) + torch.mean(F.relu(1. + D_fake))

def hinge_generator_loss(D_fake):
    return -torch.mean(D_fake)

def sample_latent(batch_size, z_dim, num_classes, device):
    z = torch.randn(batch_size, z_dim, device=device)
    y = torch.randint(0, num_classes, (batch_size,), device=device)
    return z, y

def train(
    generator,
    discriminator,
    dataloader,
    num_classes,
    z_dim=128,
    epochs=100,
    lr_g=2e-4,
    lr_d=2e-4,
    device='cuda',
    g_steps=1,
    d_steps=1,
    save_interval=5,
    use_pseudo_aug=False,           # <— new flag
    t_threshold=0.6,                # APA threshold
    p_step=0.01,                    # APA step size
    adjust_every=4,                  # APA adjust frequency
    use_augself = True,  # 🔁 Set False to disable AugSelf
    lambda_aug = 1.0  # Weight of the contrastive loss

):
    # — Pseudo‑Augmentation state —
    pseudo_aug_prob = 0.0           # initial deception prob.
    d_iter          = 0             # counter for D updates
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(0.0, 0.999))
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.0, 0.999))
    step = 0

    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for real_imgs, labels in pbar:
            real_imgs, labels = real_imgs.to(device), labels.to(device)
            bs = real_imgs.size(0)

            # — Train Discriminator —
            for _ in range(d_steps):
                # 1) generate fake and update buffer
                z, y = sample_latent(bs, z_dim, num_classes, device)
                fake_imgs = generator(z, y).detach()

                # 2) mix pseudo into real (only if APA enabled)
                if use_pseudo_aug and pseudo_aug_prob > 0 and epoch>4:
                    mask = (torch.rand(bs,1,1,1,device=device) < pseudo_aug_prob).float()
                    real_mixed = real_imgs * (1 - mask) + fake_imgs * mask
                else:
                    real_mixed = real_imgs
                if use_augself:
                    aug_labels = torch.randint(0, NUM_AUGS, (bs,), device=device)
                    aug_imgs = apply_aug_batch(real_mixed, aug_labels)
                    D_real_aug, aug_logits = discriminator(aug_imgs, labels, return_aug_logits=True)
                    loss_aug = torch.clamp(F.cross_entropy(aug_logits, aug_labels), max=5.0)
                else:
                    loss_aug = torch.tensor(0.0, device=device)
                # 3) forward
                D_real = discriminator(real_mixed, labels, return_aug_logits=False)
                D_fake = discriminator(fake_imgs, y, return_aug_logits=False)

                loss_adv = hinge_discriminator_loss(D_real, D_fake)
                loss_d = loss_adv + lambda_aug * loss_aug
                # 4) backward + step
                opt_d.zero_grad()
                loss_d.backward()
                opt_d.step()

                # 5) adapt pseudo_aug_prob if APA enabled
                if use_pseudo_aug and epoch>4:
                    d_iter += 1
                    if d_iter % adjust_every == 0:
                        lambda_r = float(torch.mean(torch.sign(D_real)).item())
                        if lambda_r > t_threshold:
                            pseudo_aug_prob = min(pseudo_aug_prob + p_step, 1.0)
                        else:
                            pseudo_aug_prob = max(pseudo_aug_prob - p_step, 0.0)

            # — Train Generator —
            for _ in range(g_steps):
                z, y = sample_latent(bs, z_dim, num_classes, device)
                fake_imgs = generator(z, y)
                D_fake = discriminator(fake_imgs, y)
                loss_g = hinge_generator_loss(D_fake)

                opt_g.zero_grad()
                loss_g.backward()
                opt_g.step()

            # — Logging & checkpointing —
            postfix = {
                "d_loss": loss_d.item(),
                "g_loss": loss_g.item(),
            }
            if use_pseudo_aug and epoch>4:
                postfix["pseudo_aug_prob"] = pseudo_aug_prob
            pbar.set_postfix(postfix)
            step += 1
            if step % 50 == 0:
                log_dict = {
                    "d_loss": loss_d.item(),
                    "g_loss": loss_g.item(),
                    "epoch": epoch,
                    "step": step
                }
                if use_pseudo_aug and epoch>4:
                    log_dict["pseudo_aug_prob"] = pseudo_aug_prob 
                if use_augself:
                    log_dict["aug_loss"] = loss_aug
                wandb.log(log_dict)
                

        save_checkpoint(generator, discriminator, opt_g, opt_d, epoch, step)
        # ------------------
        # Save Images
        # ------------------
        if epoch % 5 == 0:
            generator.eval()
            with torch.no_grad():
                z = torch.randn(64, z_dim, device=device)
                labels = torch.randint(0, num_classes, (64,), device=device)
                fakes = generator(z, labels)
                grid = torchvision.utils.make_grid(fakes, nrow=8, normalize=True)
                wandb.log({"Generated Images": [wandb.Image(grid, caption=f"Epoch {epoch}")]})
            generator.train()



In [None]:
use_augself = True

In [None]:
generator = BigGANDeepLiteGenerator(latent_dim=128, num_classes=10)
discriminator = BigGANDeepLiteDiscriminator(num_classes=10, use_augself=use_augself)

In [None]:
train(generator, discriminator, train_loader_20, num_classes=10, epochs=50, use_pseudo_aug=True, use_augself=use_augself)

In [None]:
run.finish()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()


In [None]:
generator = BigGANDeepLiteGenerator(latent_dim=128, num_classes=10)
generator.load_state_dict(torch.load("/kaggle/working/checkpoints/last.pth")["generator"])


In [None]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch

transform = transforms.Compose([
    transforms.ToTensor(),                   # [0,1]
    transforms.Normalize((0.5, 0.5, 0.5),     # back to [-1, 1]
                         (0.5, 0.5, 0.5))
])

real_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
real_images = torch.stack([real_dataset[i][0] for i in range(5000)])  # shape: [5000, 3, 32, 32]
generator.eval().to("cuda")
latent_dim = 128
num_classes = 10
batch_size = 64
fakes = []

with torch.no_grad():
    for _ in range(5000 // batch_size):
        z = torch.randn(batch_size, latent_dim).to("cuda")
        y = torch.randint(0, num_classes, (batch_size,), device="cuda")
        out = generator(z, y)
        fakes.append(out.cpu())  # offload to CPU to save VRAM

fakes_tensor = torch.cat(fakes, dim=0)  # [5000, 3, 32, 32]

def denorm(x):
    return (x * 0.5 + 0.5).clamp(0, 1)

real_images = denorm(real_images)
fakes_tensor = denorm(fakes_tensor)

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048, normalize=True).to("cuda")

# Feed real images
for i in range(0, 10000, batch_size):
    real_batch = real_images[i:i+batch_size].to("cuda")
    if real_batch.size(0) == 0:
        continue
    fid.update(real_batch, real=True)

# Feed fake images
for i in range(0, fakes_tensor.size(0), batch_size):
    fake_batch = fakes_tensor[i:i+batch_size]
    if fake_batch.size(0) == 0:
        continue  # skip empty batches just in case
    fid.update(fake_batch.to("cuda"), real=False)


# Compute score
score = fid.compute().item()
print(f"✅ Final FID Score: {score:.2f}")


In [None]:
from torchmetrics.image.inception import InceptionScore

# Normalize fake images to [0, 1] if not already
fake_imgs = fakes_tensor.clone()  # shape: [5000, 3, 32, 32]
fake_imgs = fake_imgs.clamp(0, 1)

# Create IS object
is_metric = InceptionScore(normalize=True, splits=10).to("cuda")

# Feed fake images in batches
batch_size = 64
for i in range(0, fake_imgs.size(0), batch_size):
    batch = fake_imgs[i:i+batch_size]
    if batch.size(0) == 0:
        continue
    is_metric.update(batch.to("cuda"))

# Compute IS
score, std = is_metric.compute()
print(f"✅ Inception Score: {score:.2f} ± {std:.2f}")