In [49]:
import os
import pickle
import torch
import random
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from nltk.tokenize import RegexpTokenizer

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torchvision.models import Inception_V3_Weights

class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""
    
    BLOCK_INDEX_BY_DIM = {
        64: 0,   
        192: 1,  
        768: 2,  
        2048: 3  
    }

    def __init__(self,
                 output_blocks=[3],
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False,
                 use_pretrained=True):
        super(InceptionV3, self).__init__()
        
        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)
        
        assert self.last_needed_block <= 3, 'Last possible output block index is 3'
        
        # Load InceptionV3 with proper weight handling
        weights = Inception_V3_Weights.IMAGENET1K_V1 if use_pretrained else None
        
        # Load model with compatible parameters
        kwargs = {
            'weights': weights,
            'transform_input': False,  # We'll handle normalization ourselves
            'aux_logits': True  # Must be True for pretrained weights
        }
        
        try:
            inception = models.inception_v3(**kwargs)
        except Exception as e:
            print(f"Error loading pretrained model: {e}")
            inception = models.inception_v3(weights=None, aux_logits=False)
        
        # Remove aux logits branch if present
        if hasattr(inception, 'AuxLogits'):
            inception.AuxLogits = None
        
        # Freeze model if not training
        if not requires_grad:
            for param in inception.parameters():
                param.requires_grad = False
        
        # Build feature extraction blocks
        self.blocks = nn.ModuleList()
        
        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))
        
        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))
        
        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))
        
        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))
        
        # Store normalization parameters
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        """Extract features from input tensor"""
        features = []
        
        # Resize if needed
        if self.resize_input:
            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        
        # Normalize if needed
        if self.normalize_input:
            x = (x - self.mean.to(x.device)) / self.std.to(x.device)
        
        # Extract features from each block
        for block_idx, block in enumerate(self.blocks):
            x = block(x)
            if block_idx in self.output_blocks:
                features.append(x)
            if block_idx == self.last_needed_block:
                break
                
        return features

****Configurations required****

In [30]:
import argparse

def parse_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Distributed Training
    parser.add_argument('--world_size', default=-1, type=int, help='Number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int, help='Node rank for distributed training')
    parser.add_argument('--local_rank', default=0, type=int, help='Node rank for distributed training')
    parser.add_argument('--dist_url', default='env://', type=str, help='URL for distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str, help='Distributed backend')

    # General Settings
    parser.add_argument('--seed', default=12345, type=int, help='Seed for initializing training')
    parser.add_argument('--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)')

    # Dataset and Paths
    parser.add_argument('--dataset', type=str, default='birds', help='Dataset type')
    parser.add_argument('--data_path', type=str, default='/kaggle/input/dataset-cub/CUB-200-2011', help='Base path')
    parser.add_argument('--image_dir', type=str, default='./output_birds', help='Path to save images')
    parser.add_argument('--bbox_file', type=str, default='/kaggle/input/dataset-cub/CUB-200-2011/bounding_boxes.txt', help='Bounding box file')
    parser.add_argument('--text_dir', type=str, default='/kaggle/input/dataset-birds/birds/birds/text_c10', help='Path to text descriptions')
    parser.add_argument('--train_filenames', type=str, default='/kaggle/input/dataset-cub/CUB-200-2011/train/filenames.pickle', help='Train filenames pickle')
    parser.add_argument('--test_filenames', type=str, default='/kaggle/input/dataset-cub/CUB-200-2011/test/filenames.pickle', help='Test filenames pickle')
    parser.add_argument('--c_dim', type=int, default=128, help='Text embedding dimension')
    parser.add_argument('--text_path', type=str, default="/kaggle/input/dataset-birds/birds/text_c10",
                        help='Path to the text_c10 folder containing raw captions')

    # Training Parameters
    parser.add_argument('--STAGE', type=int, default=1, help='Training stage')
    parser.add_argument('--epoch', type=int, default=200, help='Number of epochs')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size to use')
    parser.add_argument('--gener_batch_size', type=int, default=32, help='Batch size for generator')
    parser.add_argument('--dis_batch_size', type=int, default=32, help='Batch size for discriminator')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of CPU threads for DataLoader')

    # Model Parameters
    parser.add_argument('--image_size', type=int, default=32, help='Size of image for discriminator input')
    parser.add_argument('--initial_size', type=int, default=8, help='Initial size for generator')
    parser.add_argument('--patch_size', type=int, default=4, help='Patch size for generated image')
    parser.add_argument('--num_classes', type=int, default=1, help='Number of classes for discriminator')
    parser.add_argument('--lr_gen', type=float, default=0.0001, help='Learning rate for generator')
    parser.add_argument('--lr_dis', type=float, default=0.0001, help='Learning rate for discriminator')
    parser.add_argument('--weight_decay', type=float, default=1e-3, help='Weight decay')
    parser.add_argument('--z_dim', type=int, default=128, help='Latent dimension')
    parser.add_argument('--n_critic', type=int, default=5, help='Number of critic updates per generator update')
    parser.add_argument('--max_iter', type=int, default=250000, help='Maximum iterations')

    # Optimization
    parser.add_argument('--optim', type=str, default="Adam", help='Optimizer')
    parser.add_argument('--loss', type=str, default="wgangp-mode", help='Loss function')
    parser.add_argument('--lr_decay', action='store_true', help='Enable learning rate decay')
    parser.add_argument('--beta1', type=float, default=0.0, help='Beta1 for Adam optimizer')
    parser.add_argument('--beta2', type=float, default=0.99, help='Beta2 for Adam optimizer')

    # Conditioning and Validation
    parser.add_argument('--Iscondtion', action='store_true', help='Use text conditioning')
    parser.add_argument('--Isval', action='store_true', help='Enable validation')

    # Pretrained Models
    parser.add_argument('--NET_G', type=str, default='', help='Path to Generator weights')
    parser.add_argument('--NET_D', type=str, default='', help='Path to Discriminator weights')
    parser.add_argument('--STAGE1_G', type=str, default='', help='Path to Stage 1 Generator weights')
    parser.add_argument('--STAGE1_D', type=str, default='', help='Path to Stage 1 Discriminator weights')

    # Snapshot & Normalization
    parser.add_argument('--SNAPSHOT_INTERVAL', type=int, default=5000, help='Interval for saving snapshots')
    parser.add_argument('--g_norm', type=str, default="ln", help='Generator Normalization')
    parser.add_argument('--g_act', type=str, default="gelu", help='Generator Activation Layer')
    parser.add_argument('--d_act', type=str, default="gelu", help='Discriminator Activation Layer')
    parser.add_argument('--d_norm', type=str, default="ln", help='Discriminator Normalization')

    # Text Encoding
    parser.add_argument('--RNN_TYPE', type=str, default='LSTM', help='Text encoding type')
    parser.add_argument('--WORDS_NUM', type=int, default=18, help='Number of words in text input')
    parser.add_argument('--CAPTIONS_PER_IMAGE', type=int, default=5, help='Captions per image')
    parser.add_argument('--EMBEDDING_DIM', type=int, default=256, help='Embedding dimension')
    parser.add_argument('--CONDITION_DIM', type=int, default=256, help='Condition dimension')
    # Evaluation Paths
    parser.add_argument('--dims', type=int, default=2048, choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
                        help='Dimensionality of Inception features to use.')

    args, unknown = parser.parse_known_args()
    return args

args = parse_args()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

class RNN_ENCODER(nn.Module):
    def __init__(self, ntoken, cfg):
        super(RNN_ENCODER, self).__init__()
        self.ntoken = ntoken
        self.ninput = cfg.TEXT.EMBEDDING_DIM
        self.drop_prob = cfg.TEXT.DROP_PROB
        self.nhidden = cfg.TEXT.HIDDEN_DIM
        self.nlayers = cfg.TEXT.NUM_LAYERS
        self.bidirectional = cfg.TEXT.BIDIRECTIONAL
        self.rnn_type = cfg.TEXT.RNN_TYPE.upper()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.define_module()

    def define_module(self):
        self.encoder = nn.Embedding(self.ntoken, self.ninput)
        
        rnn_cls = nn.LSTM if self.rnn_type == 'LSTM' else nn.GRU
        self.rnn = rnn_cls(
            self.ninput, self.nhidden, self.nlayers, batch_first=True,
            dropout=self.drop_prob, bidirectional=self.bidirectional
        )
        
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, captions, cap_lens, hidden=None):
        embeddings = self.encoder(captions)
        packed = nn.utils.rnn.pack_padded_sequence(embeddings, cap_lens, batch_first=True, enforce_sorted=True)
        output, hidden = self.rnn(packed, hidden)
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

        if self.bidirectional:
            hidden = hidden.view(self.nlayers, 2, -1, self.nhidden)
            hidden = hidden[:, -1].contiguous().view(self.nlayers, -1, self.nhidden * 2)
        else:
            hidden = hidden.view(self.nlayers, -1, self.nhidden)

        return output, hidden.squeeze(0)


In [3]:
def custom_collate_fn(batch):
    # Filter out the invalid (empty) images in the batch
    batch = [item for item in batch if item != torch.tensor([])]

    # If the batch is empty after filtering, handle accordingly (e.g., return an empty batch)
    if not batch:
        return torch.tensor([]), torch.tensor([])

    # Standard collate for non-empty items
    return torch.stack([item[0] for item in batch]), torch.stack([item[1] for item in batch])


In [32]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler

class ImageDataset(object):
    def __init__(self, args, cur_img_size=None, is_distributed=False):
        self.imsize = [cur_img_size * (2 ** i) for i in range(args.STAGE)]
        img_size = self.imsize[args.STAGE - 1]

        if args.dataset.lower() == 'birds':  # CUB-200 Dataset
            transform = transforms.Compose([
                transforms.Resize(int(img_size * 76 / 64)),  # Resize first
                transforms.RandomCrop(img_size),  # Then crop
                transforms.RandomHorizontalFlip(),  # Data augmentation
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

            transform_test = transforms.Compose([
                transforms.Resize(int(img_size * 76 / 64)),
                transforms.CenterCrop(img_size),  # No random cropping for test
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Matching training normalization
            ])

            # Load full dataset first
            train_dataset = BIRDS(data_dir=args.data_path, split='train', imsize=img_size, transform=transform)
            test_dataset = BIRDS(data_dir=args.data_path, split='test', imsize=img_size, transform=transform)
            
            # Handling Distributed Training
            if is_distributed:
                train_sampler = DistributedSampler(train_dataset)
                test_sampler = DistributedSampler(test_dataset)
                self.train = DataLoader(
                    train_dataset,
                    batch_size=args.dis_batch_size,
                    shuffle=False,  # Sampler handles shuffling
                    num_workers=args.num_workers,
                    pin_memory=True,
                    drop_last=True,
                    collate_fn=custom_collate_fn,
                    sampler=train_sampler
                )
                self.test = DataLoader(
                    test_dataset,
                    batch_size=args.dis_batch_size,
                    shuffle=False,
                    num_workers=args.num_workers,
                    pin_memory=True,
                    drop_last=True,
                    collate_fn=custom_collate_fn,
                    sampler=test_sampler
                )
            else:
                # Standard non-distributed DataLoader
                self.train = DataLoader(
                    train_dataset,
                    batch_size=args.dis_batch_size,
                    shuffle=True,
                    num_workers=args.num_workers,
                    pin_memory=True,
                    drop_last=True,
                    collate_fn=custom_collate_fn
                )
                
                self.test = DataLoader(
                    test_dataset,
                    batch_size=args.dis_batch_size,
                    shuffle=False,
                    num_workers=args.num_workers,
                    pin_memory=True,
                    drop_last=True,
                    collate_fn=custom_collate_fn
                )

        else:
            raise NotImplementedError(f"Dataset '{args.dataset}' is not supported yet.")


In [33]:
import os
import pickle
import torch
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
from nltk.tokenize import RegexpTokenizer
from torch.utils.data import Dataset

BASE_DIR = "/kaggle/input/dataset-cub"  # Change this to your dataset folder
TRAIN_EMBEDDINGS_PATH = os.path.join(BASE_DIR, "CUB-200-2011/train/char-CNN-RNN-embeddings.pickle")
TEST_EMBEDDINGS_PATH = os.path.join(BASE_DIR, "CUB-200-2011/test/char-CNN-RNN-embeddings.pickle")

# Load configuration
args = cfg.parse_args()

def get_imgs(img_path, transform=None, bbox=None):
    img = Image.open(img_path).convert('RGB')
    if bbox is not None:
        r = int(np.maximum(bbox[2], bbox[3]) * 0.65)
        center_x = int((2 * bbox[0] + bbox[2]) / 2)
        center_y = int((2 * bbox[1] + bbox[3]) / 2)
        y1 = max(0, center_y - r)
        y2 = min(img.height, center_y + r)
        x1 = max(0, center_x - r)
        x2 = min(img.width, center_x + r)
        img = img.crop([x1, y1, x2, y2])
    return transform(img) if transform else img

class BIRDS(Dataset):
    def __init__(self, data_dir=BASE_DIR, split='train', imsize=64, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.split = split
        self.img_size = imsize
        self.embeddings_num = args.CAPTIONS_PER_IMAGE
        
        self.filenames = self.load_filenames(split)
        self.bbox = self.load_bbox()
        self.sent_emb = self.load_embedding(split) if args.Iscondtion else None
    
    def load_filenames(self, split):
        if split == 'train':
            filepath = args.train_filenames  # Use train filenames file
        elif split == 'test':
            filepath = args.test_filenames   # Use test filenames file
        else:
            raise ValueError(f"Invalid dataset split: {split}")
        
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"File not found: {filepath}")
        
        with open(filepath, 'rb') as f:
            return pickle.load(f)

    
    def load_bbox(self):
        bbox_path = os.path.join(self.data_dir, "bounding_boxes.txt")
        df_bounding_boxes = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
        
        filepath = os.path.join(self.data_dir, "images.txt")
        df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
        filenames = df_filenames[1].tolist()
        
        return {img_file[:-4]: df_bounding_boxes.iloc[i][1:].tolist() for i, img_file in enumerate(filenames)}
    
    def load_embedding(self, split):
        emb_path = TRAIN_EMBEDDINGS_PATH if split == 'train' else TEST_EMBEDDINGS_PATH
        with open(emb_path, 'rb') as f:
            return pickle.load(f)
    
    def __getitem__(self, index):
        key = self.filenames[index]
        bbox = self.bbox.get(key, None)
        img_path = os.path.join(self.data_dir, f"images/{key}.jpg")
        img = get_imgs(img_path, self.transform, bbox)
    
        if args.Iscondtion:
            sent_ix = np.random.randint(0, self.embeddings_num)
            sent_emb = self.sent_emb[index][sent_ix]
            return img, sent_emb  # ✅ Always return a tuple
        else:
            return img, torch.tensor([])  # ✅ Return an empty tensor for text_emb if not using conditions

    
    def __len__(self):
        return len(self.filenames)


In [41]:
import os
import torch
import numpy as np
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from torch.utils.data import Dataset
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
import torchvision.transforms as transforms
from PIL import Image

args = cfg.parse_args()

# ✅ Ensure paths are properly set
args.path1 = "/kaggle/input/dataset-cub/CUB-200-2011/images"   # ✅ Path for real images
args.path2 = "/kaggle/working/test_outputs"   # ✅ Path for generated images

def get_activations(args, images, model, dims, device, num_img, verbose=False):
        model.eval()
        pred_arr = []
    
        with torch.no_grad():
            for batch in images:
                if isinstance(batch, (list, tuple)):
                    batch = batch[0]  # In case dataset returns (img, label)
                batch = batch.to(device)
    
                pred = model(batch)[0]
                pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
    
                activations = pred.cpu().data.numpy().reshape(batch.size(0), -1)
                pred_arr.append(activations)
    
                if sum(x.shape[0] for x in pred_arr) >= num_img:
                    break
    
        pred_arr = np.concatenate(pred_arr, axis=0)[:num_img]
        return pred_arr



def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Computes Frechet Inception Distance (FID)."""
    mu1, mu2 = np.atleast_1d(mu1), np.atleast_1d(mu2)
    sigma1, sigma2 = np.atleast_2d(sigma1), np.atleast_2d(sigma2)

    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)

    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean))

def calculate_activation_statistics(args, images, model, dims=2048, device='cuda', num_img=5000):
    """Compute activations, mean and covariance matrix."""
    act = get_activations(args, images, model, dims, device, num_img)
    return np.mean(act, axis=0), np.cov(act, rowvar=False)

def calculate_fid(args, test_loader, device=None, dims=2048, num_img=5000):
    """Compute FID score between real and generated images."""
    if device:
        device = torch.device(device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"✅ Using device: {device}")

    # ✅ Load InceptionV3 model
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx]).to(device)

    # ✅ Load real images (test set)
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    real_dataset = Sample_evaluate(args.path1, transform=transform)
    real_loader = torch.utils.data.DataLoader(
        real_dataset, batch_size=args.gener_batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True
    )

    # ✅ Load generated images
    gen_dataset = Sample_evaluate(args.path2, transform=transform)
    gen_loader = torch.utils.data.DataLoader(
        gen_dataset, batch_size=args.gener_batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True
    )

    print("✅ Computing activations for real images...")
    mu1, sigma1 = calculate_activation_statistics(args, real_loader, model, dims, device, num_img)

    print("✅ Computing activations for generated images...")
    mu2, sigma2 = calculate_activation_statistics(args, gen_loader, model, dims, device, num_img)

    fid_value = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    print(f"✅ Computed FID Score: {fid_value:.4f}")

    return fid_value

class Sample_evaluate(Dataset):
    def __init__(self, path, transform=None):
        """Dataset for loading images recursively from subdirectories."""
        self.transform = transform
        self.data_dir = path

        # Recursively find all image paths
        self.image_paths = []
        for root, _, files in os.walk(self.data_dir):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(root, file))

        if not self.image_paths:
            raise RuntimeError(f"No valid image files found in directory: {self.data_dir}")

    def get_imgs(self, img_path, transform=None, bbox=None):
        try:
            img = Image.open(img_path).convert('RGB')

            if bbox is not None:
                r = int(np.maximum(bbox[2], bbox[3]) * 0.65)
                center_x = int((2 * bbox[0] + bbox[2]) / 2)
                center_y = int((2 * bbox[1] + bbox[3]) / 2)
                y1 = max(0, center_y - r)
                y2 = min(img.height, center_y + r)
                x1 = max(0, center_x - r)
                x2 = min(img.width, center_x + r)
                img = img.crop([x1, y1, x2, y2])

            return transform(img) if transform else img

        except Exception as e:
            print(f"⚠️ Warning: Failed to load image at '{img_path}' due to: {e}")
            return None

    def __getitem__(self, index):
        for _ in range(len(self.image_paths)):
            img_path = self.image_paths[index]
            img = self.get_imgs(img_path, self.transform)
            if img is not None:
                return img
            else:
                index = (index + 1) % len(self.image_paths)
        raise RuntimeError("❌ All images in dataset are invalid or failed to load.")

    def __len__(self):
        return len(self.image_paths)


# ✅ Run FID calculation
if __name__ == "__main__":
    fid_score = calculate_fid(args, test_loader=None, device="cuda", dims=2048, num_img=5000)
    print(f"🔥 Final FID Score: {fid_score:.4f}")

✅ Using device: cuda
✅ Computing activations for real images...
✅ Computing activations for generated images...
✅ Computed FID Score: 153.7975
🔥 Final FID Score: 153.7975


In [None]:
import matplotlib.pyplot as plt

def show_images(dataset, n=5):
    plt.figure(figsize=(15,3))
    for i in range(n):
        img = dataset[i].permute(1,2,0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # Unnormalize
        plt.subplot(1,n,i+1)
        plt.imshow(np.clip(img,0,1))
        plt.axis('off')

show_images(real_dataset)  # Your real images
show_images(gen_dataset)   # Your generated images

In [8]:
import numpy as np
import warnings
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.models import inception_v3
from scipy.stats import entropy

warnings.filterwarnings("ignore")

# Define arguments manually (since Kaggle doesn't support config.cfg)
class Args:
    def __init__(self):
        self.gener_batch_size = 32  # Change as per available memory

args = Args()

# Inception Score Calculation
def inception_score(imgs, Evaluate_loader, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the Inception Score (IS) of the generated images"""
    N = len(imgs)
    assert batch_size > 0 and N > batch_size

    device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")

    # Load Inception model
    inception_model = inception_v3(weights='IMAGENET1K_V1', transform_input=False).to(device)
    inception_model.eval()

    up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)
    
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x, dim=1).detach().cpu().numpy()

    preds = np.zeros((N, 1000))

    for i, batch in enumerate(Evaluate_loader):
        batch = batch.to(device)
        batchv = Variable(batch)
        batch_size_i = batch.size(0)
        preds[i * batch_size: i * batch_size + batch_size_i] = get_pred(batchv)

    # Compute mean KL-divergence
    split_scores = []
    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = [entropy(pyx, py) for pyx in part]
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores)

# Custom Dataset
class Sample_evaluate(Dataset):
    def __init__(self, path, transform=None):
        self.transform = transform
        self.data_dir = path
        self.images_names = os.listdir(self.data_dir)

    def get_imgs(self, img_path, transform=None):
        img = Image.open(img_path).convert('RGB')
        if transform:
            img = transform(img)
        return img

    def __getitem__(self, index):
        key = self.images_names[index]
        img_name = os.path.join(self.data_dir, key)
        return self.get_imgs(img_name, self.transform)

    def __len__(self):
        return len(self.images_names)

# Ignore Label Dataset
class IgnoreLabelDataset(Dataset):
    def __init__(self, orig):
        self.orig = orig

    def __getitem__(self, index):
        return self.orig[index]

    def __len__(self):
        return len(self.orig)

# Function to calculate IS
def calculate_IS(path):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = Sample_evaluate(path, transform=transform)
    Evaluate_loader = DataLoader(dataset, batch_size=args.gener_batch_size, shuffle=False, num_workers=2, pin_memory=True, drop_last=True)
    
    print("Calculating Inception Score...")
    IS_score = inception_score(IgnoreLabelDataset(dataset), Evaluate_loader, cuda=True, batch_size=args.gener_batch_size, resize=True, splits=10)
    
    return IS_score


In [9]:
import numpy as np
import warnings
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torchvision.models import inception_v3, Inception_V3_Weights
from scipy.stats import entropy

warnings.filterwarnings("ignore")

class Args:
    def __init__(self):
        self.gener_batch_size = 32  # Adjust based on available memory

args = Args()

class SampleDataset(Dataset):
    def __init__(self, path, transform=None):
        self.transform = transform or transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        self.image_paths = [os.path.join(path, f) for f in os.listdir(path) 
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
    def __getitem__(self, index):
        img = Image.open(self.image_paths[index]).convert('RGB')
        return self.transform(img)
    
    def __len__(self):
        return len(self.image_paths)

def inception_score(dataset, batch_size=32, splits=10, cuda=True):
    """Computes Inception Score for a given dataset"""
    N = len(dataset)
    device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
    
    # Load inception model
    model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1, 
                        transform_input=False).to(device)
    model.eval()
    
    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    # Get predictions
    preds = []
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            outputs = model(batch)
            preds.append(F.softmax(outputs, dim=1).cpu().numpy())
    
    preds = np.concatenate(preds, axis=0)
    
    # Compute IS
    scores = []
    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores.append(np.exp(np.mean([entropy(p, py) for p in part])))
    
    return np.mean(scores), np.std(scores)

def calculate_inception_score(image_dir):
    transform = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    dataset = SampleDataset(image_dir, transform=transform)
    mean, std = inception_score(dataset, batch_size=args.gener_batch_size)
    print(f"Inception Score: {mean:.2f} ± {std:.2f}")
    return mean

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
from torch.autograd import Variable
import torch.utils.data
from torchvision.models.inception import inception_v3
import numpy as np
from scipy.stats import entropy
args = cfg.parse_args()

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url



FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'  # noqa: E501


class InceptionV3(nn.Module):
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,   # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3  # Final average pooling features
    }

    def __init__(self,
                 output_blocks=(DEFAULT_BLOCK_INDEX,),
                 resize_input=True,
                 normalize_input=True,
                 requires_grad=False,
                 use_fid_inception=True):
     
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, \
            'Last possible output block index is 3'

        self.blocks = nn.ModuleList()

        if use_fid_inception:
            inception = fid_inception_v3()
        else:
            inception = _inception_v3(pretrained=True)

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))


        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
       
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x,
                              size=(299, 299),
                              mode='bilinear',
                              align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp


def fid_inception_v3():
    inception = models.inception_v3(num_classes=1008,
                              aux_logits=False,
                              pretrained=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    #state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    state_dict = \
        torch.load(args.inception,map_location=lambda storage, loc: storage)
    inception.load_state_dict(state_dict)
    return inception

class FIDInceptionA(torchvision.models.inception.InceptionA):
    """InceptionA block patched for FID computation"""
    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionC(torchvision.models.inception.InceptionC):
    """InceptionC block patched for FID computation"""
    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionE_1(torchvision.models.inception.InceptionE):
    """First InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                   count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionE_2(torchvision.models.inception.InceptionE):
    """Second InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: The FID Inception model uses max pooling instead of average
        # pooling. This is likely an error in this specific Inception
        # implementation, as other Inception models use average pooling here
        # (which matches the description in the paper).
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

def inception_score(imgs,Evaluate_loader, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = Evaluate_loader

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        preds = F.softmax(x, dim=1).data.cpu().numpy()
        # Ensure no NaN or Inf values
        if np.any(np.isnan(preds)) or np.any(np.isinf(preds)):
            print("Warning: NaN or Inf in prediction!")
            preds = np.zeros_like(preds)
        return preds


    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)


if __name__ == '__main__':
    class IgnoreLabelDataset(torch.utils.data.Dataset):
        def __init__(self, orig):
            self.orig = orig

        def __getitem__(self, index):
            return self.orig[index][0]

        def __len__(self):
            return len(self.orig)

    import torchvision.datasets as dset
    import torchvision.transforms as transforms


    transform = transforms.Compose([
        transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    data_dir = "/kaggle/working/test_outputs"
    dataset = Sample_evaluate(data_dir, transform=transform)  # Pass the correct path to the images

    # Create DataLoader for evaluation
    Evaluate_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=8, shuffle=True, num_workers=0, pin_memory=True, drop_last=True
    )
    
    # Wrap dataset for Inception Score computation
    IgnoreLabelDataset(dataset)
    
    # Compute Inception Score
    print("Calculating Inception Score...")
    print(inception_score(IgnoreLabelDataset(dataset), Evaluate_loader, cuda=True, batch_size=8, resize=True, splits=10))

In [11]:
import os
import numpy as np
from scipy.stats import entropy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import models, transforms
from torchvision.models.inception import inception_v3
from PIL import Image

# Configuration class
class Args:
    def __init__(self):
        self.gener_batch_size = 32
        self.inception = None  # Path to FID weights if needed

args = Args()

# Dataset class that handles subdirectories
class Sample_evaluate(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform or transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        self.image_paths = []
        
        # Recursively find all images in subdirectories
        for dirpath, _, filenames in os.walk(root_dir):
            for filename in filenames:
                if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(dirpath, filename))
        
    def __getitem__(self, index):
        try:
            img = Image.open(self.image_paths[index]).convert('RGB')
            if self.transform is not None:
                img = self.transform(img)
            return img
        except Exception as e:
            print(f"Error loading image {self.image_paths[index]}: {e}")
            # Return blank image if loading fails
            return torch.zeros(3, 299, 299)
        
    def __len__(self):
        return len(self.image_paths)

# Dataset wrapper to ignore labels
class IgnoreLabelDataset(Dataset):
    def __init__(self, orig):
        self.orig = orig

    def __getitem__(self, index):
        return self.orig[index]

    def __len__(self):
        return len(self.orig)

# FID-specific InceptionV3 implementation
class FIDInceptionA(torchvision.models.inception.InceptionA):
    """InceptionA block patched for FID computation"""
    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                 count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)
        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionC(torchvision.models.inception.InceptionC):
    """InceptionC block patched for FID computation"""
    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)
        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                 count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)
        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionE_1(torchvision.models.inception.InceptionE):
    """First InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
                                 count_include_pad=False)
        branch_pool = self.branch_pool(branch_pool)
        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

class FIDInceptionE_2(torchvision.models.inception.InceptionE):
    """Second InceptionE block patched for FID computation"""
    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)
        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)
        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

def fid_inception_v3():
    """Build pretrained Inception model for FID computation"""
    inception = models.inception_v3(num_classes=1008,
                                  aux_logits=False,
                                  pretrained=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    if args.inception:
        state_dict = torch.load(args.inception, 
                              map_location=lambda storage, loc: storage)
    else:
        # Load default pretrained weights
        state_dict = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1).state_dict()
    
    inception.load_state_dict(state_dict)
    return inception

def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=10):
    """Computes the inception score of the generated images imgs
    
    Args:
        imgs: Torch dataset of (3xHxW) numpy images normalized in [-1, 1]
        cuda: whether or not to run on GPU
        batch_size: batch size for feeding into Inception v3
        splits: number of splits
    Returns:
        Mean and std of inception score
    """
    N = len(imgs)
    
    if N < batch_size:
        print(f"Warning: Increasing batch size from {batch_size} to {N} because not enough samples")
        batch_size = N

    # Set up device
    device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
    inception_model.eval()
    
    # Upsampler if needed
    up = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False).to(device)

    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x, dim=1).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))
    dataloader = DataLoader(imgs, batch_size=batch_size)

    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        batch_size_i = batch.size()[0]

        with torch.no_grad():
            preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batch)

    # Now compute the mean kl-div
    split_scores = []
    eps = 1e-16  # Small constant for numerical stability

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        
        for i in range(part.shape[0]):
            pyx = part[i, :]
            # Add epsilon to avoid log(0)
            pyx = np.maximum(pyx, eps)
            py = np.maximum(py, eps)
            scores.append(entropy(pyx, py))
        
        # Check for invalid values before exponentiation
        if len(scores) > 0:
            mean_score = np.mean(scores)
            if not np.isnan(mean_score) and not np.isinf(mean_score):
                split_scores.append(np.exp(mean_score))
            else:
                print(f"Warning: Invalid score in split {k}")

    if not split_scores:  # If all splits had issues
        return float('nan'), float('nan')
    
    return np.mean(split_scores), np.std(split_scores)

if __name__ == '__main__':
    # Set up transforms - must match Inception's expected input
    transform = transforms.Compose([
        transforms.Resize((299, 299)),  # Inception expects 299x299
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Path to your images (with subdirectories for classes)
    data_dir = "/kaggle/working/test_outputs"
    
    # Create dataset
    dataset = Sample_evaluate(data_dir, transform=transform)
    
    # Wrap dataset to ignore labels
    ignore_label_dataset = IgnoreLabelDataset(dataset)
    
    # Check if we have enough images
    if len(ignore_label_dataset) == 0:
        print("Error: No images found in directory")
    else:
        print(f"Found {len(ignore_label_dataset)} images")
        
        # Compute Inception Score
        print("Calculating Inception Score...")
        mean, std = inception_score(
            ignore_label_dataset,
            cuda=True,
            batch_size=args.gener_batch_size,
            resize=False,  # We already resized to 299x299
            splits=10
        )
        
        print(f"Inception Score: {mean:.2f} ± {std:.2f}")

Error: No images found in directory


In [12]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import Inception_V3_Weights
from scipy.stats import entropy
from PIL import Image

class Args:
    def __init__(self):
        self.gener_batch_size = 32
        self.inception = None  # Path to FID weights if needed

args = Args()

class Sample_evaluate(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform or transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        self.image_paths = []
        
        # Walk through all subdirectories to find images
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(subdir, file))
        
    def __getitem__(self, index):
        try:
            img = Image.open(self.image_paths[index]).convert('RGB')
            return self.transform(img)
        except Exception as e:
            print(f"Error loading image {self.image_paths[index]}: {e}")
            # Return blank image if loading fails
            return torch.zeros(3, 299, 299)
        
    def __len__(self):
        return len(self.image_paths)

class IgnoreLabelDataset(Dataset):
    def __init__(self, orig):
        self.orig = orig

    def __getitem__(self, index):
        return self.orig[index]

    def __len__(self):
        return len(self.orig)

def inception_score(dataset, batch_size=32, splits=10, cuda=True):
    """Compute Inception Score for a dataset"""
    N = len(dataset)
    device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
    
    # Load inception model
    model = models.inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1,
                               transform_input=False).to(device)
    model.eval()
    
    # Get predictions
    dataloader = DataLoader(dataset, batch_size=batch_size)
    preds = []
    
    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            outputs = model(batch)
            probs = F.softmax(outputs, dim=1)
            preds.append(probs.cpu().numpy())
    
    preds = np.concatenate(preds, axis=0)
    
    # Calculate IS with numerical stability
    scores = []
    eps = 1e-16  # Small constant to avoid log(0)
    
    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0) + eps
        part = part + eps
        
        # Calculate KL divergence
        kl = part * (np.log(part) - np.log(py))
        kl = np.sum(kl, axis=1)
        scores.append(np.exp(np.mean(kl)))
    
    return np.mean(scores), np.std(scores)

def calculate_inception_score(image_dir):
    """Main function to calculate Inception Score"""
    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    try:
        dataset = Sample_evaluate(image_dir, transform=transform)
        if len(dataset) == 0:
            raise ValueError("No valid images found in directory")
        
        # Wrap dataset to ignore labels (if any)
        dataset = IgnoreLabelDataset(dataset)
        
        mean, std = inception_score(
            dataset,
            batch_size=args.gener_batch_size,
            splits=10,
            cuda=True
        )
        print(f"Inception Score: {mean:.2f} ± {std:.2f}")
        return mean
        
    except Exception as e:
        print(f"Error calculating Inception Score: {e}")
        return float('nan')

if __name__ == "__main__":
    # Point to the root directory containing class subdirectories
    image_dir = "/kaggle/input/dataset-cub/CUB-200-2011/images"  
    score = calculate_inception_score(image_dir)
    print(f"Final Inception Score: {score:.2f}")

In [13]:
import torch
import torch.nn as nn
import numpy as np

def compute_discriminator_loss(args, netD, real_imgs, fake_imgs, conditions=None, dis_mode='full'):
    device = real_imgs.device
    criterion = nn.BCELoss()
    batch_size = real_imgs.size(0)

    real_label = torch.full((batch_size,), 1., dtype=torch.float, device=device)
    fake_label = torch.full((batch_size,), 0., dtype=torch.float, device=device)

    if getattr(args, "Iscondtion", False):  # Ensure args.Iscondtion exists
        # Real pairs
        real_logits = netD(real_imgs, conditions, dis_mode).view(-1)
        errD_real = criterion(real_logits, real_label)

        # Wrong pairs (shifted conditions)
        wrong_logits = netD(real_imgs[:batch_size - 1], conditions[1:]).view(-1)
        errD_wrong = criterion(wrong_logits, fake_label[1:])

        # Fake pairs
        fake_logits = netD(fake_imgs, conditions, dis_mode).view(-1)
        errD_fake = criterion(fake_logits, fake_label)

        errD = errD_real + (errD_fake + errD_wrong) * 0.5
        return errD, errD_real, errD_wrong, errD_fake
    else:
        real_logits = netD(real_imgs, dis_mode).view(-1)
        errD_real = criterion(real_logits, real_label)

        fake_logits = netD(fake_imgs, dis_mode).view(-1)
        errD_fake = criterion(fake_logits, fake_label)

        errD = errD_real + errD_fake
        return errD, errD_real, errD_fake

def compute_generator_loss(args, netD, fake_imgs, conditions=None, dis_mode='full'):
    device = fake_imgs.device
    real_label = torch.full((fake_imgs.shape[0],), 1., dtype=torch.float, device=device)
    criterion = nn.BCELoss()

    if getattr(args, "Iscondtion", False):  # Conditional
        cond = conditions.detach()
        fake_logits = netD(fake_imgs, cond, dis_mode).view(-1)
    else:
        fake_logits = netD(fake_imgs, dis_mode).view(-1)

    errG = criterion(fake_logits, real_label)
    return errG


In [14]:
import torch
from torch import nn
import math
import warnings
from itertools import repeat
from collections.abc import Iterable  # Replaces deprecated torch._six

# DropPath (Stochastic Depth)
def drop_path(x, drop_prob: float = 0., training: bool = False):
    """Drop paths (Stochastic Depth) per sample (for residual blocks)."""
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # Work with different tensor dims
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # Binarize
    return x.div(keep_prob) * random_tensor

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) for residual blocks."""
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

# Function to handle different tuple sizes
def _ntuple(n):
    def parse(x):
        if isinstance(x, Iterable):
            return x
        return tuple(repeat(x, n))
    return parse

to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)

# Truncated Normal Distribution
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    """Fills tensor with values from a truncated normal distribution."""
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("Mean is more than 2 std from [a, b]. The distribution may be incorrect.", stacklevel=2)

    with torch.no_grad():
        l, u = norm_cdf((a - mean) / std), norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1).erfinv_()
        tensor.mul_(std * math.sqrt(2.)).add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    """Truncated normal distribution initialization."""
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

# Example usage (Colab-ready)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.randn(10, 10).to(device)
drop_layer = DropPath(0.5).to(device)

print("Original Tensor:\n", x)
print("After DropPath:\n", drop_layer(x))


Original Tensor:
 tensor([[-0.1655,  0.2034,  0.7755,  1.2473,  0.5719,  0.8826,  0.7003, -0.3900,
          0.7372, -2.3263],
        [ 0.4776, -0.6503,  0.9010, -0.5001, -1.0648, -1.3095, -0.0838,  0.0549,
          0.3487, -1.1700],
        [-0.1974,  0.7314, -1.0342,  1.2728,  0.4853, -1.4157,  0.8394, -0.5128,
         -0.3210, -0.5561],
        [ 0.5385, -1.6113, -1.0079,  1.1051, -0.1755, -0.2883,  1.6926, -1.1681,
         -0.5770,  1.1492],
        [-0.0103,  0.6808, -0.4248,  0.4011, -0.7257, -0.1550, -0.2028, -0.3032,
         -0.6097, -0.4261],
        [-1.7541, -0.3349,  0.4381,  0.4129, -0.0070,  1.7640,  1.5498,  0.1098,
          0.0137, -0.1646],
        [-0.1620,  0.2261,  0.3850,  0.3997, -0.4631,  0.6551, -0.4844,  1.0146,
         -0.6451, -0.0915],
        [-0.3202,  1.7779, -0.7696,  0.0527,  0.8011, -0.7848,  0.8314,  1.6104,
          0.8489, -0.3973],
        [-1.4742,  0.2968, -1.5962,  0.9027,  0.7020,  0.8075,  0.7069, -0.2810,
          1.8428,  0.6330],
 

In [15]:
import math
import torch
import torch.nn as nn
from pdb import set_trace as stx


def UpSampling(x, H, W):
    B, N, C = x.size()
    
    # Check if the number of channels is divisible by 4 (as PixelShuffle requires this)
    assert C % 4 == 0, "Number of channels must be divisible by 4 for PixelShuffle"
    
    # Check if the input size matches the expected height and width
    assert N == H * W, "The number of tokens (N) must match the height * width (H * W)"
    
    # Permute and reshape the tensor to match PixelShuffle input format
    x = x.permute(0, 2, 1).view(-1, C, H, W)
    
    # Apply PixelShuffle for upsampling
    x = nn.PixelShuffle(2)(x)  # This upsamples spatial dimensions by a factor of 2
    
    # Get the new height and width
    _, C, H, W = x.size()  # Update the H, W after PixelShuffle

    # Reshape and permute the tensor back to match the original format
    x = x.view(-1, C, H * W).permute(0, 2, 1)
    
    return x, H, W

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = CustomAct(act_layer)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class matmul(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x1, x2):
        x = x1 @ x2
        return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)

            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x



In [16]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=16):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.mat = matmul()
        self.window_size = window_size
        if self.window_size != 0:
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(window_size)
            coords_w = torch.arange(window_size)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += window_size - 1  # shift to start from 0
            relative_coords[:, :, 1] += window_size - 1
            relative_coords[:, :, 0] *= 2 * window_size - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

            trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        attn = (self.mat(q, k.transpose(-2, -1))) * self.scale
        if self.window_size != 0:
            relative_position_bias = self.relative_position_bias_table[
                self.relative_position_index.view(-1).clone()].view(
                self.window_size * self.window_size, self.window_size * self.window_size, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            attn = attn + relative_position_bias.unsqueeze(0)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = self.mat(attn, v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


In [17]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=16):
        super().__init__()
        self.norm1 = CustomNorm(norm_layer, dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
            window_size=window_size)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = CustomNorm(norm_layer, dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class SwinBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=16, shift_size=8,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution

            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            # Unpack the returned tuple
            mask_windows, B, N = window_partition(img_mask, self.window_size)
            
            # Apply view operation only on the tensor
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)

            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None
        self.register_buffer("attn_mask", attn_mask)
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
        # partition windows
        x_windows, B, N = window_partition(shifted_x, self.window_size)  # ✅ Unpack the tuple correctly
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # ✅ Now works correctly

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)
        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


In [18]:
class TransformerEncoder(nn.Module):
    def __init__(self, depth, dim, heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=16):
        super().__init__()
        self.depth = depth
        models = [Block(
            dim=dim,
            num_heads=heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            drop=drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer,
            window_size=window_size
        ) for i in range(depth)]
        self.block = nn.Sequential(*models)

    def forward(self, x):
        x = self.block(x)
        return x


class SwinTransformerEncoder(nn.Module):
    def __init__(self, depth, dim, input_resolution, heads=4, window_size=16, shift_size=8, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.depth = depth
        models = [SwinBlock(
            dim=dim,
            input_resolution=input_resolution,
            num_heads=heads,
            window_size=window_size,
            shift_size=0 if (i % 2 == 0) else window_size // 2,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            drop=drop,
            attn_drop=attn_drop,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer,
        ) for i in range(depth)]
        self.block = nn.Sequential(*models)

    def forward(self, x):
        x = self.block(x)
        return x


def bicubic_upsample(x, H, W):
    B, N, C = x.size()
    assert N == H * W
    x = x.permute(0, 2, 1)
    x = x.view(-1, C, H, W)
    x = nn.functional.interpolate(x, scale_factor=2, mode='bicubic')
    B, C, H, W = x.size()
    x = x.view(-1, C, H * W)
    x = x.permute(0, 2, 1)
    return x, H, W


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    x = x.permute(0, 2, 3, 1)  # Convert shape to (B, H, W, C)
    B, H, W, C = x.shape

    # ✅ Ensure H and W are valid for partitioning
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size

    x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h), mode="constant", value=0)  # Pad with zeros
    H, W = x.shape[1], x.shape[2]  # Update new H and W

    # ✅ Ensure valid shape
    assert H > 0 and W > 0, f"Invalid H={H}, W={W} after padding"
    assert H % window_size == 0 and W % window_size == 0, "H or W is not divisible by window_size"

    # ✅ Partition the window
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(-1, C, window_size, window_size)

    return x, B, H // window_size  # Return the partitioned tensor and dimensions


def window_reverse(x, window_size, H, W):
    """
    Args:
        x: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of original image
        W (int): Width of original image
    Returns:
        x: (B, H, W, C)
    """
    BNN, C, H_W, W_W = x.shape  # Extract batch info

    N = H // window_size  # Number of windows per row/column

    # ✅ Correct reshaping
    x = x.view(BNN // (N * N), N, N, C, window_size, window_size)
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
    x = x.view(-1, C, H, W)  # Restore original shape

    return x



class CustomNorm(nn.Module):
    def __init__(self, norm_layer, dim):
        super().__init__()
        self.norm_type = norm_layer
        self.norm = nn.LayerNorm(dim)
    def forward(self, x):
        return self.norm(x)

class CustomAct(nn.Module):
    def __init__(self, act_layer):
        super().__init__()
        self.act_layer = nn.GELU()
    def forward(self, x):
        return self.act_layer(x)



def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)



In [19]:

class Generator64(nn.Module):

    def __init__(self, args, initial_size=8, dim=1024, heads=4, mlp_ratio=4, drop_rate=0.,
                 window_size=16,depth=[5,4,4,4]):
        super(Generator64, self).__init__()

        self.initial_size = initial_size
        self.dim = dim
        self.args = args
        self.window_size = window_size
        self.c_dim = args.CONDITION_DIM
        self.z_dim = args.z_dim
        self.heads = heads
        self.mlp_ratio = mlp_ratio
        self.droprate_rate = drop_rate
     
        if args.Iscondtion:
            self.mlp = nn.Linear(self.c_dim + self.z_dim, (self.initial_size ** 2) * self.dim)
        else:
            self.mlp = nn.Linear(self.z_dim, (self.initial_size ** 2) * self.dim)

        self.positional_embedding_1 = nn.Parameter(torch.zeros(1, (self.initial_size ** 2), self.dim))
        self.positional_embedding_2 = nn.Parameter(torch.zeros(1, (self.initial_size * 2) ** 2, self.dim // 4))
        self.positional_embedding_3 = nn.Parameter(torch.zeros(1, (self.initial_size * 4) ** 2, self.dim // 16))
        self.positional_embedding_4 = nn.Parameter(torch.zeros(1, 16, self.window_size ** 2, self.dim // 64))

        self.TransformerEncoder_encoder1 = TransformerEncoder(depth[0], dim=self.dim, heads=self.heads,
                                                              mlp_ratio=self.mlp_ratio, qkv_bias=False,
                                                              qk_scale=None, drop=drop_rate, attn_drop=0.,
                                                              drop_path=0., act_layer=args.g_act,
                                                              norm_layer=args.g_norm,
                                                              window_size=8)
        self.TransformerEncoder_encoder2 = TransformerEncoder(depth[1], dim=self.dim // 4, heads=self.heads,
                                                              mlp_ratio=self.mlp_ratio, qkv_bias=False,
                                                              qk_scale=None, drop=drop_rate, attn_drop=0.,
                                                              drop_path=0., act_layer=args.g_act,
                                                              norm_layer=args.g_norm,
                                                              window_size=16)
        self.TransformerEncoder_encoder3 = TransformerEncoder(depth[2], dim=self.dim // 16, heads=self.heads,
                                                              mlp_ratio=self.mlp_ratio, qkv_bias=False,
                                                              qk_scale=None, drop=drop_rate, attn_drop=0.,
                                                              drop_path=0., act_layer=args.g_act,
                                                              norm_layer=args.g_norm,
                                                              window_size=32)
        self.TransformerEncoder_encoder4 = SwinTransformerEncoder(depth[3], input_resolution=(64, 64),
                                                              dim=self.dim // 64, heads=self.heads,
                                                              window_size=self.window_size,
                                                              shift_size=self.window_size//2,
                                                              mlp_ratio=4., qkv_bias=False,
                                                              qk_scale=None, drop=0., attn_drop=0.,
                                                              drop_path=0., act_layer=nn.GELU,
                                                              norm_layer=nn.LayerNorm)
        self.norm = nn.LayerNorm(16)
        self.linear = nn.Sequential(nn.Conv2d(self.dim // 64, 3, 1, 1, 0))
        
    def forward(self, z_code, sent_emb=None, output_dir="output", step=0):
        if self.args.Iscondtion:
            c_z_code = torch.cat((z_code, sent_emb), 1)
        else:
            c_z_code = z_code

        x = self.mlp(c_z_code).view(-1, self.initial_size ** 2, self.dim)
        x = x + self.positional_embedding_1
        H, W = self.initial_size, self.initial_size
        x = self.TransformerEncoder_encoder1(x)

    

        x, H, W = UpSampling(x, H, W)
        x = x + self.positional_embedding_2
        x = self.TransformerEncoder_encoder2(x)

    

        x, H, W = UpSampling(x, H, W)
        x = x + self.positional_embedding_3
        x = self.TransformerEncoder_encoder3(x)


        x, H, W = UpSampling(x, H, W)
        x = self.TransformerEncoder_encoder4(x)
        x = self.norm(x)
        x = self.linear(x.permute(0, 2, 1).view(-1, self.dim // 64, H, W))

        return x


class Generator128(nn.Module):
    def __init__(self, args, dim=1024, heads=4, mlp_ratio=4, H=16, W=16, drop_rate=0.,depth=[5,4,4,4]):
        super(Generator128, self).__init__()
        self.conv_dim = 256
        self.CONDITION_DIM = args.CONDITION_DIM
        self.args = args
        self.window_size = 16
        self.shift_size = self.window_size // 2
        self.dim = dim
        self.H = H
        self.W = W
        self.heads = heads
        self.mlp_ratio = mlp_ratio
        self.droprate_rate = drop_rate
        self.positional_embedding_0 = nn.Parameter(torch.zeros(1, (4 * 4) ** 2, self.dim))
        self.positional_embedding_1 = nn.Parameter(torch.zeros(1, (4 * 8) ** 2, self.dim // 4))
        self.positional_embedding_2 = nn.Parameter(torch.zeros(1, 16, self.window_size ** 2, self.dim // 16))
        self.positional_embedding_3 = nn.Parameter(torch.zeros(1, (16*8) ** 2, self.dim // 64))

        self.TransformerEncoder_encoder0 = TransformerEncoder(depth[0], dim=self.dim, heads=self.heads,
                                                              mlp_ratio=4., qkv_bias=False,
                                                              qk_scale=None, drop=0., attn_drop=0.,
                                                              drop_path=0., act_layer=nn.GELU,
                                                              norm_layer=nn.LayerNorm,
                                                              window_size=16)

        self.TransformerEncoder_encoder1 = TransformerEncoder(depth[1], dim=self.dim // 4, heads=self.heads,
                                                              mlp_ratio=4., qkv_bias=False,
                                                              qk_scale=None, drop=0., attn_drop=0.,
                                                              drop_path=0., act_layer=nn.GELU,
                                                              norm_layer=nn.LayerNorm,
                                                              window_size=32)

        self.TransformerEncoder_encoder2 = SwinTransformerEncoder(depth[2], input_resolution=(64, 64),
                                                                  dim=self.dim // 16, heads=self.heads,
                                                                  window_size=self.window_size,
                                                                  shift_size=self.shift_size,
                                                                  mlp_ratio=4., qkv_bias=False,
                                                                  qk_scale=None, drop=0., attn_drop=0.,
                                                                  drop_path=0., act_layer=nn.GELU,
                                                                  norm_layer=nn.LayerNorm)

        self.TransformerEncoder_encoder3 = SwinTransformerEncoder(depth[3], input_resolution=(128, 128),
                                                                  dim=self.dim // 64, heads=self.heads,
                                                                  window_size=self.window_size,
                                                                  shift_size=self.shift_size,
                                                                  mlp_ratio=4., qkv_bias=False,
                                                                  qk_scale=None, drop=0., attn_drop=0.,
                                                                  drop_path=0., act_layer=nn.GELU,
                                                                  norm_layer=nn.LayerNorm
                                                                  )

        self.deconv = nn.Sequential(nn.Conv2d(16, 3, 1, 1, 0))
        self.norm = nn.LayerNorm(16)

        self.encoder = nn.Sequential(
            conv3x3(3, self.conv_dim),
            nn.GELU(),
            nn.Conv2d(self.conv_dim, self.conv_dim * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.conv_dim * 2),
            nn.GELU(),
            nn.Conv2d(self.conv_dim * 2, self.conv_dim * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.conv_dim * 4),
            nn.GELU(),
        )
        self.hr_joint = nn.Sequential(
            conv3x3(self.CONDITION_DIM + self.conv_dim * 4, self.conv_dim * 4),
            nn.BatchNorm2d(self.conv_dim * 4),
            nn.GELU())

    def cat(self, fake_img_feature, sent_emb):
        c_code = sent_emb.view(-1, self.CONDITION_DIM, 1, 1)
        c_code = c_code.repeat(1, 1, fake_img_feature.shape[2], fake_img_feature.shape[2])
        i_c_code = torch.cat([fake_img_feature, c_code], 1)
        return i_c_code

    def forward(self, stageI_images, sent_emb=None):

        if self.args.Iscondtion:
            
            fake_img_feature = self.encoder(stageI_images)
            i_c_code = self.cat(fake_img_feature, sent_emb)
            fake_img1_feature = self.hr_joint(i_c_code)
        else:
            fake_img1_feature = self.encoder(stageI_images)

        x = fake_img1_feature.view(-1, self.dim, self.H * self.W).permute(0, 2, 1)
        x = x + self.positional_embedding_0
        x = self.TransformerEncoder_encoder0(x)

        x, H, W = UpSampling(x, self.H, self.W)
        x = x + self.positional_embedding_1
        B, _, C = x.size()
        x = self.TransformerEncoder_encoder1(x)

        x, H, W = UpSampling(x, H, W)
        B, _, C = x.size()
        #x = x + self.positional_embedding_2
        x = self.TransformerEncoder_encoder2(x)

        x, H, W = UpSampling(x, H, W)
        B, _, C = x.size()
        #x = x + self.positional_embedding_3
        x = self.TransformerEncoder_encoder3(x)
        x=self.norm(x)
        x = x.permute(0, 2, 1).view(-1, C, H, W)
        fake_image = self.deconv(x)
        
        return fake_image

class Generator256(nn.Module):
    def __init__(self, args, dim=1024, heads=4, mlp_ratio=4, H=32, W=32, drop_rate=0., depth=[5,4,4,4,4]):
        super(Generator256, self).__init__()
        self.args = args
        self.dim = dim
        self.H = H
        self.W = W
        self.CONDITION_DIM = args.CONDITION_DIM if hasattr(args, 'CONDITION_DIM') else 128
        self.z_dim = args.z_dim if hasattr(args, 'z_dim') else 100
        self.Iscondtion = args.Iscondtion if hasattr(args, 'Iscondtion') else True
        
        # Initialize all layers from your checkpoint
        self.positional_embedding_1 = nn.Parameter(torch.zeros(1, H*W, dim))
        self.positional_embedding_2 = nn.Parameter(torch.zeros(1, (H*2)*(W*2), dim//4))
        self.positional_embedding_3 = nn.Parameter(torch.zeros(1, (H*4)*(W*4), dim//16))
        self.positional_embedding_4 = nn.Parameter(torch.zeros(1, (H*8)*(W*8), dim//64))
        
        # MLP layer to project z_code + sent_emb to initial dimensions
        if self.Iscondtion:
            self.mlp = nn.Linear(self.z_dim + self.CONDITION_DIM, (H * W) * dim)
        else:
            self.mlp = nn.Linear(self.z_dim, (H * W) * dim)
        
        # Transformer blocks
        self.TransformerEncoder_encoder1 = TransformerEncoder(depth[0], dim=dim, heads=heads,
                                                             mlp_ratio=mlp_ratio, drop_rate=drop_rate)
        self.TransformerEncoder_encoder2 = TransformerEncoder(depth[1], dim=dim//4, heads=heads,
                                                             mlp_ratio=mlp_ratio, drop_rate=drop_rate)
        self.TransformerEncoder_encoder3 = SwinTransformerEncoder(depth[2], dim=dim//16, heads=heads,
                                                                window_size=16, drop_rate=drop_rate)
        self.TransformerEncoder_encoder4 = SwinTransformerEncoder(depth[3], dim=dim//64, heads=heads,
                                                                window_size=16, drop_rate=drop_rate)
        
        # Final layers
        self.norm = nn.LayerNorm(dim//64)
        self.linear = nn.Sequential(
            nn.Conv2d(dim//64, 3, kernel_size=1, stride=1, padding=0)
        )

    def forward(self, z_code, sent_emb=None):
        if self.Iscondtion:
            assert sent_emb is not None, "Conditional generation requires sent_emb"
            z = torch.cat([z_code, sent_emb], dim=1)
        else:
            z = z_code
            
        # Project to initial dimensions
        x = self.mlp(z).view(-1, self.H*self.W, self.dim)
        x = x + self.positional_embedding_1
        
        # Transformer blocks with upsampling
        x = self.TransformerEncoder_encoder1(x)
        x, H, W = UpSampling(x, self.H, self.W)
        x = x + self.positional_embedding_2
        
        x = self.TransformerEncoder_encoder2(x)
        x, H, W = UpSampling(x, H, W)
        x = x + self.positional_embedding_3
        
        x = self.TransformerEncoder_encoder3(x)
        x, H, W = UpSampling(x, H, W)
        x = x + self.positional_embedding_4
        
        x = self.TransformerEncoder_encoder4(x)
        
        # Final processing
        x = self.norm(x)
        x = x.permute(0, 2, 1).view(-1, self.dim//64, H, W)
        x = self.linear(x)
        
        return torch.tanh(x)  # Normalize to [-1, 1]

# Update your generate_images_from_text function
def generate_images_from_text(generator, text_descriptions, text_encoder, vocab, device, save_dir, n_images=1, z_dim=100):
    generator.eval()
    text_encoder.eval()
    os.makedirs(save_dir, exist_ok=True)
    
    for i, text in enumerate(text_descriptions):
        print(f"\nGenerating image for: '{text}'")
        
        with torch.no_grad():
            # Encode text
            sent_emb = encode_text(text, text_encoder, device, vocab)
            
            # Generate random noise
            z_code = torch.randn(1, z_dim).to(device)
            
            # Generate multiple images
            for j in range(n_images):
                fake_imgs = generator(z_code, sent_emb)
                
                # Save image
                save_path = os.path.join(save_dir, f"text2img_{i}_{j}.png")
                save_image(fake_imgs, save_path, normalize=True, range=(-1, 1))
                print(f"✅ Saved: {save_path}")
                
                # Display image
                display(Image.open(save_path))

In [20]:
import torch
import torch.nn as nn

class FilterModule(nn.Module):
    def __init__(self, generator, discriminator, num_candidates=5):
        super(FilterModule, self).__init__()
        self.generator = generator  # G0
        self.discriminator = discriminator  # D0
        self.num_candidates = num_candidates  # Number of images to generate

    def forward(self, text_embeddings, noise):
        batch_size = text_embeddings.size(0)

        # Efficiently expand text embeddings and noise for candidates
        text_embeddings_expanded = text_embeddings.unsqueeze(1).expand(-1, self.num_candidates, -1).contiguous()
        noise_expanded = noise.unsqueeze(1).expand(-1, self.num_candidates, -1).contiguous()

        # Flatten the expanded embeddings and noise
        text_embeddings_expanded = text_embeddings_expanded.view(-1, text_embeddings.size(1))
        noise_expanded = noise_expanded.view(-1, noise.size(1))

        # Generate candidate images
        candidate_images = self.generator(noise_expanded, text_embeddings_expanded)  # (batch_size * num_candidates, C, H, W)

        # Get scores for the candidate images
        scores = self.discriminator(candidate_images, text_embeddings_expanded)  # (batch_size * num_candidates, 1)

        # Reshape scores and candidate images for easier manipulation
        scores = scores.view(batch_size, self.num_candidates)  # (batch_size, num_candidates)
        candidate_images = candidate_images.view(batch_size, self.num_candidates, *candidate_images.shape[1:])

        # Select the best image based on the highest score
        best_indices = scores.argmax(dim=1)  # Best index for each batch
        best_images = candidate_images.gather(1, best_indices.view(-1, 1, 1).expand(-1, -1, candidate_images.size(2), candidate_images.size(3)))

        return best_images  # Shape: (batch_size, C, H, W)


In [21]:
import torch
import torch.nn as nn
from pdb import set_trace as stx

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    x = x.permute(0, 2, 3, 1)  # Convert shape to (B, H, W, C)
    B, H, W, C = x.shape

    # ✅ Ensure H and W are valid for partitioning
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size

    x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h), mode="constant", value=0)  # Pad with zeros
    H, W = x.shape[1], x.shape[2]  # Update new H and W

    # ✅ Ensure valid shape
    assert H > 0 and W > 0, f"Invalid H={H}, W={W} after padding"
    assert H % window_size == 0 and W % window_size == 0, "H or W is not divisible by window_size"

    # ✅ Partition the window
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous().view(-1, C, window_size, window_size)

    return x, B, H // window_size  # Return the partitioned tensor and dimensions


def window_reverse(x, window_size, H, W):
    """
    Args:
        x: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of original image
        W (int): Width of original image
    Returns:
        x: (B, H, W, C)
    """
    BNN, C, H_W, W_W = x.shape  # Extract batch info

    N = H // window_size  # Number of windows per row/column

    # ✅ Correct reshaping
    x = x.view(BNN // (N * N), N, N, C, window_size, window_size)
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
    x = x.view(-1, C, H, W)  # Restore original shape

    return x

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
def Block3x3_leakRelu(in_planes, out_planes):
    block = nn.Sequential(
        conv3x3(in_planes, out_planes),
        nn.BatchNorm2d(out_planes),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return block

def downBlock(in_planes, out_planes):
    block = nn.Sequential(
        nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_planes),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return block

def encode_image_by_16times(ndf):
    encode_img = nn.Sequential(
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return encode_img
def encode_image_by_16times2(ndf):
    encode_img = nn.Sequential(
        # nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
        # nn.BatchNorm2d(ndf * 2),
        # nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return encode_img


In [22]:

class Discriminator64(nn.Module):
    def __init__(self,args):
        super(Discriminator64, self).__init__()
        self.df_dim = 64
        self.args=args
        self.ef_dim = args.CONDITION_DIM
        ndf, nef = self.df_dim, self.ef_dim
        self.window_size=16
        self.fixed_size=32
        self.conv1=nn.Sequential(nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True))
        self.encode_img = encode_image_by_16times(ndf)
        self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8)
        self.outlogits = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
            nn.Sigmoid())

    def sent_process(self,img_embedding ,cond):
        cond = cond.view(-1, self.ef_dim, 1, 1)
        cond = cond.repeat(1, 1, 4, 4)
        h_c_code = torch.cat((img_embedding, cond), 1)
        return h_c_code

    def forward(self,image,cond=None,mode='full'):
        
        if mode == 'full':
            x_code = self.conv1(image)
        else:
            x_code,B,N=window_partition(image,self.window_size)
            x_code = self.conv1(x_code)
            x_code=window_reverse(x_code,self.fixed_size,B,N)

        img_embedding = self.encode_img(x_code)
        if self.args.Iscondtion:
            img_embedding=self.sent_process(img_embedding,cond)
            h_c_code=self.jointConv(img_embedding)
        else:
            h_c_code=img_embedding
        out=self.outlogits(h_c_code)
        return out


class Discriminator128(nn.Module):
    def __init__(self,args):
        super(Discriminator128, self).__init__()
        self.df_dim = 64
        self.ef_dim = args.CONDITION_DIM
        ndf, nef = self.df_dim, self.ef_dim
        self.args=args
        self.window_size=16
        self.fixed_size=32
        self.conv1=nn.Sequential(nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
                                nn.LeakyReLU(0.2, inplace=True),
                                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                                nn.BatchNorm2d(ndf * 2),
                                nn.LeakyReLU(0.2, inplace=True))
        self.img_code_s16 = encode_image_by_16times2(ndf)
        self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
        self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8)
        self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8)
        self.outlogits = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, kernel_size=3, stride=1, padding=1),
            nn.AdaptiveAvgPool2d((1, 1)),  # 🔥 Ensures final output is (1,1)
            nn.Sigmoid()
        )



    def sent_process(self, img_embedding, cond):
        cond = cond.view(-1, self.ef_dim, 1, 1)
        cond = cond.repeat(1, 1, 4, 4)
        h_c_code = torch.cat((img_embedding, cond), 1)
        return h_c_code

    def forward(self,image,cond=None,mode='full'):
        if mode=='full':
            x_code = self.conv1(image)
        else:
            x_code,B,N=window_partition(image,self.window_size)
            x_code = self.conv1(x_code)
            x_code=window_reverse(x_code,self.fixed_size,B,N)

        x_code = self.img_code_s16(x_code)
        x_code = self.img_code_s32(x_code)
        x_code = self.img_code_s32_1(x_code)
        if self.args.Iscondtion:
            img_embedding = self.sent_process(x_code, cond)
            h_c_code = self.jointConv(img_embedding)
        else:
            h_c_code=x_code

        out=self.outlogits(h_c_code)

        return out

class Discriminator256(nn.Module):
    def __init__(self, args):
        super(Discriminator256, self).__init__()
        self.df_dim = 64
        self.ef_dim = args.CONDITION_DIM
        self.args = args
        ndf, nef = self.df_dim, self.ef_dim
        self.window_size = 16
        self.fixed_size = 32

        # Initial convolution layers
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),  # 256 -> 128
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),  # 128 -> 64
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),  # 64 -> 32
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Further downsampling to 16x16, 8x8, 4x4
        self.img_code_s16 = downBlock(ndf * 4, ndf * 8)   # 32 -> 16
        self.img_code_s8 = downBlock(ndf * 8, ndf * 16)   # 16 -> 8
        self.img_code_s4 = downBlock(ndf * 16, ndf * 32)  # 8 -> 4
        self.refine = Block3x3_leakRelu(ndf * 32, ndf * 16)  # keep depth for jointConv

        # Joint conditioning
        self.jointConv = Block3x3_leakRelu(ndf * 16 + nef, ndf * 8)

        self.outlogits = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),  # 4x4 -> 1x1
            nn.Sigmoid()
        )

    def sent_process(self, img_embedding, cond):
        cond = cond.view(-1, self.ef_dim, 1, 1)
        cond = cond.repeat(1, 1, 4, 4)  # Match feature spatial size
        h_c_code = torch.cat((img_embedding, cond), 1)
        return h_c_code

    def forward(self, image, cond=None, mode='full'):
        if mode == 'full':
            x_code = self.conv1(image)
        else:
            x_code, B, N = window_partition(image, self.window_size)
            x_code = self.conv1(x_code)
            x_code = window_reverse(x_code, self.fixed_size, B, N)

        x_code = self.img_code_s16(x_code)
        x_code = self.img_code_s8(x_code)
        x_code = self.img_code_s4(x_code)
        x_code = self.refine(x_code)

        if self.args.Iscondtion:
            x_code = self.sent_process(x_code, cond)
            h_c_code = self.jointConv(x_code)
        else:
            h_c_code = x_code

        out = self.outlogits(h_c_code)
        return out


In [23]:
# Perceptual Loss (using VGG features)
class PerceptualLoss(nn.Module):
    def __init__(self, model):
        super(PerceptualLoss, self).__init__()
        self.model = model
        self.criterion = nn.MSELoss()

    def forward(self, generated, target):
        generated_features = self.model(generated)
        target_features = self.model(target)
        return self.criterion(generated_features, target_features)

# WGAN-GP Loss for the Discriminator
def wgan_gp_loss(D_real, D_fake, real_images, fake_images, lambda_gp=10):
    # Compute gradient penalty
    epsilon = torch.rand(real_images.size(0), 1, 1, 1).to(real_images.device)
    interpolates = epsilon * real_images + (1 - epsilon) * fake_images
    interpolates.requires_grad_(True)
    
    D_interpolates = D(interpolates)
    gradients = torch.autograd.grad(outputs=D_interpolates, inputs=interpolates,
                                   grad_outputs=torch.ones(D_interpolates.size()).to(real_images.device),
                                   create_graph=True, retain_graph=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    # WGAN-GP loss
    loss = -D_real.mean() + D_fake.mean() + lambda_gp * gradient_penalty
    return loss


In [24]:
batch_size = 8
img_size = 64  # Change to 128 for Discriminator128
channels = 3
condition_dim = 128

# Initialize models
D64 = Discriminator64(args)  # ✅ Fixed
D128 = Discriminator128(args)  # ✅ Fixed


# Create random fake images and text embeddings
fake_images = torch.randn(batch_size, channels, img_size, img_size)  # (B, 3, H, W)
text_embeddings = torch.randn(batch_size, condition_dim)  # (B, 128)

# Forward pass
output_64 = D64(fake_images, text_embeddings)  # Shape: (B, 1, 1, 1)
output_128 = D128(fake_images, text_embeddings)  # Shape: (B, 1, 1, 1)

print(f"D64 Output Shape: {output_64.shape}")  # Expected: (B, 1, 1, 1)
print(f"D128 Output Shape: {output_128.shape}")  # Expected: (B, 1, 1, 1)


AttributeError: 'Args' object has no attribute 'CONDITION_DIM'

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
from torchvision.utils import  save_image
import os
import numpy as np



def noise(n_samples, z_dim, device):
    return torch.randn(n_samples, z_dim).to(device)


class LinearLrDecay(object):
    def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step):

        assert start_lr > end_lr
        self.optimizer = optimizer
        self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step)
        self.decay_start_step = decay_start_step
        self.decay_end_step = decay_end_step
        self.start_lr = start_lr
        self.end_lr = end_lr
    def step(self, current_step):
        if current_step <= self.decay_start_step:
            lr = self.start_lr
        elif current_step >= self.decay_end_step:
            lr = self.end_lr
        else:
            lr = self.start_lr - self.delta * (current_step - self.decay_start_step)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        return lr


def inits_weight(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight.data, 1.)


def noise(imgs, latent_dim):
    return torch.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))


def gener_noise(gener_batch_size, latent_dim):
    return torch.FloatTensor(np.random.normal(0, 1, (gener_batch_size, latent_dim)))


def save_checkpoint(states, is_best, output_dir,
                    filename='checkpoint.pth'):
    torch.save(states, os.path.join(output_dir, filename))
    if is_best:
        torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))

def save_model(netG, netD, stage, epoch, model_dir, dataset_name, FID_value, IS_value):
    torch.save(
        netG.state_dict(),
        '%s/model/%s_stage_%d_netG_epoch_%d_fid_%2f_is_%2f.pth' % (
        model_dir, dataset_name, stage, epoch, FID_value, IS_value))
    torch.save(
        netD.state_dict(),
        '%s/model/%s_stage_%d_netD_epoch_last.pth' % (model_dir, dataset_name, stage))
    print('Save G/D models')


def pre_gen_imgs(test_emb,pre_generator, pre_discriminator,device='cuda:0'):
    expend_size = 10
    fake_img_list = []
    test_emb = test_emb.to(device)
    for i in range(test_emb.shape[0]):
        test_emb_expend = test_emb[i].unsqueeze(0).repeat((expend_size, 1))
        noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (expend_size, args.z_dim)))
        with torch.no_grad():
            fake_img = pre_generator(noise, test_emb_expend)
        # save_images(fake_imgs)
        score = pre_discriminator(fake_img, test_emb_expend)
        index = torch.argmax(score)
        fake_img_list.append(fake_img[index])
    fake_imgs = torch.stack(fake_img_list, 0)
    return fake_imgs

def mk_img(args,generator,test_loader,pre_generator=None,pre_discriminator=None,num_img=30000,batch_size=args.gener_batch_size,device='cuda:0'):
    with torch.no_grad():
        if args.STAGE==1:
            label=True
            id=0
            generator = generator.eval()
            fp=os.path.join(args.image_dir,'evaluate_images')
            if(not os.path.exists(fp)):
                os.mkdir(fp)
            print('sampling images...')
            while label:
                for index,(_,test_emb) in enumerate(test_loader):
                    noise = torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, args.z_dim)))
                    if args.Iscondtion :
                        test_emb=test_emb.to(device)
                        gen_imgs= generator(noise, test_emb)
                    else:
                        gen_imgs = generator(noise)
                    for i in range(gen_imgs.shape[0]):
                        save_name = '%s/stage-%d-%d.png' % (fp, args.STAGE, id)
                        save_image(gen_imgs[i], save_name, nrow=1, normalize=True, scale_each=True)
                        id += 1
                        if id==num_img:
                            print('finished')
                            return fp

        else:
            id = 0
            print('sampling images...')
            fp = os.path.join(args.image_dir, 'evaluate_images')
            if(not os.path.exists(fp)):
                os.mkdir(fp)
            for index ,(_,test_emb) in enumerate(test_loader) :
                print(index)
                if args.Iscondtion:
                    test_emb = test_emb.to(device)
                    stageI_imgs = pre_gen_imgs(test_emb,pre_generator, pre_discriminator,device=device)
                    gen_imgs = generator(stageI_imgs, test_emb)
                else:
                    stageI_imgs = pre_gen_imgs(test_emb,pre_generator, pre_discriminator,device=device)
                    gen_imgs = generator(stageI_imgs)
                for i in range(gen_imgs.shape[0]):
                    save_name = '%s/stage-%d-%d.png' % (fp, args.STAGE, id)
                    save_image(gen_imgs[i], save_name, nrow=1, normalize=True, scale_each=True)
                    id += 1
                    if id == num_img:
                        print('finished')
                        return fp

In [None]:
from IPython import display
from PIL import Image
import torch
import os
import warnings
from torchvision.utils import save_image

def show_image(image_path):
    img = Image.open(image_path)
    display.display(img)

warnings.filterwarnings("ignore")

# Load Configuration
args = cfg.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device:", device)

os.makedirs("/kaggle/working/models/", exist_ok=True)
os.makedirs("/kaggle/working/output_images/", exist_ok=True)

# ✅ Load Networks
def load_network_stageI(args, device):
    netG = Generator64(args).to(device)
    netD = Discriminator64(args).to(device)
    return netG, netD

def load_network_stageII(args, device):
    netG = Generator128(args).to(device)
    netD = Discriminator128(args).to(device)
    pre_generator = Generator64(args).to(device)
    pre_discriminator = Discriminator64(args).to(device)
    return netG, netD, pre_generator, pre_discriminator

def load_network_stageIII(args, device):
    netG = Generator256(args).to(device)
    netD = Discriminator256(args).to(device)
    pre_generator = Generator128(args).to(device)
    pre_discriminator = Discriminator128(args).to(device)
    return netG, netD, pre_generator, pre_discriminator

def load_weight(netG, netD, pre_generator=None, pre_discriminator=None):
    netG.apply(inits_weight)
    netD.apply(inits_weight)

    if args.NET_G != '':
        state_dict = torch.load(args.NET_G, map_location=lambda storage, loc: storage)
        netG.load_state_dict(state_dict)
        print('Loaded Generator from:', args.NET_G)

    if args.NET_D != '':
        state_dict = torch.load(args.NET_D, map_location=lambda storage, loc: storage)
        netD.load_state_dict(state_dict)
        print('Loaded Discriminator from:', args.NET_D)

    if args.STAGE == 1:
        return netG, netD

    if args.STAGE in [2, 3]:
        if args.STAGE1_G != '' and args.STAGE1_D != '':
            state_dict = torch.load(args.STAGE1_G, map_location=lambda storage, loc: storage)
            state_dict2 = torch.load(args.STAGE1_D, map_location=lambda storage, loc: storage)
            pre_generator.load_state_dict(state_dict)
            pre_discriminator.load_state_dict(state_dict2)
            pre_generator.eval()
            pre_discriminator.eval()
            print('Loaded Stage 1 Generator from:', args.STAGE1_G)
        else:
            print("Please provide Stage 1 Generator and Discriminator paths.")
            return netG, netD

        return netG, netD, pre_generator, pre_discriminator

def define_optimizers(args, generator, discriminator):
    optim_gen = torch.optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(args.beta1, args.beta2))
    optim_dis = torch.optim.Adam(discriminator.parameters(), lr=args.lr_dis, betas=(args.beta1, args.beta2))
    return optim_gen, optim_dis

def save_checkpoint(generator, discriminator, optim_gen, optim_dis, epoch, stage, pre_generator=None, pre_discriminator=None):
    checkpoint = {
        'epoch': epoch + 1,
        'generator_state': generator.state_dict(),
        'discriminator_state': discriminator.state_dict(),
        'optimizer_gen_state': optim_gen.state_dict(),
        'optimizer_dis_state': optim_dis.state_dict(),
        'stage': stage
    }
    if stage >= 2 and pre_generator and pre_discriminator:
        checkpoint['pre_generator_state'] = pre_generator.state_dict()
        checkpoint['pre_discriminator_state'] = pre_discriminator.state_dict()

    save_path = f"/kaggle/working/models/ctgan_checkpoint_epoch_{epoch}.pth"
    torch.save(checkpoint, save_path)
    print(f"✅ Checkpoint saved at {save_path}")

def train(generator, discriminator, optim_gen, optim_dis, train_loader, test_loader,
          pre_generator=None, pre_discriminator=None, device='cuda', stage=1,
          save_dir="output_images", start_epoch=0):

    generator.train()
    discriminator.train()
    os.makedirs(save_dir, exist_ok=True)

    print(f"🚀 Starting training from epoch {start_epoch}")

    for epoch in range(start_epoch, args.epoch):
        epoch_folder = os.path.join(save_dir, f"epoch_{epoch+1}")
        os.makedirs(epoch_folder, exist_ok=True)

        print(f"Epoch {epoch+1}/{args.epoch}")

        for batch_idx, (real_imgs, text_emb) in enumerate(train_loader):
            real_imgs, text_emb = real_imgs.to(device), text_emb.to(device)

            optim_dis.zero_grad()
            with torch.no_grad():
                if stage == 1:
                    noise = torch.randn((real_imgs.size(0), args.z_dim), device=device)
                    fake_imgs = generator(noise, text_emb)
                elif stage == 2:
                    filtered_imgs = pre_gen_imgs(text_emb, pre_generator, pre_discriminator, device=device)
                    fake_imgs = generator(filtered_imgs, text_emb)
                elif stage == 3:
                    intermediate_imgs = pre_gen_imgs(text_emb, pre_generator, pre_discriminator, device=device)
                    fake_imgs = generator(intermediate_imgs, text_emb)

            d_out = compute_discriminator_loss(args, discriminator, real_imgs, fake_imgs, text_emb)
            if len(d_out) == 4:
                d_loss, errD_real, errD_wrong, errD_fake = d_out
            else:
                d_loss, errD_real, errD_fake = d_out
                errD_wrong = torch.tensor(0.0)
            d_loss.backward()
            optim_dis.step()

            optim_gen.zero_grad()
            if stage == 1:
                noise = torch.randn((real_imgs.size(0), args.z_dim), device=device)
                gen_imgs = generator(noise, text_emb)
            elif stage == 2:
                gen_imgs = generator(filtered_imgs, text_emb)
            elif stage == 3:
                gen_imgs = generator(intermediate_imgs, text_emb)

            g_loss = compute_generator_loss(args, discriminator, gen_imgs, text_emb)
            g_loss.backward()
            optim_gen.step()

            batch_img_path = os.path.join(epoch_folder, f"batch_{batch_idx}.png")
            save_image(gen_imgs[:16], batch_img_path, normalize=True)

            if batch_idx % 50 == 0:
                show_image(batch_img_path)

            if batch_idx % 100 == 0:
                print(f"[Epoch {epoch+1}] [Batch {batch_idx}/{len(train_loader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}")

        final_epoch_image = os.path.join(epoch_folder, f"epoch_{epoch+1}_final.png")
        save_image(gen_imgs[:16], final_epoch_image, normalize=True)
        print(f"✅ Saved final epoch image: {final_epoch_image}")
        show_image(final_epoch_image)

        # Save model if it improves
        # if FID_score < best_FID:
        #     torch.save(generator.state_dict(), "/kaggle/working/models/best_generator.pth")
        #     best_FID = FID_score
        #     print("Updated Best Generator Model!")

        # if IS_score > best_IS:
        #     torch.save(generator.state_dict(), "/kaggle/working/models/best_generator.pth")
        #     best_IS = IS_score
        #     print("Updated Best Generator Model!")

        if (epoch + 1) % 50 == 0 or (epoch + 1) == args.epoch:
            save_checkpoint(generator, discriminator, optim_gen, optim_dis, epoch + 1, stage, pre_generator, pre_discriminator)

def main():
    print("Initializing Training...")

    if args.STAGE == 1:
        generator, discriminator = load_network_stageI(args, device)
        pre_generator, pre_discriminator = None, None
    elif args.STAGE == 2:
        generator, discriminator, pre_generator, pre_discriminator = load_network_stageII(args, device)
    elif args.STAGE == 3:
        generator, discriminator, pre_generator, pre_discriminator = load_network_stageIII(args, device)
    else:
        raise ValueError("Invalid STAGE. Must be 1, 2 or 3.")

    dataset = ImageDataset(args, cur_img_size=64)
    train_loader = dataset.train
    test_loader = dataset.test

    # Load weights for generator and discriminator
    if args.STAGE == 1:
        result = load_weight(generator, discriminator)
        if len(result) == 2:
            generator, discriminator = result
        else:
            raise ValueError(f"Unexpected return values from load_weight(): {result}")
    else:
        result = load_weight(generator, discriminator, pre_generator, pre_discriminator)
        if len(result) == 4:
            generator, discriminator, pre_generator, pre_discriminator = result
        else:
            raise ValueError(f"Unexpected return values from load_weight(): {result}")

    optim_gen, optim_dis = define_optimizers(args, generator, discriminator)

    # ✅ Resume from checkpoint if available
    checkpoint_path = "your_checkpoint_path"  # update this path
    start_epoch = 101

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.load_state_dict(checkpoint['generator_state'])
        discriminator.load_state_dict(checkpoint['discriminator_state'])
        optim_gen.load_state_dict(checkpoint['optimizer_gen_state'])
        optim_dis.load_state_dict(checkpoint['optimizer_dis_state'])
        start_epoch = checkpoint['epoch']
        print(f"✅ Resuming from checkpoint at epoch {start_epoch}")

        # Also load pre_generator and pre_discriminator for stage 2/3
        if args.STAGE >= 2:
            pre_generator.load_state_dict(checkpoint['pre_generator_state'])
            pre_discriminator.load_state_dict(checkpoint['pre_discriminator_state'])
            pre_generator.eval()
            pre_discriminator.eval()
    else:
        print("⚠️ Checkpoint not found. Starting from scratch.")

    target_epoch = args.epoch

    print(f"Starting training from epoch {start_epoch} to {target_epoch}")
    train(generator, discriminator, optim_gen, optim_dis, train_loader, test_loader,
          pre_generator, pre_discriminator, device=device, stage=args.STAGE,
          save_dir="output_images", start_epoch=start_epoch)

    print("✅ Training Completed!")

if __name__ == "__main__":
    main()


In [None]:


# 🧪 Setup
args = cfg.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("🧪 Testing on device:", device)

checkpoint_path = "your_checkpoint_path"
output_dir = "your_output_images path"
os.makedirs(output_dir, exist_ok=True)

dataset = ImageDataset(args, cur_img_size=64)
test_loader = dataset.test

# Load model
if args.STAGE == 1:
    generator, _ = load_network_stageI(args, device)
else:
    generator, _, pre_generator, pre_discriminator = load_network_stageII(args, device)

# Load weights from checkpoint
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state'])
    print(f"✅ Loaded generator from checkpoint: {checkpoint_path}")
else:
    raise FileNotFoundError("Checkpoint not found.")

generator.eval()
if args.STAGE == 2:
    pre_generator.load_state_dict(checkpoint['pre_generator_state'])
    pre_discriminator.load_state_dict(checkpoint['pre_discriminator_state'])
    pre_generator.eval()
    pre_discriminator.eval()

# Inference
def generate_images():
    print("🎨 Generating images from test embeddings...")
    with torch.no_grad():
        for idx, (real_img, emb) in enumerate(test_loader):
            emb = emb.to(device)
            batch_size = emb.size(0)

            if args.STAGE == 1:
                noise = torch.randn((batch_size, args.z_dim), device=device)
                fake_imgs = generator(noise, emb)
            else:
                filtered = pre_gen_imgs(emb, pre_generator, pre_discriminator, device=device)
                fake_imgs = generator(filtered, emb)

            save_path = os.path.join(output_dir, f"sample_{idx}.png")
            save_image(fake_imgs[:16], save_path, normalize=True)
            print(f"✅ Saved test image: {save_path}")
            show_image(save_path)

            # if idx >= 4:  # Just generate 5 batches for quick test
            #     break

def show_image(path):
    img = Image.open(path)
    display.display(img)

# Run test
generate_images()
print("✅ Test image generation complete!")
