<a href="https://colab.research.google.com/github/joe-singh/Superconducting-Diffusion/blob/main/SuperconductingDiffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Torchvision for FID

In [None]:
!pip install torchvision
!pip install torch-fidelity

## Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Set Random Seeds

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from scipy.stats import norm
import scipy.stats as stats
from scipy.signal import welch

RANDOM_SEED = 35

# ------------------------------------------------
# 0. Set random seed so only difference is the noise
# ------------------------------------------------

def set_seed(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Make CuDNN deterministic (slower but reproducible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(RANDOM_SEED)

## Load Physical Timestream, visualise in histogram

In [None]:

timestream = "/content/drive/MyDrive/SnoopyCooldownLogs/250227_Uninsulated_Coil_TuneThrough0/250228/Timestreams/BP2_25MHz_67108864samples.csv"

x = np.loadtxt(timestream, comments="#")
x = np.asarray(x, dtype=np.float64)

from scipy.signal import decimate

#x = decimate(x, 10, ftype="fir", zero_phase=True)

#x = load_labone_scope_trace(fname)

print("Loaded", x.size, "samples.")
print("Mean:", x.mean(), "Std:", x.std())

fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].hist(x, bins=40, density=True, alpha=0.6, label="Data")

# fit a normal distribution
mu, sigma = x.mean(), x.std()
xx = np.linspace(mu-4*sigma, mu+4*sigma, 400)
ax[0].plot(xx, norm.pdf(xx, mu, sigma), 'r--', label="Gaussian fit")

ax[0].set_title("Histogram of noise samples")
ax[0].set_xlabel("Value (V)")
ax[0].set_ylabel("Probability density")
ax[0].legend()

n = min(len(x), 10_000_000)
rng = np.random.default_rng(RANDOM_SEED)
xs = rng.choice(x, size=n, replace=False)
stats.probplot(xs, dist="norm", plot=ax[1])
ax[1].set_title("QQ Plot vs Gaussian")


## Bandpassing + Reshaping

This cell filters the physical noise and adds it to digital noise. The crucial parameter $p$ that controls how much resonator noise is added is defined here, and used throughout the diffusion process later.

In [None]:
import numpy as np
from scipy.signal import butter, sosfiltfilt, detrend
from scipy.signal import welch

fs = 25e6 # Sampling Frequency

x0 = np.asarray(x, dtype=np.float64)

# High-pass filter
f_hp = 1e3   # example cutoff: 1 kHz
sos_hp = butter(
    N=6,
    Wn=f_hp,
    btype="highpass",
    fs=fs,
    output="sos"
)

# highpass
r = sosfiltfilt(sos_hp, x0)

# 1) isolate resonator band (tight!)
f0 = 2.47794e6
bw = 10
f_low, f_high = f0 - bw/2, f0 + bw/2

sos = butter(6, [f_low, f_high], btype="bandpass", fs=fs, output="sos")
# bandpass
r = sosfiltfilt(sos, r)

# trim edges
trim = int(0.1 * fs)
r = r[trim:-trim]

# 2) make synthetic flat background
rng = np.random.default_rng(RANDOM_SEED)

w = rng.standard_normal(len(r))


# zero-mean + unit-variance
r_u = (r - r.mean()) / r.std(ddof=0)
w_u = (w - w.mean()) / w.std(ddof=0)   # ~1 already, but do it for exactness

############################
#  Noise Mixing Parameter
############################

p = 0.025 # <-- fraction of variance from the bandpassed physical signal (0..1)

y = np.sqrt(1 - p) * w_u + np.sqrt(p) * r_u


## Save engineered noise to .npy file, visualise saved noise

In [None]:
def save_noise(array, fname):
  out = fname
  array = array - array.mean()
  array = array / (array.std() + 1e-8)
  rng = np.random.default_rng(seed=RANDOM_SEED)
  np.save(out, array)
  print("Saved", out)

out = "physical_noise_unitvar.npy"

# y if using the bandpassed superconducting array, x if using default array
array_to_use = y

save_noise(array_to_use, out)

# Load physical noise
phys_noise = np.load("physical_noise_unitvar.npy")

# Create equivalent digital noise
dig_noise = np.random.randn(len(phys_noise))

print(f"Physical Std: {np.std(phys_noise):.4f}")
print(f"Digital Std:  {np.std(dig_noise):.4f}")

# Plot histograms
plt.figure(figsize=(10, 6))
plt.hist(dig_noise, bins=100, alpha=0.5, label='Digital (Ideal)', density=True, range=(-5, 5))
plt.hist(phys_noise, bins=100, alpha=0.5, label='Physical (Resonator)', density=True, range=(-5, 5))
plt.yscale('log') # Log scale reveals the tails!
plt.legend()
plt.title("Noise Distribution Check (Log Scale)")
plt.show()


## PSD Visualisation of Engineered Noise

In [None]:
from scipy.signal import welch, sosfreqz
import matplotlib.pyplot as plt
import numpy as np

fs = 25e6

# PSD before filtering
f, Pxx = welch(dig_noise, fs=fs, window="hann", nperseg=2**18, detrend="constant", scaling="density")

# PSD after filtering (Normalization)
f, Pyy = welch(phys_noise, fs=fs, window="hann", nperseg=2**18, detrend="constant", scaling="density")

plt.figure(figsize=(8,5))
plt.title('Engineered Noise, p = %.2f' % p)
#plt.semilogy(f, Pxx, c='b',label="Digital")
plt.semilogy(f, Pyy, c='r', label='Engineered Physical')
#plt.xlim(2.47e6, 2.48e6)
plt.xlabel("Frequency [Hz]", fontsize=15)
plt.ylabel("PSD [1/Hz]", fontsize=15)
plt.grid(True, which="both")
plt.semilogx()
plt.legend()
plt.show()

## Visualisation of Noise in Image Format

In [None]:
import numpy as np
import matplotlib.pyplot as plt

noise = np.load("physical_noise_unitvar.npy")

# take first 32*32 samples
img = noise[:32*32].reshape(32, 32)

dig = np.random.randn(32*32).reshape(32, 32)

plt.figure(figsize=(8,3))

plt.subplot(1,3,1)
plt.imshow(img, cmap="gray")
plt.title(r"Physical, $p$ = %.2f" % p)
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(dig, cmap="gray")
plt.title(r"Digital, $p = %.2f$" % p)
plt.axis("off")

fft = np.fft.fftshift(np.fft.fft2(img))
plt.subplot(1,3,3)
plt.imshow(np.log(np.abs(fft)+1e-6), cmap="inferno")
plt.title(r"Physical log |FFT|, $p$ = %.2f" % p)
plt.axis("off")
plt.show()

plt.tight_layout()
plt.show()



# Diffusion With Physical Noise

First initialise wandb

In [None]:
import wandb
from google.colab import userdata

api_key = userdata.get('WANDB_API_KEY')
wandb.login(key=api_key)

The cell below contains a full diffusion pipeline using HuggingFace's Diffusers package. It has the option to run CIFAR-10 and CelebA.

The USE_PHYSICAL flag controls if physical noise is used or not. The key parameter is the value $p$ set in the noise bandpassing file above.

Config parameters like number of epochs, number of inference steps, LR schedule etc. are at the config dictionary at the bottom of this cell

In [None]:
import os
import math
import shutil
import random
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CelebA
import torchvision.utils as vutils

from tqdm.auto import tqdm

from diffusers import UNet2DModel, DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.training_utils import EMAModel
from accelerate import Accelerator

from PIL import Image
from torch_fidelity import calculate_metrics

##############################
#          Params
##############################

USE_PHYSICAL = True
tag = "physical" if USE_PHYSICAL else "digital"


#dataset = "celeba64"
dataset = "cifar10"
data_path = "./data" if dataset == "celeba64" else "./cifar10_data"
image_size = 64 if dataset == "celeba64" else 32
inference_steps = 10 if dataset == "celeba64" else 50

run_name = f"run1_diffusers_{dataset}_{tag}_raw_noise_seed_{RANDOM_SEED}"

if USE_PHYSICAL:
  run_name += f"_p_{p}"


# --------------------------
# 0) Reproducibility
# --------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ------------------------------------
# 1) Physical noise RNG and normaliser
# ------------------------------------
class PhysicalNoiseRNG:
    def __init__(self, path: str, device: str = "cuda", seed: int = RANDOM_SEED):
        noise = np.load(path).astype(np.float32)
        self.noise = torch.from_numpy(noise).to(device)
        self.N = self.noise.numel()
        self.device = device
        self.rng = np.random.default_rng(seed)

    def get(self, shape):
        num = int(torch.tensor(shape).prod().item())

        if num > self.N:
            raise ValueError(f"Requested {num} samples, but noise buffer has {self.N}.")

        # Random contiguous window, preserve ordering
        start = int(self.rng.integers(0, self.N - num + 1))
        out = self.noise[start:start + num]

        return out.view(shape)

"""
def local_unit_normalize(x: torch.Tensor, eps: float = 1e-6):

  Normalize each sample in a batch independently to mean 0, std 1.
  x: [B,C,H,W]

  dims = (1, 2, 3)
  x = x - x.mean(dim=dims, keepdim=True)
  x = x / (x.std(dim=dims, keepdim=True) + eps)
  return x
"""



# --------------------------
# 2) Dataset: images only
# --------------------------
class CIFAR10ImagesOnly(CIFAR10):
    def __getitem__(self, index):
        img, _ = super().__getitem__(index)
        return img

class CelebAImagesOnly(CelebA):
    def __getitem__(self, index):
        img, _ = super().__getitem__(index)  # CelebA returns (PIL, target)
        return img

# CelebA ref directory

def build_celeba_fid_ref_dir(
    data_root: str = "./data",
    out_dir: str = "./fid_ref_celeba64_train_10k",
    image_size: int = 64,
    n_images: int = 10_000,
    batch_size: int = 128,
    num_workers: int = 4,
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    tfm = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),  # [0,1]
    ])

    ds = CelebA(
        root=data_root,
        split="train",
        download=True,
        transform=tfm,
    )

    subset = Subset(ds, range(min(n_images, len(ds))))

    loader = DataLoader(
        subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    idx = 0
    for imgs, _ in loader:
        imgs_uint8 = (
            imgs.clamp(0, 1)
            .mul(255)
            .to(torch.uint8)
            .permute(0, 2, 3, 1)
            .cpu()
            .numpy()
        )

        for b in range(imgs_uint8.shape[0]):
            Image.fromarray(imgs_uint8[b]).save(out_dir / f"real_{idx:05d}.png")
            idx += 1
            if idx >= n_images:
                print(f"âœ” Built CelebA FID reference at {out_dir}")
                return str(out_dir)

    return str(out_dir)



# --------------------------
# 3) Sampling (DDPM reverse)
#    - supports phys noise init + phys noise at each step if desired
# --------------------------
@torch.no_grad()
def ddpm_sample(
    unet,
    scheduler: DDPMScheduler,
    device: str,
    n_samples: int,
    batch_size: int,
    image_size: int,
    inference_steps: int,
    use_physical: bool,
    phys_rng: PhysicalNoiseRNG | None,
    seed: int,
):
    unet.eval()
    gen = torch.Generator(device=device).manual_seed(seed)

    scheduler.set_timesteps(inference_steps, device=device)

    out = []
    n_batches = (n_samples + batch_size - 1) // batch_size

    for bi in tqdm(range(n_batches), desc="Sampling", leave=False):
        bsz = min(batch_size, n_samples - bi * batch_size)
        shape = (bsz, 3, image_size, image_size)

        if use_physical:
            assert phys_rng is not None
            x = phys_rng.get(shape)
            # x = local_unit_normalize(x)
        else:
            x = torch.randn(shape, device=device, generator=gen)

        # Reverse diffusion
        for t in scheduler.timesteps:
            t_batch = torch.full((bsz,), t, device=device, dtype=torch.long)
            eps = unet(x, t_batch).sample

            step_out = scheduler.step(eps, t, x, generator=gen)
            x = step_out.prev_sample

        out.append(x.detach().cpu())

    out = torch.cat(out, dim=0)[:n_samples]
    return out  # in [-1, 1] (because training data is normalized that way)

@torch.no_grad()
def ddpm_trajectory(
    unet,
    scheduler: DDPMScheduler,
    device: str,
    image_size: int,
    inference_steps: int,
    use_physical: bool,
    phys_rng: PhysicalNoiseRNG | None,
    seed: int,
    n_samples: int = 4,
    n_frames: int = 8,
):
    """
    Returns: frames list, each element is a tensor [n_samples, 3, H, W] in [-1,1]
    Captures n_frames snapshots across the reverse process.
    """
    unet.eval()
    gen = torch.Generator(device=device).manual_seed(seed)

    scheduler.set_timesteps(inference_steps, device=device)
    timesteps = list(scheduler.timesteps)

    # choose roughly-uniform capture indices
    capture_ids = set(np.linspace(0, len(timesteps) - 1, n_frames, dtype=int).tolist())

    shape = (n_samples, 3, image_size, image_size)
    if use_physical:
        assert phys_rng is not None
        x = phys_rng.get(shape)
        # x = local_unit_normalize(x)
    else:
        x = torch.randn(shape, device=device, generator=gen)

    frames = []
    for idx, t in enumerate(timesteps):
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        eps = unet(x, t_batch).sample
        x = scheduler.step(eps, t, x, generator=gen).prev_sample

        if idx in capture_ids or idx == len(timesteps) - 1:
            frames.append(x.detach().cpu().clone())

    return frames



def tensor_to_uint8_images(x):
    """
    x: [N,3,H,W] in [-1,1] -> uint8 HWC in [0,255]
    """
    x = (x / 2 + 0.5).clamp(0, 1)
    x = (x * 255).to(torch.uint8)
    x = x.permute(0, 2, 3, 1).contiguous()
    return x


def save_image_dir_uint8_hwc(imgs_uint8_hwc: torch.Tensor, out_dir: Path, prefix="gen"):
    out_dir.mkdir(parents=True, exist_ok=True)
    n = imgs_uint8_hwc.shape[0]
    for i in range(n):
        Image.fromarray(imgs_uint8_hwc[i].cpu().numpy()).save(out_dir / f"{prefix}_{i:05d}.jpg", quality=95)


# --------------------------
# 4) FID-only eval (torch-fidelity)
#    - deletes generated files afterwards
# --------------------------
@torch.no_grad()
def fid_eval_torch_fidelity(
    unet,
    scheduler,
    device: str,
    out_root: str,
    epoch: int,
    image_size: int,
    inference_steps: int,
    fid_ref: str = "cifar10-train",
    n_samples: int = 10_000,
    batch_size: int = 256,
    seed: int = 0,
    use_physical: bool = False,
    phys_rng: PhysicalNoiseRNG | None = None,
    cache: bool = False,              # keep False to avoid the notebook rerun cache-unpickling drama
):
    out_root = Path(out_root)
    gen_dir = out_root / "fid_tmp" / f"epoch_{epoch:03d}"
    if gen_dir.exists():
        shutil.rmtree(gen_dir)

    samples = ddpm_sample(
        unet=unet,
        scheduler=scheduler,
        device=device,
        n_samples=n_samples,
        batch_size=batch_size,
        image_size=image_size,
        inference_steps=inference_steps,
        use_physical=use_physical,
        phys_rng=phys_rng,
        seed=seed,
    )

    imgs_uint8 = tensor_to_uint8_images(samples)
    save_image_dir_uint8_hwc(imgs_uint8, gen_dir, prefix="gen")

    # FID only
    metrics = calculate_metrics(
        input1=str(gen_dir),
        input2=fid_ref,
        cuda=torch.cuda.is_available(),
        fid=True,
        isc=False,
        kid=False,
        prc=False,
        verbose=False,
        cache=cache,
    )
    fid = float(metrics["frechet_inception_distance"])

    # cleanup
    shutil.rmtree(gen_dir, ignore_errors=True)
    return fid


# --------------------------
# 5) Training loop (diffusers + EMA + accelerate)
# --------------------------
def train_diffusers_with_custom_noise(config):
    set_seed(config["seed"])

    accelerator = Accelerator(
        mixed_precision=config["mixed_precision"],
        gradient_accumulation_steps=config["gradient_accumulation_steps"],
        log_with="wandb",
        project_dir=os.path.join(config["output_dir"], "logs"),
    )

    if accelerator.is_main_process:
      accelerator.init_trackers(
          project_name=config["wandb_project"],
          config=config,
          init_kwargs={"wandb": {"name": config.get("wandb_run_name", "run")}},
      )
      os.makedirs(config["output_dir"], exist_ok=True)
      os.makedirs(os.path.join(config["output_dir"], "samples"), exist_ok=True)

    # data
    transform_train = transforms.Compose([
        transforms.Resize((config["image_size"], config["image_size"])),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),  # [-1,1]
    ])

    ds = config.get("dataset_name", "cifar10").lower()

    if ds == "cifar10":
        dataset = CIFAR10ImagesOnly(
            root=config["dataset_path"],
            train=True,
            transform=transform_train,
            download=True,
        )
    elif ds in {"celeba64", "celeba"}:
        dataset = CelebAImagesOnly(
            root=config["dataset_path"],
            split="train",
            transform=transform_train,
            download=True,
        )
    else:
        raise ValueError(f"Unknown dataset_name: {config['dataset_name']}")


    dataloader = DataLoader(
        dataset,
        batch_size=config["train_batch_size"],
        shuffle=True,
        num_workers=config["num_workers"],
        pin_memory=True,
        drop_last=True,
    )

    # model: EXACT UNet settings you pasted (attention blocks etc.)
    model = UNet2DModel(
        sample_size=config["image_size"],
        in_channels=3,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(128, 256, 512, 512),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )

    noise_scheduler = DDPMScheduler(
        num_train_timesteps=config["num_train_timesteps"],
        beta_schedule=config["beta_schedule"],
        prediction_type=config["prediction_type"],
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config["learning_rate"],
        betas=(config["adam_beta1"], config["adam_beta2"]),
        weight_decay=config["adam_weight_decay"],
        eps=config["adam_epsilon"],
    )

    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=config["lr_warmup_steps"],
        num_training_steps=(len(dataloader) * config["num_epochs"]),
    )

    # prepare
    model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, dataloader, lr_scheduler
    )

    # EMA (same idea as your hardcore code)
    ema = None
    if config.get("ema_decay", 0) > 0:
        ema = EMAModel(
            model.parameters(),
            decay=config["ema_decay"],
            use_ema_warmup=True,
            inv_gamma=1.0,
            power=3/4,
        )
        ema.to(accelerator.device)
        accelerator.register_for_checkpointing(ema)

    # physical RNGs (train + eval) on the right device
    phys_rng_train = None
    phys_rng_eval = None
    if config["use_physical_noise"]:
        # different seeds to make the training and eval different, non correlated
        phys_rng_train = PhysicalNoiseRNG(config["physical_noise_path"],
                                          seed=config["seed"]+1,
                                          device=str(accelerator.device))
        phys_rng_eval = PhysicalNoiseRNG(config["physical_noise_path"],
                                         seed=config["seed"]+2,
                                         device=str(accelerator.device))

    global_step = 0

    for epoch in range(config["num_epochs"]):
        model.train()

        running_loss = 0.0
        n_batches = 0

        progress = tqdm(
            dataloader,
            disable=not accelerator.is_local_main_process,
            desc=f"Epoch {epoch+1}/{config['num_epochs']}",
        )

        for batch in progress:
            clean = batch  # [B,3,H,W] in [-1,1]
            bsz = clean.shape[0]

            # timesteps
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,),
                device=clean.device
            ).long()

            # CUSTOM NOISE HERE:
            if config["use_physical_noise"]:
                noise = phys_rng_train.get(clean.shape)
                # noise = local_unit_normalize(noise) # normalise locally
            else:
                noise = torch.randn_like(clean)

            # standard scheduler add_noise
            noisy = noise_scheduler.add_noise(clean, noise, timesteps)

            with accelerator.accumulate(model):
                pred = model(noisy, timesteps).sample

                if config["prediction_type"] == "epsilon":
                    target = noise
                elif config["prediction_type"] == "v_prediction":
                    target = noise_scheduler.get_velocity(clean, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction_type: {config['prediction_type']}")

                loss = F.mse_loss(pred, target)

                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if accelerator.is_main_process and ema is not None:
                ema.step(model.parameters())

            running_loss += float(loss.detach().item())
            n_batches += 1

            logs = {"loss": float(loss.detach().item()), "lr": float(lr_scheduler.get_last_lr()[0])}
            progress.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        avg_loss = running_loss / max(1, n_batches)
        if accelerator.is_main_process:
          accelerator.log({"avg_loss": avg_loss, "epoch": epoch + 1}, step=global_step)

        # ---- Eval / checkpoint ----
        do_eval = (accelerator.is_main_process and (
            (epoch + 1) % config["eval_epochs"] == 0 or (epoch == 0)
        ))
        if do_eval:
            # checkpoint training state
            ckpt_dir = os.path.join(config["output_dir"], f"checkpoint-epoch-{epoch+1}")
            accelerator.save_state(ckpt_dir)

            # swap EMA weights in for eval (then restore)
            if ema is not None:
                ema.store(model.parameters())
                ema.copy_to(model.parameters())

            unwrapped = accelerator.unwrap_model(model)

            # quick sample grid for sanity
            eval_steps = config["sample_inference_steps"]
            grid_n = config["num_save_samples"]
            grid = ddpm_sample(
                unet=unwrapped,
                scheduler=noise_scheduler,
                device=str(accelerator.device),
                n_samples=grid_n,
                batch_size=min(grid_n, 64),
                image_size=config["image_size"],
                inference_steps=eval_steps,
                use_physical=config["use_physical_noise"],
                phys_rng=phys_rng_eval,
                seed=config["seed"] + epoch + 123,
            )
            grid = (grid / 2 + 0.5).clamp(0, 1)
            grid_img = vutils.make_grid(grid, nrow=int(math.sqrt(grid_n)))
            vutils.save_image(grid_img, os.path.join(config["output_dir"], "samples", f"grid_epoch{epoch+1:03d}.png"))
            if accelerator.is_main_process:
              accelerator.log(
                  {"samples_grid": wandb.Image(grid_img, caption=f"Epoch {epoch+1}")},
                  step=global_step,
              )

            # FID(10k)
            # fid_ref = "cifar10-train" if config["dataset_name"] == "cifar10" else "celeba-train"
            ###

            if config["dataset_name"] == "cifar10":
                fid_ref = "cifar10-train"
            else:
                fid_ref_dir = os.path.join(
                    config["output_dir"],
                    "fid_ref_celeba64_train_10k",
                )

                if not os.path.isdir(fid_ref_dir) or len(os.listdir(fid_ref_dir)) < 10_000:
                    build_celeba_fid_ref_dir(
                        data_root=config["dataset_path"],  # <-- "./data"
                        out_dir=fid_ref_dir,
                        image_size=config["image_size"],    # 64
                        n_images=10_000,
                        batch_size=128,
                        num_workers=config["num_workers"],
                    )

                fid_ref = fid_ref_dir


            ###

            fid = fid_eval_torch_fidelity(
                unet=unwrapped,
                scheduler=noise_scheduler,
                device=str(accelerator.device),
                out_root=config["output_dir"],
                epoch=epoch + 1,
                image_size=config["image_size"],
                inference_steps=config["fid_inference_steps"],
                fid_ref=fid_ref,
                n_samples=config["fid_n_generated"],
                batch_size=config["fid_gen_batch_size"],
                seed=config["seed"] + 10_000 + epoch,
                use_physical=config["use_physical_noise"],
                phys_rng=phys_rng_eval,
                cache=False,  # keep it stable across notebook reruns
            )

            print(f"[Epoch {epoch+1}] FID(10k) = {fid:.3f}")
            accelerator.log({"fid10k": fid, "epoch": epoch + 1}, step=global_step)

            # ---- Denoising trajectory (log to W&B) ----
            traj_frames = ddpm_trajectory(
                unet=unwrapped,
                scheduler=noise_scheduler,
                device=str(accelerator.device),
                image_size=config["image_size"],
                inference_steps=config["sample_inference_steps"],  # or a separate traj_steps config
                use_physical=config["use_physical_noise"],
                phys_rng=phys_rng_eval,
                seed=config["seed"] + 999 + epoch,
                n_samples=1,
                n_frames=6,
            )

            # Make one image: rows = samples, cols = time
            # Each frame is [N,3,H,W] in [-1,1]
            vis = []
            for f in traj_frames:
                f = (f / 2 + 0.5).clamp(0, 1)  # [N,3,H,W]
                vis.append(f)
            vis = torch.cat(vis, dim=0)  # [N*n_frames, 3, H, W]

            n_frames = len(traj_frames)
            traj_grid = vutils.make_grid(vis, nrow=n_frames)

            # Save locally (optional)
            traj_path = os.path.join(config["output_dir"], "samples", f"traj_epoch{epoch+1:03d}.png")
            vutils.save_image(traj_grid, traj_path)

            # Log to W&B
            if accelerator.is_main_process:
                accelerator.log(
                    {"denoising_traj": wandb.Image(traj_grid, caption=f"Epoch {epoch+1} denoising trajectory")},
                    step=global_step,
                )


            # restore training weights if EMA
            if ema is not None:
                ema.restore(model.parameters())

    # final save (EMA weights if enabled)
    if accelerator.is_main_process:
        if ema is not None:
            ema.copy_to(model.parameters())

        # save diffusers-style checkpoint
        save_dir = os.path.join(config["output_dir"], "final")
        os.makedirs(save_dir, exist_ok=True)
        accelerator.unwrap_model(model).save_pretrained(save_dir)
        noise_scheduler.save_pretrained(save_dir)

    accelerator.end_training()


# --------------------------
# 6) Example config (match your hardcore defaults)
# --------------------------
config = {
    # data
    "dataset_name": dataset,
    "dataset_path": data_path,
    "image_size": image_size,
    "num_workers": 4,

    # model/scheduler
    "num_train_timesteps": 1000,
    "beta_schedule": "linear",
    "prediction_type": "epsilon",

    # training
    "num_epochs": 50,
    "train_batch_size": 128,
    "gradient_accumulation_steps": 1,
    "learning_rate": 1e-4,
    "lr_warmup_steps": 500,
    "adam_beta1": 0.95,
    "adam_beta2": 0.999,
    "adam_weight_decay": 1e-6,
    "adam_epsilon": 1e-8,

    # EMA
    "ema_decay": 0.9999,

    # logging
    "mixed_precision": "fp16",
    "wandb_project": f"{dataset}_diffusion_superconductor_engineered",
    "wandb_run_name": run_name,
    "output_dir": "./" + run_name,

    # physical noise
    "use_physical_noise": USE_PHYSICAL,
    "physical_noise_path": "./physical_noise_unitvar.npy",

    # eval
    "eval_epochs": 10,
    "num_save_samples": 36,
    "sample_inference_steps": inference_steps,  # for the saved sample grid

    # FID
    "fid_n_generated": 10_000,
    "fid_inference_steps": inference_steps,     # DDPM steps for FID sampling
    "fid_gen_batch_size": 256,     # can differ from training batch size

    # seed
    "seed": RANDOM_SEED,
}


if __name__ == "__main__":
    train_diffusers_with_custom_noise(config)


In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
N = len(phys_noise)
L = 3*32*32  # one noise "image"
rng = np.random.default_rng(0)

stds = []
means = []
for _ in range(2000):
    s = rng.integers(0, N-L)
    w = phys_noise[s:s+L]
    means.append(w.mean())
    stds.append(w.std())

print(np.mean(stds), np.std(stds), np.min(stds), np.max(stds))
print(np.mean(means), np.std(means), np.min(means), np.max(means))
