In [None]:
"""
Unconditional diffusion model on MS2 spectra in combined_annotated.h5

- Trains DDPM on ms2_lib (shape: [N, 1600]) from your H5 file
- Generates synthetic spectra and saves:
    synthetic_ms2.npy      (num_samples, 1600)
    synthetic_ms2.png      (plot of a few sampled spectra)
- Saves loss curves each epoch: loss_curve_epochX.png
- Verbose training logs so you can watch progress

Requires:
    pip install torch torchvision h5py numpy matplotlib
"""

import os
import math
import h5py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# =========================
# USER SETTINGS
# =========================
H5_PATH       = r"F:\20251115\spectra_h5\combined_annotated.h5"
H5_DATASET    = "ms2_lib"       # change if your dataset name is different
OUT_DIR       = r"F:\20251115\spectra_h5\diffusion_out"

DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE    = 256
NUM_EPOCHS    = 10              # start small; increase when things look good
LR            = 2e-4
NUM_TIMESTEPS = 10           # diffusion steps (T)
SPECTRUM_LEN  = 1600            # length of each MS2 vector

NUM_SAMPLES   = 32              # how many synthetic spectra to generate

os.makedirs(OUT_DIR, exist_ok=True)

# =========================
# DATASET
# =========================

class MS2H5Dataset(Dataset):
    """
    Lazily reads spectra from an H5 file: [N, SPECTRUM_LEN]
    Scaling:
        - per-spectrum divide by max (if >0) -> [0,1]
        - then map to [-1, 1] for diffusion
    """
    def __init__(self, h5_path, dataset_name):
        super().__init__()
        self.h5_path = h5_path
        self.dataset_name = dataset_name
        self.h5 = h5py.File(self.h5_path, "r")
        self.ds = self.h5[self.dataset_name]
        self.length = self.ds.shape[0]

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        x = self.ds[idx].astype(np.float32)  # shape (SPECTRUM_LEN,)
        # per-spectrum max scaling
        m = x.max()
        if m > 0:
            x = x / m
        # map [0,1] -> [-1,1]
        x = x * 2.0 - 1.0
        return torch.from_numpy(x)

    def close(self):
        if self.h5 is not None:
            self.h5.close()
            self.h5 = None

# =========================
# DIFFUSION UTILITIES
# =========================

def make_beta_schedule(T, beta_start=1e-4, beta_end=2e-2):
    return torch.linspace(beta_start, beta_end, T)

class Diffusion:
    def __init__(self, T, device):
        self.T = T
        self.device = device

        betas = make_beta_schedule(T).to(device)          # (T,)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
        self.sqrt_recipm1_alphas = torch.sqrt(1.0 / alphas - 1.0)

        # posterior variance
        self.posterior_var = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: sample x_t from x_0 at time t
        x0: (B, D)
        t:  (B,) integer timesteps in [0, T-1]
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        return sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise

    def p_sample(self, model, x_t, t):
        """
        One reverse step p(x_{t-1} | x_t)
        """
        betas_t = self.betas[t].view(-1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        sqrt_recip_alphas_t = self.sqrt_recip_alphas[t].view(-1, 1)

        # model predicts epsilon
        eps_theta = model(x_t, t)

        # x0_estimate from predicted noise
        x0_hat = sqrt_recip_alphas_t * (x_t - betas_t / sqrt_one_minus_alphas_cumprod_t * eps_theta)

        # posterior mean μ_t
        alphas_t = self.alphas[t].view(-1, 1)
        alphas_cumprod_t = self.alphas_cumprod[t].view(-1, 1)
        alphas_cumprod_prev_t = self.alphas_cumprod_prev[t].view(-1, 1)

        posterior_mean = (
            betas_t * torch.sqrt(alphas_cumprod_prev_t) / (1.0 - alphas_cumprod_t) * x0_hat
            + (torch.sqrt(alphas_t) * (1.0 - alphas_cumprod_prev_t) / (1.0 - alphas_cumprod_t)) * x_t
        )

        posterior_var_t = self.posterior_var[t].view(-1, 1)
        if (t == 0).all():
            # no noise at final step
            return posterior_mean

        noise = torch.randn_like(x_t)
        return posterior_mean + torch.sqrt(posterior_var_t) * noise

    def p_sample_loop(self, model, shape):
        """
        Sample from pure noise x_T ~ N(0, I), then reverse to x_0.
        """
        model.eval()
        x_t = torch.randn(shape, device=self.device)
        with torch.no_grad():
            for time_step in reversed(range(self.T)):
                t_tensor = torch.full((shape[0],), time_step, device=self.device, dtype=torch.long)
                x_t = self.p_sample(model, x_t, t_tensor)
        return x_t

# =========================
# TIME EMBEDDING + MODEL
# =========================

class SinusoidalTimeEmbedding(nn.Module):
    """
    Classic transformer-style sinusoidal embedding of timestep t.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) integer timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        return emb

class DenoiseMLP(nn.Module):
    """
    Simple MLP epsilon-predictor with time conditioning.
    Input:  x_t (B, D), t (B,)
    Output: eps_theta (B, D)
    """
    def __init__(self, data_dim, time_dim=256, hidden_dim=1024, num_layers=4):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(time_dim),
            nn.Linear(time_dim, hidden_dim),
            nn.SiLU(),
        )

        layers = []
        input_dim = data_dim + hidden_dim
        for i in range(num_layers - 1):
            layers.append(nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim))
            layers.append(nn.SiLU())
        layers.append(nn.Linear(hidden_dim, data_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x, t):
        # x: (B, D), t: (B,)
        t_emb = self.time_mlp(t)      # (B, hidden_dim)
        # concatenate time embedding with spectrum
        x_in = torch.cat([x, t_emb], dim=1)
        return self.net(x_in)

# =========================
# TRAINING (HIGH VERBOSITY)
# =========================

def train():
    dataset = MS2H5Dataset(H5_PATH, H5_DATASET)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,   # keep 0 for h5py safety
        drop_last=True,
    )

    model = DenoiseMLP(SPECTRUM_LEN).to(DEVICE)
    diffusion = Diffusion(NUM_TIMESTEPS, DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

    loss_log = []     # store loss for plotting
    global_step = 0

    print("\n========== START TRAINING ==========\n")
    print(f"Device: {DEVICE}")
    print(f"Total spectra: {len(dataset):,}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Steps per epoch: {len(dataloader)}")
    print(f"Timesteps (T): {NUM_TIMESTEPS}")
    print("--------------------------------------\n")

    for epoch in range(NUM_EPOCHS):
        print(f"\n===== Epoch {epoch+1}/{NUM_EPOCHS} =====")
        epoch_loss = []

        for batch_i, batch in enumerate(dataloader):
            batch = batch.to(DEVICE)

            t = torch.randint(
                0, NUM_TIMESTEPS,
                (batch.size(0),),
                device=DEVICE
            ).long()

            noise = torch.randn_like(batch)
            x_t = diffusion.q_sample(batch, t, noise)
            eps_pred = model(x_t, t)

            loss = nn.functional.mse_loss(eps_pred, noise)
            epoch_loss.append(loss.item())
            loss_log.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1

            # ---- VERBOSITY: Every 50 steps ----
            if global_step % 50 == 0:
                print(f"[Epoch {epoch+1} | Step {batch_i+1}/{len(dataloader)} | "
                      f"Global {global_step}] Loss: {loss.item():.6f}")

            # ---- Extra: Every 250 steps ----
            if global_step % 250 == 0:
                avg_last_250 = np.mean(loss_log[-250:])
                print(f"    ↳ Rolling avg (last 250 steps): {avg_last_250:.6f}")

        # ---- End of epoch summary ----
        mean_epoch_loss = np.mean(epoch_loss)
        print(f"Epoch {epoch+1} done. Mean loss: {mean_epoch_loss:.6f}")

        # ---- Save checkpoint ----
        ckpt_path = os.path.join(OUT_DIR, f"diffusion_epoch{epoch+1}.pt")
        torch.save({
            "model_state_dict": model.state_dict(),
            "epoch": epoch + 1,
        }, ckpt_path)
        print(f"Saved checkpoint: {ckpt_path}")

        # ---- Save loss plot each epoch ----
        plt.figure(figsize=(8, 4))
        plt.plot(loss_log, alpha=0.8)
        plt.xlabel("Training step")
        plt.ylabel("Loss")
        plt.title("Diffusion Training Loss")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, f"loss_curve_epoch{epoch+1}.png"), dpi=150)
        plt.close()
        print("Saved loss plot")

    dataset.close()
    return model, diffusion, loss_log

# =========================
# SAMPLING & SAVING
# =========================

def sample_and_save(model, diffusion, num_samples=NUM_SAMPLES):
    model.to(DEVICE)
    model.eval()

    samples = diffusion.p_sample_loop(
        model,
        shape=(num_samples, SPECTRUM_LEN)
    )  # in [-1,1]

    # map back to [0,1]
    samples = (samples.clamp(-1, 1) + 1.0) / 2.0
    samples = samples.cpu().numpy().astype(np.float32)

    # optionally renormalize each spectrum so its max=1
    max_vals = samples.max(axis=1, keepdims=True)
    max_vals[max_vals == 0] = 1.0
    samples = samples / max_vals

    out_npy = os.path.join(OUT_DIR, "synthetic_ms2.npy")
    np.save(out_npy, samples)
    print(f"Saved synthetic spectra to: {out_npy}")

    # quick visualization of first few spectra
    plt.figure(figsize=(10, 6))
    for i in range(min(5, num_samples)):
        plt.plot(samples[i], alpha=0.7)
    plt.xlabel("m/z bin index")
    plt.ylabel("normalized intensity")
    plt.title("Synthetic MS2 spectra from diffusion model")
    plt.tight_layout()
    out_png = os.path.join(OUT_DIR, "synthetic_ms2.png")
    plt.savefig(out_png, dpi=200)
    plt.close()
    print(f"Saved plot: {out_png}")

# =========================
# MAIN
# =========================

if __name__ == "__main__":
    model, diffusion, loss_log = train()
    sample_and_save(model, diffusion, NUM_SAMPLES)
