In [None]:
# Cell 1: Install required packages
!pip install datasets huggingface_hub torch torchvision validators transformers fvcore bitsandbytes torch-fidelity einops

# Cell 2: Install GigaGAN package
!pip install gigagan-pytorch==0.2.20

# Cell 3: Import libraries, Setup PROFILING_DIR and Logger
import os
import torch
from datasets import load_dataset
from __future__ import annotations # Should be at the very top if used
from gigagan_pytorch import GigaGAN
from huggingface_hub import login
from PIL import Image, ImageFile
import requests
from io import BytesIO
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import validators
import gc
import time
import json
from pathlib import Path
import gdown
import torch.nn as nn
# from torch.cuda.amp import autocast, GradScaler # Deprecated
from torch.amp import autocast, GradScaler # Corrected import
import bitsandbytes as bnb
from torchvision.utils import save_image
import torchvision.transforms as T
from tqdm import tqdm
import glob
import torch_fidelity
import torch.nn.functional as F # For KD loss

# Import profiling tools
import torch.profiler as profiler
from torch.profiler import profile, record_function, ProfilerActivity
from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count

# Create profiling directory
PROFILING_DIR = "profiling_results_kd"
os.makedirs(PROFILING_DIR, exist_ok=True)

# Setup logging
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(PROFILING_DIR,'gigagan_kd_profiling.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)
logger.info("Knowledge Distillation GigaGAN Notebook Started")

# Cell 4: One time function execution (commented out)
# ... (same as in original notebook) ...

# Cell 5: Image loading utilities
ImageFile.LOAD_TRUNCATED_IMAGES = True
def safe_getitem(self, index):
    max_attempts = 5; current_index = index
    for _ in range(max_attempts):
        try:
            path = self.paths[current_index]; img = Image.open(path); img.load()
            if img.mode != 'RGB': img = img.convert('RGB')
            return self.transform(img)
        except (OSError, IOError, SyntaxError) as e:
            logger.warning(f"Error loading image {path}: {e}. Trying next."); current_index = (current_index + 1) % len(self.paths)
            if current_index == index: logger.error("Looped all images."); break
    logger.error(f"Failed {max_attempts} loads, returning blank."); return torch.zeros(3, self.image_size, self.image_size)

def patch_image_dataset(dataset):
    ImageFile.LOAD_TRUNCATED_IMAGES = True; dataset.__class__.__getitem__ = safe_getitem; return dataset

# Cell 6: Unzip dataset (if applicable)
dataset_zip_path = '/content/final_dataset.zip'; dataset_extract_path = '/content/final_dataset'
!unzip -q /content/final_dataset.zip -d /content/final_dataset
if os.path.exists(dataset_zip_path) and not os.path.exists(os.path.join(dataset_extract_path, "10047.png")):
    logger.info(f"Unzipping {dataset_zip_path} to {dataset_extract_path}...")
    os.makedirs(dataset_extract_path, exist_ok=True)
    # !unzip -q /content/final_dataset.zip -d /content/final_dataset
    logger.info("Dataset unzipped (command executed).")
else:
    if not os.path.exists(dataset_zip_path): logger.warning(f"Dataset zip {dataset_zip_path} not found.")
    else: logger.info(f"Dataset seems unzipped in {dataset_extract_path}.")

# Cell 7: Load Dataset
from gigagan_pytorch import ImageDataset
dataset = None; dataloader = None
DATASET_PATH = '/content/final_dataset'; IMAGE_SIZE = 256; BATCH_SIZE = 4 # KD Batch size will be set later
try:
    if not os.path.isdir(DATASET_PATH) or not os.listdir(DATASET_PATH): raise FileNotFoundError(f"Dataset dir {DATASET_PATH} missing/empty.")
    dataset = ImageDataset(folder=DATASET_PATH, image_size=IMAGE_SIZE)
    if len(dataset) == 0: raise ValueError("Dataset empty post-init.")
    dataset = patch_image_dataset(dataset)
    # Dataloader for KD will be created specifically in the KD cell with its own batch size.
    # However, having a general dataloader for other potential uses (like initial teacher check) is fine.
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    logger.info(f"Loaded dataset: {len(dataset)} images. Default dataloader batch: {BATCH_SIZE}.")
except Exception as e:
    logger.error(f"Failed to load dataset from {DATASET_PATH}: {e}", exc_info=True)
    logger.warning("Creating dummy dataset.");
    class DummyDataset(torch.utils.data.Dataset):
        def __init__(self, size=100, image_size=256): self.size=size; self.image_size=image_size; self.transform=transforms.Compose([transforms.ToTensor()])
        def __len__(self): return self.size
        def __getitem__(self, idx): return self.transform(Image.new('RGB', (self.image_size,self.image_size),color='gray'))
    dataset = DummyDataset(image_size=IMAGE_SIZE); dataloader = DataLoader(dataset, batch_size=BATCH_SIZE) # Dataloader for KD will be specific
    logger.info(f"Using dummy dataset: {len(dataset)} images.")


# Cell 8: Teacher Model Setup (from Original GigaGAN)
logger.info("--- Setting up Teacher Model for Knowledge Distillation ---")
teacher_model = None
teacher_ckpt_path = '/content/model-32.ckpt' # This is the original, pre-trained model

# Define teacher model configuration (should match the checkpoint)
teacher_generator_config = dict(
    style_network=dict(dim=64, depth=4), dim=32, image_size=256,
    input_image_size=64, unconditional=True, flash_attn=False
)
teacher_discriminator_config = dict(
    dim_capacity=16, dim_max=512, image_size=256, num_skip_layers_excite=4,
    multiscale_input_resolutions=(128,), unconditional=True
)
# Set amp based on how the teacher_ckpt_path was saved.
teacher_model_amp_setting = False # Assuming model-32.ckpt was saved with amp=False

try:
    logger.info(f"Initializing Teacher GigaGAN with AMP={teacher_model_amp_setting}...")
    teacher_model = GigaGAN(
        train_upsampler=True,
        generator=teacher_generator_config,
        discriminator=teacher_discriminator_config,
        amp=teacher_model_amp_setting
    ).cuda()
    logger.info("Teacher GigaGAN structure initialized.")

    if os.path.exists(teacher_ckpt_path):
        teacher_model.load(teacher_ckpt_path)
        logger.info(f"Loaded teacher model checkpoint from {teacher_ckpt_path}")
    else:
        logger.warning(f"Teacher checkpoint not found at {teacher_ckpt_path}. KD will use initialized teacher weights.")

    teacher_model.eval() # Set teacher to evaluation mode
    for param in teacher_model.parameters(): # Freeze teacher weights
        param.requires_grad = False
    logger.info("Teacher model weights frozen and set to eval mode.")

except Exception as e_teacher:
    logger.error(f"Major error setting up teacher model: {e_teacher}", exc_info=True)
    teacher_model = None


# Cell 9: Student Model Definition
logger.info("--- Defining Student Model for Knowledge Distillation ---")
gan_student = None
if teacher_model is not None:
    try:
        student_generator_config = dict(
            style_network=dict(dim=32, depth=3), # Reduced
            dim=16, # Reduced
            image_size=256, input_image_size=64, # Match teacher I/O
            unconditional=True, flash_attn=False
        )
        student_discriminator_config = dict(
            dim_capacity=8, dim_max=256, # Reduced
            image_size=256, num_skip_layers_excite=2, # Reduced
            multiscale_input_resolutions=(128,), unconditional=True
        )
        # AMP setting for student model during KD training.
        student_initial_amp = False # If teacher_model_amp_setting was True, this should also be True.

        logger.info(f"Initializing Student GigaGAN with AMP={student_initial_amp}...")
        gan_student = GigaGAN(
            train_upsampler=True, # Match teacher
            generator=student_generator_config,
            discriminator=student_discriminator_config,
            amp=student_initial_amp
        ).cuda()
        logger.info(f"Student GigaGAN model structure initialized (untrained).")
    except Exception as e_student_def:
        logger.error(f"Failed to define student model: {e_student_def}", exc_info=True)
        gan_student = None
else:
    logger.error("Teacher model not available. Cannot define student model.")


# Cell 10: Knowledge Distillation Training (Conceptual)
# This cell uses placeholder GAN losses and focuses on the distillation aspect.

logger.info("--- Conceptual Knowledge Distillation Training Setup ---")

if 'teacher_model' in locals() and teacher_model is not None and \
   'gan_student' in locals() and gan_student is not None and \
   'dataset' in locals() and dataset is not None: # Check for dataset

    logger.info("Teacher, Student, and Dataset found. Setting up KD parameters.")
    logger.warning("NOTE: This loop uses PLACEHOLDER GAN losses and is conceptual.")
    logger.warning("Actual KD training requires implementing real GAN loss calculations.")


    KD_TRAINING_STEPS = 100 # Reduced for quick test
    KD_BATCH_SIZE = 4      # Batch size for KD training
    KD_LR_G = 1e-4
    KD_LR_D = 1e-4
    LAMBDA_KD = 10.0
    KD_GRAD_ACCUM_EVERY = 1
    KD_LOG_EVERY = 20
    KD_SAVE_EVERY = 50

    # Create a new dataloader instance for KD
    kd_dataloader = DataLoader(dataset, batch_size=KD_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    logger.info(f"Created KD Dataloader with batch size {KD_BATCH_SIZE}.")


    student_optimizer_G = torch.optim.Adam(gan_student.G.parameters(), lr=KD_LR_G, betas=(0.9, 0.99))
    student_optimizer_D = torch.optim.Adam(gan_student.D.parameters(), lr=KD_LR_D, betas=(0.9, 0.99))

    # Student's accelerator state for GradScaler and autocast
    is_student_amp_fp16_kd = (gan_student.accelerator.state.mixed_precision == 'fp16')
    logger.info(f"Student model KD training AMP (fp16) status: {is_student_amp_fp16_kd}")
    student_scaler_G = torch.amp.GradScaler(device='cuda', enabled=is_student_amp_fp16_kd)
    student_scaler_D = torch.amp.GradScaler(device='cuda', enabled=is_student_amp_fp16_kd)

    gan_student.train() # Student in training mode
    teacher_model.eval() # Teacher stays in eval

    logger.info(f"Starting conceptual KD training for {KD_TRAINING_STEPS} steps...")

    student_gen_input_size = gan_student.unwrapped_G.input_image_size
    resize_transform_student_kd = T.Resize((student_gen_input_size, student_gen_input_size), antialias=True)

    style_dim_student_kd = gan_student.unwrapped_G.style_network.dim
    style_dim_teacher_kd = teacher_model.unwrapped_G.style_network.dim

    teacher_gen_input_size_kd = teacher_model.unwrapped_G.input_image_size
    resize_transform_teacher_kd = None
    dataloader_image_size = IMAGE_SIZE # Get size from dataset config used earlier
    if teacher_gen_input_size_kd != dataloader_image_size:
        resize_transform_teacher_kd = T.Resize((teacher_gen_input_size_kd, teacher_gen_input_size_kd), antialias=True)

    kd_dataloader_iter = iter(kd_dataloader)
    current_device_kd = gan_student.accelerator.device
    teacher_device_kd = teacher_model.accelerator.device
    is_teacher_amp_fp16_kd = (teacher_model.accelerator.state.mixed_precision == 'fp16')

    for step in range(KD_TRAINING_STEPS):
        try:
            try:
                real_images_full_res = next(kd_dataloader_iter)
                if isinstance(real_images_full_res, (list, tuple)): real_images_full_res = real_images_full_res[0]
            except StopIteration:
                logger.info(f"KD Dataloader exhausted at step {step}. Resetting.")
                kd_dataloader_iter = iter(kd_dataloader); real_images_full_res = next(kd_dataloader_iter)
                if isinstance(real_images_full_res, (list, tuple)): real_images_full_res = real_images_full_res[0]

            real_images_full_res = real_images_full_res.to(current_device_kd)
            current_batch_size_dynamic = real_images_full_res.size(0)
            if current_batch_size_dynamic == 0: logger.warning("KD: Fetched empty batch."); continue

            # Student's low-resolution input
            lowres_input_for_student_kd = resize_transform_student_kd(real_images_full_res)

            # Teacher's low-resolution input
            if resize_transform_teacher_kd:
                lowres_input_for_teacher_kd = resize_transform_teacher_kd(real_images_full_res)
            elif teacher_gen_input_size_kd == real_images_full_res.shape[-1]:
                lowres_input_for_teacher_kd = real_images_full_res.clone()
            else:
                lowres_input_for_teacher_kd = lowres_input_for_student_kd.clone()

            noise_batch_student_kd = torch.randn(current_batch_size_dynamic, style_dim_student_kd, device=current_device_kd)
            noise_batch_teacher_kd = torch.randn(current_batch_size_dynamic, style_dim_teacher_kd, device=teacher_device_kd)

            student_input_processed_G_kd = lowres_input_for_student_kd
            teacher_input_processed_G_kd = lowres_input_for_teacher_kd.to(teacher_device_kd)

            # --- Update Student Discriminator (Conceptual) ---
            student_optimizer_D.zero_grad(set_to_none=True)
            for _ in range(KD_GRAD_ACCUM_EVERY):
                # Placeholder loss - does not require model call or autocast
                d_loss_placeholder = torch.tensor(1.0, device=current_device_kd, requires_grad=True)
                d_loss_accum = d_loss_placeholder / KD_GRAD_ACCUM_EVERY
                # Need to scale even placeholder loss if scaler is enabled
                student_scaler_D.scale(d_loss_accum).backward()
            student_scaler_D.unscale_(student_optimizer_D)
            student_scaler_D.step(student_optimizer_D)
            student_scaler_D.update()

            # --- Update Student Generator (Conceptual GAN Loss + Real Distillation Loss) ---
            student_optimizer_G.zero_grad(set_to_none=True)
            for _ in range(KD_GRAD_ACCUM_EVERY):
                with torch.amp.autocast(device_type=current_device_kd.type, dtype=torch.float16, enabled=is_student_amp_fp16_kd):
                    # Generate student images (needed for distillation loss)
                    fake_images_student = gan_student.G(
                        student_input_processed_G_kd,
                        noise=noise_batch_student_kd # G takes noise as keyword
                    )

                    # Placeholder GAN loss for generator
                    g_loss_gan_placeholder = torch.tensor(0.5, device=current_device_kd) # No grad needed for placeholder

                    # Generate teacher images (still needed for distillation)
                    with torch.no_grad():
                        with torch.amp.autocast(device_type=teacher_device_kd.type, dtype=torch.float16, enabled=is_teacher_amp_fp16_kd):
                            fake_images_teacher = teacher_model.G(
                                teacher_input_processed_G_kd,
                                noise=noise_batch_teacher_kd
                            ).detach()

                    # Calculate REAL distillation loss
                    distill_loss = F.mse_loss(fake_images_student.float(), fake_images_teacher.to(fake_images_student.device).float())

                    # Combine placeholder GAN loss and real distillation loss
                    # Only distill_loss requires grad here
                    total_g_loss = g_loss_gan_placeholder + (LAMBDA_KD * distill_loss)
                    total_g_loss_accum = total_g_loss / KD_GRAD_ACCUM_EVERY

                # Scale and backward pass for G (gradient comes only from distill_loss)
                student_scaler_G.scale(total_g_loss_accum).backward()
            student_scaler_G.unscale_(student_optimizer_G)
            student_scaler_G.step(student_optimizer_G)
            student_scaler_G.update()

            if step % KD_LOG_EVERY == 0:
                 # Log placeholder and real loss values
                 d_loss_item = d_loss_placeholder.item()
                 g_loss_gan_item = g_loss_gan_placeholder.item()
                 distill_loss_item = distill_loss.item() if torch.is_tensor(distill_loss) else float('nan')
                 total_g_loss_item = total_g_loss.item() if torch.is_tensor(total_g_loss) else float('nan')
                 logger.info(f"KD Step [{step}/{KD_TRAINING_STEPS}] | D Loss (PH): {d_loss_item:.4f} | G Loss (PH): {g_loss_gan_item:.4f} | Distill Loss: {distill_loss_item:.4f} | Total G Loss: {total_g_loss_item:.4f}")

            if step > 0 and (step % KD_SAVE_EVERY == 0 or step == KD_TRAINING_STEPS - 1):
                 student_save_path = os.path.join(PROFILING_DIR, f'student_model_step_{step}.ckpt')
                 try:
                     gan_student.save(student_save_path)
                     logger.info(f"Student checkpoint saved to {student_save_path}")
                 except Exception as save_e:
                     logger.error(f"Failed to save student checkpoint at step {step}: {save_e}")

        except Exception as train_err:
            logger.error(f"Error during conceptual KD step {step}: {train_err}", exc_info=True)
            if step > 0 :
                student_error_save_path = os.path.join(PROFILING_DIR, f'student_model_error_step_{step}.ckpt')
                try: gan_student.save(student_error_save_path); logger.info(f"Saved error checkpoint: {student_error_save_path}")
                except Exception as save_err: logger.error(f"Could not save error checkpoint: {save_err}")
            break

    logger.info(f"Conceptual KD training finished or stopped at step {step}.")
    student_final_ckpt_path = os.path.join(PROFILING_DIR, 'student_model_final.ckpt')
    # Check if final save is needed (if last step wasn't a save step)
    if step == KD_TRAINING_STEPS - 1 and KD_TRAINING_STEPS > 0 and KD_TRAINING_STEPS % KD_SAVE_EVERY != 0 :
        try: gan_student.save(student_final_ckpt_path); logger.info(f"Final student model saved to {student_final_ckpt_path}")
        except Exception as e: logger.error(f"Could not save final student model: {e}")

    logger.warning("This was conceptual training. Load a properly trained student checkpoint before profiling.")
else:
    # Check which condition failed
    if 'teacher_model' not in locals() or teacher_model is None:
        logger.error("KD Training cannot start: Teacher model is missing/invalid.")
    elif 'gan_student' not in locals() or gan_student is None:
        logger.error("KD Training cannot start: Student model is missing/invalid.")
    elif 'dataset' not in locals() or dataset is None: # Changed check to dataset
        logger.error("KD Training cannot start: Dataset is missing/invalid.")
    else: # Should not happen if one of the above is true
         logger.error("KD Training cannot start: Unknown reason (Teacher, Student, or Dataset missing/invalid).")


# Cell 11: Trained Student GigaGAN Model Analysis and Profiling
# ... (rest of the notebook remains the same, but will profile the conceptually trained model if run directly after Cell 10) ...
# Cell 11: Trained Student GigaGAN Model Analysis and Profiling
logger.info("--- Trained Student Model Analysis and Profiling ---")
# This cell assumes gan_student is either still in memory from KD or loaded from a checkpoint.
# For a clean run, it's better to explicitly load the desired student checkpoint.

gan_student_profiling = None # Use a new variable for clarity
# Point to the checkpoint saved from the *conceptual* training if you want to profile that.
# Otherwise, point to a checkpoint from a *real* training run.
student_checkpoint_to_profile = os.path.join(PROFILING_DIR, 'student_model_final.ckpt')

if os.path.exists(student_checkpoint_to_profile):
    logger.info(f"Loading student model from {student_checkpoint_to_profile} for profiling...")
    try:
        if 'student_generator_config' not in globals() or 'student_discriminator_config' not in globals():
            raise NameError("Student model configurations not found. Cannot reload student model.")
        student_load_amp = False # Match the amp setting used when saving
        gan_student_profiling = GigaGAN(
            train_upsampler=True, generator=student_generator_config,
            discriminator=student_discriminator_config, amp=student_load_amp
        ).cuda()
        gan_student_profiling.load(student_checkpoint_to_profile)
        logger.info("Student model loaded successfully for profiling.")
        gan_student_profiling.eval()
    except Exception as e_load_stud:
        logger.error(f"Failed to load student model {student_checkpoint_to_profile} for profiling: {e_load_stud}", exc_info=True)
        gan_student_profiling = None
else:
    logger.warning(f"Student checkpoint {student_checkpoint_to_profile} not found. Profiling will be skipped or use in-memory model if available.")
    if 'gan_student' in locals() and gan_student is not None:
        logger.info("Using in-memory 'gan_student' (potentially only conceptually trained) for profiling.")
        gan_student_profiling = gan_student
        gan_student_profiling.eval()
    else:
        gan_student_profiling = None

if gan_student_profiling is not None:
    current_device_stud = gan_student_profiling.accelerator.device
    architecture_stats_student = {"model_type": "student_kd_conceptual"} # Note conceptual
    try:
        gen_params_s = sum(p.numel() for p in gan_student_profiling.G.parameters())
        disc_params_s = sum(p.numel() for p in gan_student_profiling.D.parameters()); total_params_s = gen_params_s + disc_params_s
        model_size_mb_s = sum(p.numel()*p.element_size() for p in gan_student_profiling.parameters())/ (1024*1024)
        architecture_stats_student.update({
            "gpu_type": "A100" if current_device_stud.type == 'cuda' else "CPU",
            "amp_setting": gan_student_profiling.accelerator.state.mixed_precision,
            "generator_parameters": gen_params_s, "discriminator_parameters": disc_params_s, "total_parameters": total_params_s,
            "model_size_mb": model_size_mb_s
        })
        logger.info(f"[Student KD Conceptual] G Params: {gen_params_s:,}, D Params: {disc_params_s:,}, Total: {total_params_s:,}, Size: {model_size_mb_s:.2f}MB")
    except Exception as e_arch_stud: logger.error(f"Student arch analysis error: {e_arch_stud}", exc_info=True)

    logger.warning("Skipping FLOPs for student generator.")
    architecture_stats_student['generator_gflops'] = 'Skipped'
    stats_path_s = os.path.join(PROFILING_DIR, 'architecture_stats_student_kd_conceptual_a100.json') # Note conceptual
    with open(stats_path_s, 'w') as f: json.dump(architecture_stats_student, f, indent=4)
    logger.info(f"Saved student arch stats to {stats_path_s}")

    # --- Profile Generation (Student) ---
    logger.info("--- Profiling Conceptually Trained Student Model Generation ---")
    try:
        input_size_s = gan_student_profiling.unwrapped_G.input_image_size
        style_dim_s = gan_student_profiling.unwrapped_G.style_network.dim
        gen_input_lowres_s = torch.randn(1, 3, input_size_s, input_size_s).to(current_device_stud)
        gen_input_noise_s = torch.randn(1, style_dim_s).to(current_device_stud)
        output_dir_student_imgs = os.path.join(PROFILING_DIR, "generated_images_student_kd_conceptual") # Note conceptual
        os.makedirs(output_dir_student_imgs, exist_ok=True)

        is_student_model_amp_fp16_prof = (gan_student_profiling.accelerator.state.mixed_precision == 'fp16')
        if is_student_model_amp_fp16_prof:
            gen_input_lowres_s, gen_input_noise_s = gen_input_lowres_s.half(), gen_input_noise_s.half()
        else:
            gen_input_lowres_s, gen_input_noise_s = gen_input_lowres_s.float(), gen_input_noise_s.float()

        logger.info(f"Warm-up for student profiling (Model AMP: {is_student_model_amp_fp16_prof}, Input: {gen_input_lowres_s.dtype})...")
        with torch.inference_mode():
            with torch.amp.autocast(device_type=current_device_stud.type, enabled=is_student_model_amp_fp16_prof):
                _ = gan_student_profiling.generate(lowres_image=gen_input_lowres_s.clone(), noise=gen_input_noise_s.clone())
        if current_device_stud.type == 'cuda': torch.cuda.synchronize()

        logger.info("Starting student profiler...")
        activities_s = [ProfilerActivity.CPU]
        if current_device_stud.type == 'cuda': activities_s.append(ProfilerActivity.CUDA)
        with profile(activities=activities_s, record_shapes=True, profile_memory=True, with_stack=True) as prof_s:
            with record_function("student_model_inference"):
                with torch.inference_mode():
                    with torch.amp.autocast(device_type=current_device_stud.type, enabled=is_student_model_amp_fp16_prof):
                        images_s = gan_student_profiling.generate(lowres_image=gen_input_lowres_s, noise=gen_input_noise_s)
        if current_device_stud.type == 'cuda': torch.cuda.synchronize()
        logger.info("Student profiling complete.")
        images_s = images_s.float().cpu().clamp(0., 1.)
        save_image(images_s, os.path.join(output_dir_student_imgs, "image_profiled_student_conceptual_a100.png")) # Note conceptual
        logger.info(f"Saved profiled student image.")
        sort_key_s = "cuda_time_total" if current_device_stud.type == 'cuda' else "cpu_time_total"
        print(prof_s.key_averages().table(sort_by=sort_key_s, row_limit=15))
        profiler_output_path_s = os.path.join(PROFILING_DIR, 'profiler_student_conceptual_trace_a100.json') # Note conceptual
        prof_s.export_chrome_trace(profiler_output_path_s)
        logger.info(f"Student profiler trace saved: {profiler_output_path_s}")
    except Exception as e_prof_stud: logger.error(f"Student profiling error: {e_prof_stud}", exc_info=True)
    finally:
        if 'images_s' in locals(): del images_s;
        if 'gen_input_lowres_s' in locals(): del gen_input_lowres_s;
        if 'gen_input_noise_s' in locals(): del gen_input_noise_s;
        if 'prof_s' in locals(): del prof_s;
        if torch.cuda.is_available(): torch.cuda.empty_cache();
        gc.collect()
else:
    logger.error("Trained student model not available for profiling.")


# Cell 12: Metrics Calculation (Trained Student Model)
logger.info("--- GAN Evaluation Metrics Calculation (Conceptually Trained Student Model) ---")
# Re-define helper functions if this notebook is run standalone
def calculate_metrics_from_directories(generated_dir, real_dir=None, batch_size=50, prefix=""):
    results = {};
    if not os.path.isdir(generated_dir) or not glob.glob(os.path.join(generated_dir, "*.[pj][np]g")):
        logger.error(f"{prefix}Generated images directory is invalid or empty: {generated_dir}"); return None
    if real_dir and (not os.path.isdir(real_dir) or not glob.glob(os.path.join(real_dir, "*.[pj][np]g"))):
        logger.warning(f"{prefix}Real images directory is invalid or empty: {real_dir}. Skipping FID."); real_dir = None
    try:
        metrics_to_calc = {'isc': True, 'fid': bool(real_dir)}
        logger.info(f"{prefix}Calculating metrics: {metrics_to_calc} for {generated_dir}...")
        metrics_output = torch_fidelity.calculate_metrics(
            input1=generated_dir, input2=real_dir, isc=metrics_to_calc['isc'], fid=metrics_to_calc['fid'],
            batch_size=batch_size, cuda=torch.cuda.is_available(), save_cpu_ram=True, verbose=False
        )
        results['inception_score_mean'] = metrics_output.get('inception_score_mean', float('nan'))
        results['inception_score_std'] = metrics_output.get('inception_score_std', float('nan'))
        logger.info(f"{prefix}Inception Score: {results['inception_score_mean']:.4f} ± {results['inception_score_std']:.4f}")
        if real_dir:
            results['frechet_inception_distance'] = metrics_output.get('frechet_inception_distance', float('nan'))
            logger.info(f"{prefix}FID Score: {results['frechet_inception_distance']:.4f}")
    except Exception as e:
        logger.error(f"{prefix}Error during metric calculation for {generated_dir}: {e}", exc_info=True); results = None
    return results

def generate_images_for_metrics(gan_model, output_dir, num_images=100, batch_size=4, prefix=""):
    if gan_model is None: logger.error(f"{prefix}GAN model is None."); return None
    os.makedirs(output_dir, exist_ok=True)
    logger.info(f"{prefix}Clearing existing images in {output_dir}...")
    for f in glob.glob(os.path.join(output_dir, "*.*g")): os.remove(f)

    gan_model.eval(); num_batches = (num_images + batch_size - 1) // batch_size; image_count = 0
    device = gan_model.accelerator.device
    is_model_amp_fp16 = (gan_model.accelerator.state.mixed_precision == 'fp16')

    unwrapped_G = getattr(gan_model, 'unwrapped_G', gan_model.G)
    input_size = getattr(unwrapped_G, 'input_image_size', 64)
    style_dim = getattr(getattr(unwrapped_G, 'style_network', None), 'dim', 512)

    logger.info(f"{prefix}Generating {num_images} images for metrics in {output_dir} (Model AMP: {is_model_amp_fp16})...")
    pbar = tqdm(range(num_batches), desc=f"{prefix}Generating for {output_dir}")
    for i in pbar:
        current_batch_size = min(batch_size, num_images - image_count)
        if current_batch_size <= 0: break
        lowres_images = torch.randn(current_batch_size, 3, input_size, input_size, device=device)
        noise = torch.randn(current_batch_size, style_dim, device=device)

        if is_model_amp_fp16: lowres_images, noise = lowres_images.half(), noise.half()
        else: lowres_images, noise = lowres_images.float(), noise.float()

        try:
            with torch.inference_mode():
                with torch.amp.autocast(device_type=device.type, enabled=is_model_amp_fp16):
                    images_gen = gan_model.generate(lowres_image=lowres_images, noise=noise)
            images_gen = images_gen.float().cpu().clamp(0., 1.)
            for j, img_tensor in enumerate(images_gen):
                if image_count >= num_images: break
                save_image(img_tensor, os.path.join(output_dir, f'metric_image_{image_count:05d}.png'))
                image_count += 1
            pbar.set_postfix({"images_generated": image_count})
        except Exception as e_gen: logger.error(f"{prefix}Error gen batch {i}: {e_gen}", exc_info=True); continue
        finally:
            del lowres_images, noise;
            if 'images_gen' in locals(): del images_gen
            if torch.cuda.is_available(): torch.cuda.empty_cache()
        if image_count >= num_images: break
    logger.info(f"{prefix}Generated {image_count} images to {output_dir}")
    return output_dir if image_count > 0 else None

metrics_student_results = None
# gan_student_profiling should be the loaded student model (conceptually trained in this case)
if gan_student_profiling is not None:
    metrics_output_dir_student = os.path.join(PROFILING_DIR, 'generated_images_for_metrics_student_kd_conceptual') # Note conceptual
    gen_dir_stud = generate_images_for_metrics(gan_student_profiling, metrics_output_dir_student,
                                               num_images=100, batch_size=8, prefix="[Student KD Conceptual Metrics] ") # Reduced num_images
    if gen_dir_stud:
        logger.info("Calculating metrics for CONCEPTUALLY TRAINED STUDENT (KD) model...")
        metrics_student_results = calculate_metrics_from_directories(
            gen_dir_stud, real_dir=None, batch_size=16, prefix="[Student KD Conceptual Metrics] ") # real_dir=None for IS only
else:
    logger.error("Conceptually trained student model (gan_student_profiling) not available for metrics calculation.")

if metrics_student_results:
    logger.info(f"Final Student (KD Conceptual) Model Inception Score: {metrics_student_results.get('inception_score_mean', 'N/A'):.4f} ± {metrics_student_results.get('inception_score_std', 'N/A'):.4f}")
    if metrics_student_results.get('frechet_inception_distance', 'N/A') != 'N/A':
        logger.info(f"Final Student (KD Conceptual) Model FID Score: {metrics_student_results.get('frechet_inception_distance'):.4f}")

logger.info("Knowledge Distillation GigaGAN Notebook Finished")
if 'teacher_model' in locals(): del teacher_model
if 'gan_student' in locals(): del gan_student
if 'gan_student_profiling' in locals(): del gan_student_profiling
if torch.cuda.is_available(): torch.cuda.empty_cache()
gc.collect()


Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting validators
  Downloading validators-0.35.0-py3-none-any.whl.metadata (3.9 kB)
Collecting fvcore
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collect





Generator: 11.73M
Discriminator: 8.06M






-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                student_model_inference         0.00%       0.000us         0.00%       0.000us       0.000us     101.487ms       361.26%     101.487ms     101.487ms           0 b           0 b           0 b           0 

[Student KD Conceptual Metrics] Generating for profiling_results_kd/generated_images_for_metrics_student_kd_conceptual:  92%|█████████▏| 12/13 [00:04<00:00,  2.72it/s, images_generated=100]
Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 263MB/s]
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)


84729