In [12]:
# wgan_gp_curve_discriminator.py
# PyTorch 1D WGAN-GP for curves + using the discriminator as a realism classifier

import os
import math
import random
import numpy as np
import pandas as pd
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [13]:
# -----------------------------
# 0) Utils: seeds & device
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(1234)

In [14]:
# -----------------------------------------
# 1) (Demo) Synthetic curves: mixtures of sinusoids
#    Replace this block with your real dataset loading.
# -----------------------------------------
def synth_curves(n, T=256, noise_std=0.05):
    """
    Create 'real' curves: sum of 2-3 sinusoids with smooth envelopes.
    """
    xs = []
    t = np.linspace(0, 1, T)
    for _ in range(n):
        k = np.random.choice([2, 3])
        y = np.zeros_like(t)
        for _ in range(k):
            amp = np.random.uniform(0.5, 1.5)
            freq = np.random.uniform(1.0, 6.0)
            phase = np.random.uniform(0, 2*np.pi)
            env = 0.6 + 0.4*np.cos(2*np.pi*np.random.uniform(0.5, 1.5)*t + np.random.uniform(0, 2*np.pi))
            y += amp * np.sin(2*np.pi*freq*t + phase) * env
        y += np.random.normal(0, noise_std, size=T)
        xs.append(y.astype(np.float32))
    X = np.stack(xs, axis=0)  # (n, T)
    return X

def synth_ood_curves(n, T=256, noise_std=0.1):
    """
    Create 'OOD' curves: spikes/square-ish waves to look different.
    """
    xs = []
    t = np.linspace(0, 1, T)
    for _ in range(n):
        # square-ish + spikes
        base = np.sign(np.sin(2*np.pi*np.random.uniform(2, 5)*t + np.random.uniform(0, 2*np.pi)))
        # random spikes
        for _ in range(np.random.randint(2, 6)):
            idx = np.random.randint(0, T)
            width = np.random.randint(2, 8)
            base[idx:idx+width] += np.random.uniform(2.0, 4.0)
        base += np.random.normal(0, noise_std, size=T)
        xs.append(base.astype(np.float32))
    X = np.stack(xs, axis=0)
    return X


In [15]:
# -----------------------------------------
# 2) Dataset & normalization
# -----------------------------------------
class CurveDataset(Dataset):
    def __init__(self, X: np.ndarray, mean: Optional[np.ndarray]=None, std: Optional[np.ndarray]=None):
        """
        X shape: (N, T). We store as (N, 1, T) for Conv1d.
        If mean/std are given, apply them; else compute per-timepoint stats from X.
        """
        assert X.ndim == 2
        self.X = X.copy().astype(np.float32)
        self.N, self.T = self.X.shape
        if mean is None or std is None:
            self.mean = self.X.mean(axis=0, keepdims=True)
            self.std = self.X.std(axis=0, keepdims=True) + 1e-6
        else:
            self.mean = mean
            self.std = std
        self.X = (self.X - self.mean) / self.std
        self.X = self.X[:, None, :]  # (N, 1, T)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        x = self.X[idx]
        return torch.from_numpy(x)


def compute_norm_stats(X: np.ndarray):
    mean = X.mean(axis=0, keepdims=True)
    std = X.std(axis=0, keepdims=True) + 1e-6
    return mean.astype(np.float32), std.astype(np.float32)

In [16]:
# -----------------------------------------
# 3) Models: 1D Generator & Critic
# -----------------------------------------
class Generator1D(nn.Module):
    def __init__(self, z_dim=64, T=256, base_ch=64):
        super().__init__()
        # Project noise to a low-res temporal map, then upsample with ConvTranspose1d
        self.init_T = T // 16  # 16x upsampling total
        self.fc = nn.Linear(z_dim, base_ch * self.init_T)

        self.net = nn.Sequential(
            nn.ConvTranspose1d(base_ch, base_ch, kernel_size=4, stride=2, padding=1),  # x2
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose1d(base_ch, base_ch//2, kernel_size=4, stride=2, padding=1),  # x4
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose1d(base_ch//2, base_ch//4, kernel_size=4, stride=2, padding=1),  # x8
            nn.LeakyReLU(0.2, True),
            nn.ConvTranspose1d(base_ch//4, base_ch//8, kernel_size=4, stride=2, padding=1),  # x16 -> T
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(base_ch//8, 1, kernel_size=3, padding=1),
        )

    def forward(self, z):
        x = self.fc(z)  # (B, base_ch * init_T)
        B = x.shape[0]
        # reshape to (B, C, T0)
        C = x.shape[1] // self.init_T
        x = x.view(B, C, self.init_T)
        x = self.net(x)
        return x  # (B, 1, T)


class Critic1D(nn.Module):
    def __init__(self, T=256, base_ch=64, embed_dim=128):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, base_ch, kernel_size=7, padding=3),
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(base_ch, base_ch, kernel_size=5, stride=2, padding=2),  # /2
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(base_ch, base_ch*2, kernel_size=5, stride=2, padding=2),  # /4
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(base_ch*2, base_ch*4, kernel_size=5, stride=2, padding=2),  # /8
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(base_ch*4, base_ch*4, kernel_size=5, stride=2, padding=2),  # /16
            nn.LeakyReLU(0.2, True),
        )
        # compute feature length after strides
        feat_T = T // 16
        self.flat = nn.Linear(base_ch*4*feat_T, embed_dim)
        self.out = nn.Linear(embed_dim, 1)  # Wasserstein score

    def forward(self, x, return_embed=False):
        h = self.features(x)
        h = h.flatten(1)
        e = self.flat(h)
        s = self.out(F.leaky_relu(e, 0.2))
        if return_embed:
            return s.squeeze(1), e
        return s.squeeze(1)

In [17]:
# -----------------------------------------
# 4) WGAN-GP training step helpers
# -----------------------------------------
def gradient_penalty(critic, real, fake):
    B = real.size(0)
    eps = torch.rand(B, 1, 1, device=real.device)
    inter = eps * real + (1 - eps) * fake
    inter.requires_grad_(True)
    scores = critic(inter)
    grads = torch.autograd.grad(
        outputs=scores,
        inputs=inter,
        grad_outputs=torch.ones_like(scores),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    grads = grads.view(B, -1)
    gp = ((grads.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp

In [18]:
# -----------------------------------------
# 5) Training
# -----------------------------------------
def train_wgan_gp(
    train_loader,
    T=256,
    z_dim=64,
    g_lr=1e-4,
    d_lr=1e-4,
    lambda_gp=10.0,
    n_epochs=200,
    n_critic=5,
    print_every=50
):
    G = Generator1D(z_dim=z_dim, T=T).to(device)
    D = Critic1D(T=T).to(device)

    g_opt = torch.optim.Adam(G.parameters(), lr=g_lr, betas=(0.5, 0.9))
    d_opt = torch.optim.Adam(D.parameters(), lr=d_lr, betas=(0.5, 0.9))

    for epoch in range(1, n_epochs+1):
        for i, real in enumerate(train_loader):
            real = real.to(device)  # (B, 1, T)

            # 1) Update Critic n_critic times
            for _ in range(n_critic):
                z = torch.randn(real.size(0), z_dim, device=device)
                fake = G(z).detach()
                d_real = D(real)
                d_fake = D(fake)
                gp = gradient_penalty(D, real, fake)

                d_loss = -(d_real.mean() - d_fake.mean()) + lambda_gp * gp

                d_opt.zero_grad(set_to_none=True)
                d_loss.backward()
                d_opt.step()

            # 2) Update Generator once
            z = torch.randn(real.size(0), z_dim, device=device)
            fake = G(z)
            g_loss = -D(fake).mean()

            g_opt.zero_grad(set_to_none=True)
            g_loss.backward()
            g_opt.step()

        if epoch % print_every == 0 or epoch == 1:
            print(f"[Epoch {epoch:04d}] D_loss={d_loss.item():.4f} | G_loss={g_loss.item():.4f}")

    return G, D

In [19]:
# -----------------------------------------
# 6) Scoring, thresholding, evaluation
# -----------------------------------------
@torch.no_grad()
def critic_scores(D, loader):
    scores = []
    for x in loader:
        x = x.to(device)
        s = D(x)  # (B,)
        scores.append(s.cpu().numpy())
    return np.concatenate(scores, axis=0)

@torch.no_grad()
def critic_embeddings(D, loader):
    embs = []
    for x in loader:
        x = x.to(device)
        _, e = D(x, return_embed=True)
        embs.append(e.cpu().numpy())
    return np.concatenate(embs, axis=0)

def choose_threshold_from_id(scores_id, fpr_target=0.05):
    # lower tail is "less real" for WGAN critics if trained as above
    q = np.quantile(scores_id, fpr_target)
    return q

def auroc(y_true, y_score):
    # simple AUROC (Mannâ€“Whitney U)
    y_true = np.array(y_true)
    y_score = np.array(y_score)
    pos = y_score[y_true == 1]
    neg = y_score[y_true == 0]
    # P(score_pos > score_neg) + 0.5*P(=)
    total = 0
    ties = 0
    for p in pos:
        total += np.sum(p > neg) + 0.5*np.sum(p == neg)
    return total / (len(pos)*len(neg) + 1e-12)

def train_test_split(X, ratio_train=0.8):
    N = X.shape[0]
    indices = np.arange(N)
    np.random.shuffle(indices)
    n_train = int(N * ratio_train)
    train_idx = indices[:n_train]
    val_idx = indices[n_train:]
    return X[train_idx], X[val_idx]

In [20]:
# -----------------------------------------
# 7) Mahalanobis score on critic embeddings (optional)
# -----------------------------------------
class GaussianEmbed:
    def __init__(self):
        self.mean = None
        self.cov_inv = None

    def fit(self, E):  # E: (N, d)
        self.mean = E.mean(axis=0, keepdims=True)
        cov = np.cov(E.T) + 1e-5*np.eye(E.shape[1])
        self.cov_inv = np.linalg.inv(cov)

    def maha(self, E):  # smaller is more ID
        diff = E - self.mean
        return np.einsum('nd,dk,nk->n', diff, self.cov_inv, diff)

In [None]:
# -----------------------------------------
# 8) Glue: Putting it all together
# -----------------------------------------
if __name__ == "__main__":
    # --- Config ---
    T = 128
    ratio_train = 0.8
    batch_size = 128
    z_dim = 32

    # --- Data (replace with your arrays of shape (N, T)) ---
    skin_reflectance = pd.read_csv('../data/resample_skin_reflectance.csv').to_numpy()
    derivative = pd.read_csv('../data/resample_derivative.csv').to_numpy()
    X_id = np.concatenate((skin_reflectance, derivative), axis=1)

    X_train_id, X_val_id = train_test_split(X_id, ratio_train=ratio_train)
    

    # --- Normalize using training stats only ---
    mean, std = compute_norm_stats(X_train_id)
    train_ds = CurveDataset(X_train_id, mean, std)
    val_id_ds = CurveDataset(X_val_id, mean, std)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_id_loader = DataLoader(val_id_ds, batch_size=batch_size, shuffle=False)

    # --- Train WGAN-GP ---
    G, D = train_wgan_gp(
        train_loader,
        T=T,
        z_dim=z_dim,
        g_lr=1e-4,
        d_lr=1e-4,
        lambda_gp=10.0,
        n_epochs=200,       # for real data you may go longer; tune as needed
        n_critic=5,
        print_every=50
    )

    # --- Freeze critic and score validation splits ---
    D.eval()
    s_val_id = critic_scores(D, val_id_loader)

    # --- Threshold selection ---
    # Option A: target FPR on ID only
    tau = choose_threshold_from_id(s_val_id, fpr_target=0.05)

    print(f"Chosen threshold tau (5% ID FPR): {tau:.3f}")

    # Option B (if OOD labels are available): simple AUROC

    # --- Test evaluation ---

    # Predictions based on tau

    # --- Optional: Mahalanobis on embeddings ---
    E_train_id = critic_embeddings(D, train_loader)
    E_val_id = critic_embeddings(D, val_id_loader)
    ge = GaussianEmbed()
    ge.fit(E_train_id)
    m_val_id = -ge.maha(E_val_id)  # higher is more ID, so take negative distance
    tau_emb = choose_threshold_from_id(m_val_id, fpr_target=0.05)
    print(f"Chosen threshold tau_emb (5% ID FPR in embed): {tau_emb:.3f}")
    

    # --- Save for later inference ---
    ckpt = {
        "G": G.state_dict(),
        "D": D.state_dict(),
        "ge": ge,
        "mean": mean,
        "std": std,
        "tau": tau,
        "tau_emb": tau_emb
    }
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(ckpt, "checkpoints/skin_wgan_gp.pt")
    print("Saved checkpoints/skin_wgan_gp.pt")

[Epoch 0001] D_loss=-12.0021 | G_loss=2.1986
[Epoch 0050] D_loss=-1.1060 | G_loss=3.0608
[Epoch 0100] D_loss=-1.1692 | G_loss=1.5935
[Epoch 0150] D_loss=-1.0049 | G_loss=2.5962
[Epoch 0200] D_loss=-1.6637 | G_loss=1.7923
Chosen threshold tau (5% ID FPR): -7.119
Chosen threshold tau_emb (5% ID FPR in embed): -306.201
Saved checkpoints/skin_wgan_gp.pt


In [None]:
# -----------------------------------------
# 8) Glue: Putting it all together
# -----------------------------------------
if __name__ == "__main__":
    # --- Config ---
    T = 256
    N_train = 4000
    N_val = 800
    N_test = 800
    batch_size = 128
    z_dim = 64

    # --- Data (replace with your arrays of shape (N, T)) ---
    X_train_id = synth_curves(N_train, T=T)
    X_val_id = synth_curves(N_val, T=T)
    X_test_id = synth_curves(N_test, T=T)

    # OOD splits for evaluation (if you don't have them, skip and use ID quantiles only)
    X_val_ood = synth_ood_curves(N_val, T=T)
    X_test_ood = synth_ood_curves(N_test, T=T)

    # --- Normalize using training stats only ---
    mean, std = compute_norm_stats(X_train_id)
    train_ds = CurveDataset(X_train_id, mean, std)
    val_id_ds = CurveDataset(X_val_id, mean, std)
    test_id_ds = CurveDataset(X_test_id, mean, std)
    val_ood_ds = CurveDataset(X_val_ood, mean, std)
    test_ood_ds = CurveDataset(X_test_ood, mean, std)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
    val_id_loader = DataLoader(val_id_ds, batch_size=batch_size, shuffle=False)
    test_id_loader = DataLoader(test_id_ds, batch_size=batch_size, shuffle=False)
    val_ood_loader = DataLoader(val_ood_ds, batch_size=batch_size, shuffle=False)
    test_ood_loader = DataLoader(test_ood_ds, batch_size=batch_size, shuffle=False)

    # --- Train WGAN-GP ---
    G, D = train_wgan_gp(
        train_loader,
        T=T,
        z_dim=z_dim,
        g_lr=1e-4,
        d_lr=1e-4,
        lambda_gp=10.0,
        n_epochs=400,       # for real data you may go longer; tune as needed
        n_critic=5,
        print_every=50
    )

    # --- Freeze critic and score validation splits ---
    D.eval()
    s_val_id = critic_scores(D, val_id_loader)
    s_val_ood = critic_scores(D, val_ood_loader)

    # --- Threshold selection ---
    # Option A: target FPR on ID only
    tau = choose_threshold_from_id(s_val_id, fpr_target=0.05)

    # Option B (if OOD labels are available): simple AUROC
    y_val = np.concatenate([np.ones_like(s_val_id), np.zeros_like(s_val_ood)])
    s_val = np.concatenate([s_val_id, s_val_ood])
    roc = auroc(y_val, s_val)
    print(f"Validation AUROC (critic score): {roc:.3f}")
    print(f"Chosen threshold tau (5% ID FPR): {tau:.3f}")

    # --- Test evaluation ---
    s_test_id = critic_scores(D, test_id_loader)
    s_test_ood = critic_scores(D, test_ood_loader)
    y_test = np.concatenate([np.ones_like(s_test_id), np.zeros_like(s_test_ood)])
    s_test = np.concatenate([s_test_id, s_test_ood])

    test_auc = auroc(y_test, s_test)
    print(f"Test AUROC (critic score): {test_auc:.3f}")

    # Predictions based on tau
    y_pred_test = (s_test >= tau).astype(np.int32)
    acc = (y_pred_test == y_test).mean()
    print(f"Test accuracy @tau: {acc:.3f}")

    # --- Optional: Mahalanobis on embeddings ---
    E_val_id = critic_embeddings(D, val_id_loader)
    E_val_ood = critic_embeddings(D, val_ood_loader)
    ge = GaussianEmbed()
    ge.fit(E_train_id)
    m_val_id = -ge.maha(E_val_id)  # higher is more ID, so take negative distance
    m_val_ood = -ge.maha(E_val_ood)
    print(f"Validation AUROC (Mahalanobis in critic embed): {auroc(np.concatenate([np.ones_like(m_val_id), np.zeros_like(m_val_ood)]), np.concatenate([m_val_id, m_val_ood])):.3f}")

    E_test_id = critic_embeddings(D, test_id_loader)
    E_test_ood = critic_embeddings(D, test_ood_loader)
    m_test_id = -ge.maha(E_test_id)
    m_test_ood = -ge.maha(E_test_ood)
    print(f"Test AUROC (Mahalanobis): {auroc(np.concatenate([np.ones_like(m_test_id), np.zeros_like(m_test_ood)]), np.concatenate([m_test_id, m_test_ood])):.3f}")

    # --- Save for later inference ---
    ckpt = {
        "G": G.state_dict(),
        "D": D.state_dict(),
        "mean": mean,
        "std": std,
        "tau": tau
    }
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(ckpt, "checkpoints/curve_wgan_gp.pt")
    print("Saved checkpoints/curve_wgan_gp.pt")

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[Epoch 0001] D_loss=-18.6135 | G_loss=0.1249
[Epoch 0050] D_loss=-3.0907 | G_loss=-16.0091
[Epoch 0100] D_loss=-2.5250 | G_loss=-4.0485
[Epoch 0150] D_loss=-2.1766 | G_loss=-3.3152
[Epoch 0200] D_loss=-2.0812 | G_loss=0.9161
[Epoch 0250] D_loss=-2.2007 | G_loss=0.7205
[Epoch 0300] D_loss=-2.7123 | G_loss=4.5101
[Epoch 0350] D_loss=-2.8464 | G_loss=-0.4621
[Epoch 0400] D_loss=-2.3055 | G_loss=0.5632
Validation AUROC (critic score): 0.955
Chosen threshold tau (5% ID FPR): -1.578
Test AUROC (critic score): 0.948
Test accuracy @tau: 0.879
Validation AUROC (Mahalanobis in critic embed): 1.000
Test AUROC (Mahalanobis): 1.000
Saved checkpoints/curve_wgan_gp.pt


In [27]:
# -----------------------------------------
# 9) Inference helper
# -----------------------------------------
class RealismScorer:
    def __init__(self, ckpt_path: str, T=256):
        ckpt = torch.load(ckpt_path, map_location=device)
        self.mean = ckpt["mean"]
        self.std = ckpt["std"]
        self.tau = float(ckpt["tau"])
        self.T = T

        self.D = Critic1D(T=T).to(device).eval()
        self.D.load_state_dict(ckpt["D"])

    @torch.no_grad()
    def score(self, x_np: np.ndarray) -> float:
        """
        x_np: shape (T,) raw curve. Resampled already.
        Returns: critic score (higher => more real).
        """
        assert x_np.ndim == 1 and x_np.shape[0] == self.T
        x = (x_np[None, :] - self.mean) / self.std
        x = torch.from_numpy(x.astype(np.float32))[:, None, :]  # (1,1,T)
        x = x.to(device)
        s = self.D(x).item()
        return s

    def is_real(self, x_np: np.ndarray) -> bool:
        return self.score(x_np) >= self.tau

In [30]:
# --- Optional: Mahalanobis on embeddings ---
E_train = critic_embeddings(D, train_loader)
E_val_id = critic_embeddings(D, val_id_loader)
E_val_ood = critic_embeddings(D, val_ood_loader)
ge = GaussianEmbed()
# ge.fit(E_val_id)
ge.fit(E_train)
m_val_id = -ge.maha(E_val_id)  # higher is more ID, so take negative distance
m_val_ood = -ge.maha(E_val_ood)
tau_emb = choose_threshold_from_id(m_val_id, fpr_target=0.05)
print(f"Validation AUROC (Mahalanobis in critic embed): {auroc(np.concatenate([np.ones_like(m_val_id), np.zeros_like(m_val_ood)]), np.concatenate([m_val_id, m_val_ood])):.3f}")

E_test_id = critic_embeddings(D, test_id_loader)
E_test_ood = critic_embeddings(D, test_ood_loader)
m_test_id = -ge.maha(E_test_id)
m_test_ood = -ge.maha(E_test_ood)
print(f"Test AUROC (Mahalanobis): {auroc(np.concatenate([np.ones_like(m_test_id), np.zeros_like(m_test_ood)]), np.concatenate([m_test_id, m_test_ood])):.3f}")

# Predictions based on tau
m_test = np.concatenate([m_test_id, m_test_ood])
y_pred_test = (m_test >= tau_emb).astype(np.int32)
acc = (y_pred_test == y_test).mean()
print(f"Test accuracy @tau: {acc:.3f}")

Validation AUROC (Mahalanobis in critic embed): 1.000
Test AUROC (Mahalanobis): 1.000
Test accuracy @tau: 0.974


In [54]:
# --- Optional: Mahalanobis on embeddings ---
ge = GaussianEmbed()
# ge.fit(E_val_id)
ge.fit(X_train_id)
m_val_id = -ge.maha(X_val_id)  # higher is more ID, so take negative distance
m_val_ood = -ge.maha(X_val_ood)
tau_emb = choose_threshold_from_id(m_val_id, fpr_target=0.05)
print(f"Validation AUROC (Mahalanobis in critic embed): {auroc(np.concatenate([np.ones_like(m_val_id), np.zeros_like(m_val_ood)]), np.concatenate([m_val_id, m_val_ood])):.3f}")

m_test_id = -ge.maha(X_test_id)
m_test_ood = -ge.maha(X_test_ood)
print(f"Test AUROC (Mahalanobis): {auroc(np.concatenate([np.ones_like(m_test_id), np.zeros_like(m_test_ood)]), np.concatenate([m_test_id, m_test_ood])):.3f}")

# Predictions based on tau
m_test = np.concatenate([m_test_id, m_test_ood])
y_pred_test = (m_test >= tau_emb).astype(np.int32)
acc = (y_pred_test == y_test).mean()
print(f"Test accuracy @tau: {acc:.3f}")

Validation AUROC (Mahalanobis in critic embed): 1.000
Test AUROC (Mahalanobis): 1.000
Test accuracy @tau: 0.975


In [55]:
# get current directory
os.getcwd()

'g:\\Phd\\Colloborator\\Tina Lasisi\\SkinSpectrum\\analysis'

In [62]:
skin_reflectance = pd.read_csv('../data/resample_skin_reflectance.csv').to_numpy()
derivative = pd.read_csv('../data/resample_derivative.csv').to_numpy()

In [65]:
X_id = np.concatenate((skin_reflectance, derivative), axis=1)

In [66]:
X_id.shape

(15255, 128)