In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import wandb
from kaggle_secrets import UserSecretsClient
import torch.nn.utils.spectral_norm as spectral_norm
import os
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandbpass")

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
!pip install torch-fidelity

Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch-fidelity
Successfully installed torch-fidelity-0.3.0


In [4]:
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhexager[0m ([33mhexager-manipal[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [20]:
class BottleneckResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False, upsample=False, spectral=False):
        super(BottleneckResBlock, self).__init__()
        mid_channels = in_channels // 4

        self.learned_shortcut = (in_channels != out_channels) or downsample or upsample
        self.downsample = downsample
        self.upsample = upsample

        def conv3x3(in_ch, out_ch):
            conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
            return spectral_norm(conv) if spectral else conv

        def conv1x1(in_ch, out_ch):
            conv = nn.Conv2d(in_ch, out_ch, 1, 1, 0)
            return spectral_norm(conv) if spectral else conv

        self.conv1 = conv1x1(in_channels, mid_channels)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = conv3x3(mid_channels, mid_channels)
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.conv3 = conv1x1(mid_channels, out_channels)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if self.learned_shortcut:
            self.shortcut = conv1x1(in_channels, out_channels)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        if self.upsample:
            out = F.interpolate(out, scale_factor=2)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        if self.downsample:
            out = F.avg_pool2d(out, 2)

        shortcut = self.shortcut(x)
        if self.upsample:
            shortcut = F.interpolate(shortcut, scale_factor=2)
        if self.downsample:
            shortcut = F.avg_pool2d(shortcut, 2)

        return out + shortcut


In [21]:
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, embedding_dim):
        super(ConditionalBatchNorm2d, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.gamma = nn.Linear(embedding_dim, num_features)
        self.beta = nn.Linear(embedding_dim, num_features)

    def forward(self, x, y_embed):
        out = self.bn(x)
        gamma = self.gamma(y_embed).unsqueeze(2).unsqueeze(3)
        beta = self.beta(y_embed).unsqueeze(2).unsqueeze(3)
        out = gamma * out + beta
        return out


In [22]:
class ResBlockG(nn.Module):
    def __init__(self, in_channels, out_channels, embedding_dim):
        super(ResBlockG, self).__init__()
        self.cbn1 = ConditionalBatchNorm2d(in_channels, embedding_dim)
        self.cbn2 = ConditionalBatchNorm2d(out_channels, embedding_dim)
        self.relu = nn.ReLU(inplace=False)
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, 3, padding=1))
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, 1))
            )
        else:
            self.shortcut = nn.Upsample(scale_factor=2)

    def forward(self, x, y_embed):
        out = self.cbn1(x, y_embed)
        out = self.relu(out)
        out = self.upsample(out)
        out = self.conv1(out)
        out = self.cbn2(out, y_embed)
        out = self.relu(out)
        out = self.conv2(out)
        shortcut = self.shortcut(x)
        return out + shortcut


In [23]:
class BigGANDeepLiteGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, embedding_dim=128, ch=64):
        super(BigGANDeepLiteGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.init_size = 4
        self.project = nn.Linear(latent_dim, (ch * 16) * self.init_size * self.init_size)
        self.label_embedding = nn.Embedding(num_classes, embedding_dim)

        self.resblock1 = ResBlockG(ch * 16, ch * 8, embedding_dim)
        self.resblock2 = ResBlockG(ch * 8, ch * 4, embedding_dim)
        self.resblock3 = ResBlockG(ch * 4, ch * 2, embedding_dim)

        self.bn = nn.BatchNorm2d(ch * 2)
        self.relu = nn.ReLU(inplace=False)
        self.final_conv = nn.utils.spectral_norm(nn.Conv2d(ch * 2, 3, 3, padding=1))
        self.tanh = nn.Tanh()

    def forward(self, z, labels):
        y_embed = self.label_embedding(labels)
        out = self.project(z).view(z.size(0), -1, self.init_size, self.init_size)

        out = self.resblock1(out, y_embed)
        out = self.resblock2(out, y_embed)
        out = self.resblock3(out, y_embed)

        out = self.relu(self.bn(out))
        out = self.final_conv(out)
        out = self.tanh(out)
        return out


In [24]:

class ResBlockD(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True):
        super(ResBlockD, self).__init__()
        self.downsample = downsample
        self.learned_shortcut = (in_channels != out_channels) or downsample

        self.conv1 = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
        self.activation = nn.ReLU(inplace=False)
        self.avgpool = nn.AvgPool2d(2)

        if self.learned_shortcut:
            self.shortcut = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def forward(self, x):
        residual = x

        out = self.activation(x)
        out = self.conv1(out)
        out = self.activation(out)
        out = self.conv2(out)
        if self.downsample:
            out = self.avgpool(out)

        if self.learned_shortcut:
            residual = self.shortcut(residual)
            if self.downsample:
                residual = self.avgpool(residual)

        return out + residual

In [25]:
class BigGANDeepLiteDiscriminator(nn.Module):
    def __init__(self, num_classes=10, channels=64):
        super(BigGANDeepLiteDiscriminator, self).__init__()
        self.block1 = ResBlockD(3, channels, downsample=True)
        self.block2 = ResBlockD(channels, channels * 2, downsample=True)
        self.block3 = ResBlockD(channels * 2, channels * 4, downsample=True)
        self.block4 = ResBlockD(channels * 4, channels * 8, downsample=True)
        self.block5 = ResBlockD(channels * 8, channels * 16, downsample=False)

        self.activation = nn.ReLU(inplace=False)
        self.linear = spectral_norm(nn.Linear(channels * 16, 1))

        # For projection discriminator (conditional)
        self.embed = spectral_norm(nn.Embedding(num_classes, channels * 16))

    def forward(self, x, y):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)

        out = self.activation(out)
        out = torch.sum(out, dim=(2, 3))  # Global sum pooling

        output = self.linear(out)

        # Projection discriminator term
        y_embed = self.embed(y)
        proj = torch.sum(out * y_embed, dim=1, keepdim=True)

        return output + proj


In [11]:
data_root = '/kaggle/working/'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /kaggle/working/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:11<00:00, 14.3MB/s] 


Extracting /kaggle/working/cifar-10-python.tar.gz to /kaggle/working/
Files already downloaded and verified


In [13]:
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True  # Optional: can help on GPU
)

In [14]:
run = wandb.init(
    entity="Hexager-manipal",
    # Set the wandb project where this run will be logged.
    project="Big-Gan-similar",
    # Track hyperparameters and run metadata.
    config={
        "learning_rate": 2e-4,
        "architecture": "BIG-GAN-deep-lite-simplified w/o enhancements",
        "dataset": "CIFAR-10",
        "epochs": 20
    },
)

[34m[1mwandb[0m: Currently logged in as: [33mhexager[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']

In [16]:
def save_checkpoint(generator, discriminator, g_optimizer, d_optimizer,
                    epoch, step,  path="checkpoints", filename="last.pth"):
    os.makedirs(path, exist_ok=True)
    
    checkpoint = {
        "generator": generator.state_dict(),
        "discriminator": discriminator.state_dict(),
        "g_optimizer": g_optimizer.state_dict(),
        "d_optimizer": d_optimizer.state_dict(),
        "epoch": epoch,
        "step": step,
        #"best_fid": best_fid,
    }
    
    torch.save(checkpoint, os.path.join(path, filename))
    #if is_best:
        #torch.save(checkpoint, os.path.join(path, "best.pth"))


def load_checkpoint(generator, discriminator, g_optimizer, d_optimizer, path="checkpoints/last.pth"):
    checkpoint = torch.load(path)
    generator.load_state_dict(checkpoint["generator"])
    discriminator.load_state_dict(checkpoint["discriminator"])
    g_optimizer.load_state_dict(checkpoint["g_optimizer"])
    d_optimizer.load_state_dict(checkpoint["d_optimizer"])
    return checkpoint["epoch"], checkpoint["step"] #checkpoint["best_fid"]

In [26]:
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

def hinge_discriminator_loss(D_real, D_fake):
    return torch.mean(F.relu(1. - D_real)) + torch.mean(F.relu(1. + D_fake))

def hinge_generator_loss(D_fake):
    return -torch.mean(D_fake)

def sample_latent(batch_size, z_dim, num_classes, device):
    z = torch.randn(batch_size, z_dim, device=device)
    y = torch.randint(0, num_classes, (batch_size,), device=device)
    return z, y

def train(generator, discriminator, dataloader, num_classes, z_dim=128, 
          epochs=100, lr_g=2e-4, lr_d=2e-4, device='cuda', 
          g_steps=1, d_steps=1, save_interval=5):
    log_interval = 2
    generator = generator.to(device)
    discriminator = discriminator.to(device)

    opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(0.0, 0.999))
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.0, 0.999))

    step = 0

    for epoch in range(epochs):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for real_imgs, labels in pbar:
            real_imgs, labels = real_imgs.to(device), labels.to(device)
            batch_size = real_imgs.size(0)
            # Train Discriminator
            
            for _ in range(d_steps):
                z, y = sample_latent(batch_size, z_dim, num_classes, device)
                fake_imgs = generator(z, y).detach()
                
                D_real = discriminator(real_imgs, labels)
                D_fake = discriminator(fake_imgs, y)

                loss_d = hinge_discriminator_loss(D_real, D_fake)

                opt_d.zero_grad()
                loss_d.backward()
                opt_d.step()
                
            # Train Generator
            
            for _ in range(g_steps):
                z, y = sample_latent(batch_size, z_dim, num_classes, device)
                fake_imgs = generator(z, y)
                D_fake = discriminator(fake_imgs, y)

                loss_g = hinge_generator_loss(D_fake)

                opt_g.zero_grad()
                loss_g.backward()
                opt_g.step()

            pbar.set_postfix({"loss_d": loss_d.item(), "loss_g": loss_g.item()})
            step += 1
            if step % log_interval == 0:
                wandb.log({
                    "g_loss": loss_g.item(),
                    "d_loss": loss_d.item(),
                })

        save_checkpoint(generator, discriminator, opt_g, opt_d, epoch, step)
        # ------------------
        # Save Images
        # ------------------
        if epoch % 5 == 0:
            generator.eval()
            with torch.no_grad():
                n_classes = num_classes
                samples_per_class = 8
                fixed_labels = torch.arange(n_classes, device=device).repeat_interleave(samples_per_class)
                fixed_z = torch.randn(n_classes * samples_per_class, z_dim, device=device)
        
                fakes = generator(fixed_z, fixed_labels)
                grid = torchvision.utils.make_grid(fakes, nrow=samples_per_class, normalize=True, pad_value=1)
        
                wandb.log({
                    "Class-conditional Samples": [wandb.Image(grid, caption=f"Epoch {epoch} - One row per class")]
                })



In [36]:
generator = BigGANDeepLiteGenerator(latent_dim=128, num_classes=10)
discriminator = BigGANDeepLiteDiscriminator(num_classes=10)


In [28]:
train(
    generator=generator,
    discriminator=discriminator,
    dataloader=train_loader,
    num_classes=10,
    z_dim=128,
    epochs=20,
    lr_g=2e-4,
    lr_d=2e-4,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    g_steps=1,
    d_steps=1,
    save_interval=5
)


Epoch 1/20: 100%|██████████| 782/782 [02:21<00:00,  5.54it/s, loss_d=2.07, loss_g=0.105]   
Epoch 2/20: 100%|██████████| 782/782 [02:21<00:00,  5.53it/s, loss_d=2.07, loss_g=-0.0128]  
Epoch 3/20: 100%|██████████| 782/782 [02:21<00:00,  5.53it/s, loss_d=1.94, loss_g=-0.0466] 
Epoch 4/20: 100%|██████████| 782/782 [02:21<00:00,  5.54it/s, loss_d=2.05, loss_g=0.187]    
Epoch 5/20: 100%|██████████| 782/782 [02:21<00:00,  5.54it/s, loss_d=2.15, loss_g=0.00628] 
Epoch 6/20: 100%|██████████| 782/782 [02:21<00:00,  5.54it/s, loss_d=2.01, loss_g=0.109]    
Epoch 7/20: 100%|██████████| 782/782 [02:21<00:00,  5.54it/s, loss_d=2.08, loss_g=0.0342]   
Epoch 8/20: 100%|██████████| 782/782 [02:21<00:00,  5.53it/s, loss_d=1.92, loss_g=0.00907]  
Epoch 9/20: 100%|██████████| 782/782 [02:21<00:00,  5.54it/s, loss_d=2, loss_g=-0.0451]     
Epoch 10/20: 100%|██████████| 782/782 [02:20<00:00,  5.56it/s, loss_d=2.08, loss_g=0.0433]   
Epoch 11/20: 100%|██████████| 782/782 [02:20<00:00,  5.56it/s, loss_d=1.

In [29]:
run.finish()

0,1
d_loss,▁█▅▆▆▆▇▇▆▆▆▇▆▆▅▇▆▇▆▇▆▇▆▆▂▆▅▆▆▆▇▆▇▆▆▆▆▆▆▆
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
g_loss,▄▅▅▅▅▆▅▅▅▅▅▆▅▆▅▅█▃▂▃▁▁▃▃▃▃▃▃▃▅▄▄▄▄▄▄▅▅▅▅
step,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇██

0,1
d_loss,2.0707
epoch,19.0
g_loss,-0.06188
step,15640.0


In [68]:
import gc
gc.collect()
torch.cuda.empty_cache()


In [69]:
generator = BigGANDeepLiteGenerator(latent_dim=128, num_classes=10)
generator.load_state_dict(torch.load("/kaggle/working/checkpoints/last.pth")["generator"])
generator.cpu()

  generator.load_state_dict(torch.load("/kaggle/working/checkpoints/last.pth")["generator"])


BigGANDeepLiteGenerator(
  (project): Linear(in_features=128, out_features=16384, bias=True)
  (label_embedding): Embedding(10, 128)
  (resblock1): ResBlockG(
    (cbn1): ConditionalBatchNorm2d(
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (gamma): Linear(in_features=128, out_features=1024, bias=True)
      (beta): Linear(in_features=128, out_features=1024, bias=True)
    )
    (cbn2): ConditionalBatchNorm2d(
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (gamma): Linear(in_features=128, out_features=512, bias=True)
      (beta): Linear(in_features=128, out_features=512, bias=True)
    )
    (relu): ReLU()
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (shortcut): Sequential(
      (0): Upsample(scale_f

In [56]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm
from torchvision.models import inception_v3, Inception_V3_Weights

# Set device
device = torch.device("cpu")

# Load pre-trained InceptionV3 model (for classification logits)
from torchvision.models import inception_v3
weights = Inception_V3_Weights.DEFAULT
inception_model = inception_v3(weights=weights, aux_logits=True).to(device)
inception_model.eval()

# Function to extract features
def get_inception_features(images, model):
    if model is None:
        return None
    up = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).to(device)

    def get_pred(x):
        if next(model.parameters()).device != x.device:
            x = x.to(next(model.parameters()).device)
        x = up(x)
        return model(x)

    features = []
    with torch.no_grad():
        for i in tqdm(range(0, len(images), 64)):
            batch = images[i:i + 32].to(device)
            pred = get_pred(batch)
            if pred is not None:
                features.append(pred.cpu().numpy())
    return np.concatenate(features, axis=0)

# FID calculation function
def calculate_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    try:
        covmean = sqrtm(sigma1.dot(sigma2))
    except Exception as e:
        print(f"Error calculating covariance mean: {e}")
        return np.nan
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

# Parameters for fake data
num_fake_images = 5000
latent_dim = 128
num_classes = 10  # For CIFAR-10

# Generate fake images using your generator
generator.to(device)
fixed_noise = torch.randn(num_fake_images, latent_dim).to(device)
fixed_labels = torch.randint(0, num_classes, (num_fake_images,)).to(device)

with torch.no_grad():
    generated_images = generator(fixed_noise, fixed_labels).detach().cpu()
    generated_images = (generated_images * 0.5 + 0.5).clamp(0, 1)  # Scale to [0,1]

# Load real CIFAR-10 images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
real_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Get a subset of real images (same size as fake)
indices = np.random.choice(len(real_dataset), num_fake_images, replace=False)
real_images_subset = [real_dataset[i][0] for i in indices]
real_images_tensor = torch.stack(real_images_subset)
real_images_tensor = (real_images_tensor * 0.5 + 0.5).clamp(0, 1)  # Denormalize to [0,1]

# Extract features
print("Calculating Inception features for real images...")
real_features = get_inception_features(real_images_tensor.to(device), inception_model)

print("Calculating Inception features for fake images...")
fake_features = get_inception_features(generated_images.to(device), inception_model)

# Compute FID
if real_features is not None and fake_features is not None:
    fid_score = calculate_fid(real_features, fake_features)
    print(f"FID Score: {fid_score:.4f}")
else:
    print("Could not calculate FID score due to an issue with the Inception model.")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:11<00:00, 14.2MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Calculating Inception features for real images...


  0%|          | 0/79 [00:00<?, ?it/s]

Calculating Inception features for fake images...


  0%|          | 0/79 [00:00<?, ?it/s]

FID Score: 231.8393


In [71]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch

transform = transforms.Compose([
    transforms.ToTensor(),                   # [0,1]
    transforms.Normalize((0.5, 0.5, 0.5),     # back to [-1, 1]
                         (0.5, 0.5, 0.5))
])

real_dataset = CIFAR10(root="./data", train=True, transform=transform, download=True)
real_images = torch.stack([real_dataset[i][0] for i in range(5000)])  # shape: [5000, 3, 32, 32]
generator.eval().to("cuda")
latent_dim = 128
num_classes = 10
batch_size = 64
fakes = []

with torch.no_grad():
    for _ in range(5000 // batch_size):
        z = torch.randn(batch_size, latent_dim).to("cuda")
        y = torch.randint(0, num_classes, (batch_size,), device="cuda")
        out = generator(z, y)
        fakes.append(out.cpu())  # offload to CPU to save VRAM

fakes_tensor = torch.cat(fakes, dim=0)  # [5000, 3, 32, 32]

def denorm(x):
    return (x * 0.5 + 0.5).clamp(0, 1)

real_images = denorm(real_images)
fakes_tensor = denorm(fakes_tensor)

Files already downloaded and verified


In [73]:
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048, normalize=True).to("cuda")

# Feed real images
for i in range(0, 5000, batch_size):
    real_batch = real_images[i:i+batch_size].to("cuda")
    fid.update(real_batch, real=True)

# Feed fake images
for i in range(0, fakes_tensor.size(0), batch_size):
    fake_batch = fakes_tensor[i:i+batch_size]
    if fake_batch.size(0) == 0:
        continue  # skip empty batches just in case
    fid.update(fake_batch.to("cuda"), real=False)


# Compute score
score = fid.compute().item()
print(f"✅ Final FID Score: {score:.2f}")


✅ Final FID Score: 122.45


In [74]:
from torchmetrics.image.inception import InceptionScore

# Normalize fake images to [0, 1] if not already
fake_imgs = fakes_tensor.clone()  # shape: [5000, 3, 32, 32]
fake_imgs = fake_imgs.clamp(0, 1)

# Create IS object
is_metric = InceptionScore(normalize=True, splits=10).to("cuda")

# Feed fake images in batches
batch_size = 64
for i in range(0, fake_imgs.size(0), batch_size):
    batch = fake_imgs[i:i+batch_size]
    if batch.size(0) == 0:
        continue
    is_metric.update(batch.to("cuda"))

# Compute IS
score, std = is_metric.compute()
print(f"✅ Inception Score: {score:.2f} ± {std:.2f}")



✅ Inception Score: 4.04 ± 0.16


In [None]:
generator.eval().to("cuda")
latent_dim = 128
num_classes = 10
batch_size = 64
fakes = []

with torch.no_grad():
    for _ in range(5000 // batch_size):
        z = torch.randn(batch_size, latent_dim).to("cuda")
        y = torch.randint(0, num_classes, (batch_size,), device="cuda")
        out = generator(z, y)
        fakes.append(out.cpu())  # offload to CPU to save VRAM

fakes_tensor = torch.cat(fakes, dim=0)  # [5000, 3, 32, 32]

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
from scipy.linalg import sqrtm
from tqdm.notebook import tqdm
from torchvision.models import inception_v3, Inception_V3_Weights
import torch.nn.functional as F

# Set device
device = torch.device("cpu")

# Load pre-trained InceptionV3 model (for classification logits)
from torchvision.models import inception_v3
weights = Inception_V3_Weights.DEFAULT
inception_model = inception_v3(weights=weights, aux_logits=False).to(device) # Set aux_logits to False for final layer output
inception_model.eval()

# Function to extract features (logits for IS)
def get_inception_logits(images, model):
    if model is None:
        return None
    up = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).to(device)

    def get_pred(x):
        if next(model.parameters()).device != x.device:
            x = x.to(next(model.parameters()).device)
        x = up(x)
        return model(x)

    logits = []
    with torch.no_grad():
        for i in tqdm(range(0, len(images), 64)):
            batch = images[i:i + 32].to(device)
            pred = get_pred(batch)
            if pred is not None:
                logits.append(pred.cpu().numpy())
    return np.concatenate(logits, axis=0)

# Inception Score calculation function
def calculate_is(logits):
    probs = F.softmax(torch.from_numpy(logits), dim=1).numpy()
    p_y = np.mean(probs, axis=0)
    scores = []
    for i in range(probs.shape[0]):
        scores.append(np.sum(probs[i] * np.log(probs[i] / p_y)))
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    return np.exp(mean_score), np.exp(std_score)

# Parameters for fake data
num_fake_images = 5000
latent_dim = 128
num_classes = 10  # For CIFAR-10 (assuming your generator is for CIFAR-10 or similar)

# Assuming your generator is already defined and accessible as 'generator'
# and it takes 'fixed_noise' and 'fixed_labels' as input if it's a conditional generator

# Generate fake images using your generator
generator.to(device)
fixed_noise = torch.randn(num_fake_images, latent_dim).to(device)
fixed_labels = torch.randint(0, num_classes, (num_fake_images,)).to(device)

with torch.no_grad():
    generated_images = generator(fixed_noise, fixed_labels).detach().cpu()
    generated_images = (generated_images + 1) / 2.0 # Scale to [0,1] if generator outputs [-1, 1]. Adjust as needed.

# Extract Inception logits for fake images
print("Calculating Inception logits for fake images...")
fake_logits = get_inception_logits(generated_images.to(device), inception_model)

# Compute Inception Score
if fake_logits is not None:
    mean_is, std_is = calculate_is(fake_logits)
    print(f"Inception Score for fake images: {mean_is:.4f} +/- {std_is:.4f}")
else:
    print("Could not calculate Inception Score due to an issue with the Inception model.")

In [62]:
!pip install clean-fid

Collecting clean-fid
  Downloading clean_fid-0.1.35-py3-none-any.whl.metadata (36 kB)
Downloading clean_fid-0.1.35-py3-none-any.whl (26 kB)
Installing collected packages: clean-fid
Successfully installed clean-fid-0.1.35


In [66]:
from cleanfid import fid
fid_score = fid.compute_fid("output_fake", dataset_name="cifar10", dataset_split="test")
print("CleanFID Score:", fid_score)

compute FID of a folder with cifar10 statistics
downloading statistics to /usr/local/lib/python3.10/dist-packages/cleanfid/stats/cifar10_clean_test_1024.npz


HTTPError: HTTP Error 404: Not Found