# Triplet Image Autonecoder 

In [None]:
# All imports
import os
from pathlib import Path
import numpy as np
import pandas as pd
from astropy.io import fits
from tqdm.auto import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
import plotly.graph_objects as go

In [None]:
# Input files
MANIFEST_CSV = "ztf_training_triplets_maxmag_manifest.csv"
TRIPLET_DIR  = Path("ztf_training_triplets_maxmag")  # where FITS live

IMAGE_SIZE   = 64      
BATCH_SIZE   = 32
NUM_EPOCHS   = 200
LR           = 1e-3
RANDOM_SEED  = 42
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)


# Build triplet table from manifest
def build_triplet_table(manifest_csv: str) -> pd.DataFrame:
    """
    Read manifest and pivot so each row corresponds to a single triplet
    (objectId, candid) with 3 file paths: science, difference, template.
    """
    df = pd.read_csv(manifest_csv)

    # Columns 'science', 'difference', 'template'
    trip = (
        df.pivot_table(
            index=["objectId", "candid", "fid", "jd"],
            columns="cutout_type",
            values="file_path",
            aggfunc="first",
        )
        .reset_index()
    )

    # Ensure expected columns exist
    for col in ["science", "difference", "template"]:
        if col not in trip.columns:
            trip[col] = np.nan

    # Keep rows where all three exist and files actually exist
    def has_all_files(row):
        paths = [row["science"], row["difference"], row["template"]]
        return all(isinstance(p, str) and Path(p).exists() for p in paths)

    trip = trip[trip.apply(has_all_files, axis=1)].reset_index(drop=True)

    print(f"Triplets available (objectId+candid rows with all 3 cutouts): {len(trip)}")
    return trip


# Data and pre-processing
def get_first_image_data(path: str) -> np.ndarray:
    """
    Open FITS and return numpy array from the first HDU that contains data..
    """
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(path)

    with fits.open(path, memmap=False) as hdul:
        data = None
        for hdu in hdul:
            if hdu.data is not None:
                data = hdu.data
                break

    if data is None:
        raise ValueError(f"No data in FITS file: {path}")

    arr = np.array(data, dtype=np.float32)
    return arr


def preprocess_single_cutout(arr: np.ndarray, target_size: int = IMAGE_SIZE) -> torch.Tensor:
    """
    Per-cutout preprocessing:
      - Compute mean and std over finite pixels only, in float64 for stability.
      - Subtract mean, divide by std. 'Number of std from the image mean'.
      - Non-finite pixels set to 0 after scaling.
      - Resize to (target_size, target_size).
      - Return tensor of shape (1, H, W).
    """
    # Ensure float32 array
    arr = np.asarray(arr, dtype=np.float32)
    finite_mask = np.isfinite(arr)

    if finite_mask.any():
        # Work in float64 for stable mean/std
        finite_vals = arr[finite_mask].astype(np.float64)
        mean = float(finite_vals.mean())
        std = float(finite_vals.std())

        # Guard against zero/NaN/inf std
        if (not np.isfinite(std)) or (std == 0.0):
            std = 1.0

        # Do the scaling in float64
        arr64 = arr.astype(np.float64)
        arr64 = (arr64 - mean) / std
        # Non-finite pixels set to 0
        arr64[~finite_mask] = 0.0
        arr = arr64.astype(np.float32)
    else:
        # If everything is non-finite, just zeros
        arr = np.zeros_like(arr, dtype=np.float32)

    # To torch tensor, add channel dim
    tensor = torch.from_numpy(arr).unsqueeze(0)  # (1, H, W)

    # Resize to target size using bilinear interpolation
    tensor = F.interpolate(
        tensor.unsqueeze(0),  # (1, 1, H, W)
        size=(target_size, target_size),
        mode="bilinear",
        align_corners=False,
    ).squeeze(0)  # (1, target_size, target_size)

    return tensor 


def load_triplet_as_tensor(sci_path: str, diff_path: str, ref_path: str, target_size: int = IMAGE_SIZE) -> torch.Tensor:
    """
    Load SCI, DIFF, REF cutouts from given paths, preprocess each, and stack 
    into a single tensor of shape (3, H, W).
    """
    sci_arr  = get_first_image_data(sci_path)
    diff_arr = get_first_image_data(diff_path)
    ref_arr  = get_first_image_data(ref_path)

    sci_t  = preprocess_single_cutout(sci_arr,  target_size)
    diff_t = preprocess_single_cutout(diff_arr, target_size)
    ref_t  = preprocess_single_cutout(ref_arr,  target_size)

    x = torch.cat([sci_t, diff_t, ref_t], dim=0)  # (3, H, W) as (SCI, DIFF, REF)
    return x


# Dataset
class TripletCutoutDataset(Dataset):
    """
    Dataset over triplets: each item is (x, x) for autoencoder training,
    where x is a (3, H, W) tensor for one (objectId, candid).
    """

    def __init__(self, df_triplets: pd.DataFrame, image_size: int = IMAGE_SIZE):
        self.df = df_triplets.reset_index(drop=True).copy()
        self.image_size = image_size

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sci_path  = row["science"]
        diff_path = row["difference"]
        ref_path  = row["template"]

        x = load_triplet_as_tensor(sci_path, diff_path, ref_path, self.image_size)

        # Autoencoder target is the input
        return x, x



# Convolutional Autoencoder Architecture 
class ConvAutoencoder(nn.Module):
    """
    3-layer convolutional autoencoder:
      Encoder: 3  -> 32 -> 64 -> 128 channels, stride-2 each time.
      Decoder: 128 -> 64 -> 32 -> 3 via ConvTranspose2d.
    """
    def __init__(self, image_size: int = IMAGE_SIZE):
        super().__init__()

        # Encoder: 3 conv layers, stride 2, ReLU between each
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),   # 3 x H x W -> 32 x H/2 x W/2
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 32 -> 64, H/2 -> H/4
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 64 -> 128, H/4 -> H/8
            nn.ReLU(inplace=True),
        )

        # Decoder: mirror with ConvTranspose2d
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,
                               padding=1, output_padding=1),  # 128 -> 64, H/8 -> H/4
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,
                               padding=1, output_padding=1),  # 64 -> 32, H/4 -> H/2
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2,
                               padding=1, output_padding=1),  # 32 -> 3, H/2 -> H
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out



# Train/val/test split by objectId
def make_splits_by_object(df_triplets: pd.DataFrame, random_seed: int = RANDOM_SEED):
    """
    Split by objectId: 80% train, 10% val, 10% test.
    Ensures the same objectId does not appear in multiple splits.
    """
    rng = np.random.RandomState(random_seed)

    object_ids = df_triplets["objectId"].astype(str).unique()
    rng.shuffle(object_ids)

    n = len(object_ids)
    n_train = int(0.8 * n)
    n_val   = int(0.1 * n)

    train_ids = set(object_ids[:n_train])
    val_ids   = set(object_ids[n_train:n_train + n_val])
    test_ids  = set(object_ids[n_train + n_val:])

    train_df = df_triplets[df_triplets["objectId"].astype(str).isin(train_ids)].copy()
    val_df   = df_triplets[df_triplets["objectId"].astype(str).isin(val_ids)].copy()
    test_df  = df_triplets[df_triplets["objectId"].astype(str).isin(test_ids)].copy()

    print(f"Total objects with triplets: {n}")
    print(f"  Train objects: {len(train_ids)} | Val objects: {len(val_ids)} | Test objects: {len(test_ids)}")
    print(f"  Train triplets: {len(train_df)} | Val triplets: {len(val_df)} | Test triplets: {len(test_df)}")

    return train_df, val_df, test_df


# Training loop
def train_autoencoder():
    # Build full triplet table
    df_triplets = build_triplet_table(MANIFEST_CSV)

    # Split
    train_df, val_df, test_df = make_splits_by_object(df_triplets, RANDOM_SEED)

    # Datasets and loaders
    train_ds = TripletCutoutDataset(train_df, image_size=IMAGE_SIZE)
    val_ds   = TripletCutoutDataset(val_df,   image_size=IMAGE_SIZE)
    test_ds  = TripletCutoutDataset(test_df,  image_size=IMAGE_SIZE)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # Model, optimizer, loss
    model = ConvAutoencoder(image_size=IMAGE_SIZE).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.MSELoss()

    best_val_loss = np.inf
    history = {"epoch": [], "train_loss": [], "val_loss": []}

    for epoch in range(1, NUM_EPOCHS + 1):
        # Train
        model.train()
        train_loss_sum = 0.0
        n_train_samples = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}", leave=False)
        for x, y in pbar:
            x = x.to(DEVICE, non_blocking=True)
            y = y.to(DEVICE, non_blocking=True)

            optimizer.zero_grad()
            recon = model(x)
            loss = criterion(recon, y)
            loss.backward()
            optimizer.step()

            batch_size = x.size(0)
            train_loss_sum += loss.item() * batch_size
            n_train_samples += batch_size

            running = train_loss_sum / n_train_samples
            pbar.set_postfix(train_mse=f"{running:.5f}")

        train_loss = train_loss_sum / len(train_loader.dataset)

        # Validate
        model.eval()
        val_loss_sum = 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(DEVICE, non_blocking=True)
                y = y.to(DEVICE, non_blocking=True)
                recon = model(x)
                loss = criterion(recon, y)
                val_loss_sum += loss.item() * x.size(0)

        val_loss = val_loss_sum / len(val_loader.dataset)

        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        print(f"[Epoch {epoch:03d}/{NUM_EPOCHS}] Train MSE: {train_loss:.6f} | Val MSE: {val_loss:.6f}", flush=True)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_triplet_autoencoder.pt")

    # Save training history
    hist_df = pd.DataFrame(history)
    hist_df.to_csv("triplet_autoencoder_train_history.csv", index=False)
    print("\nTraining history saved to: triplet_autoencoder_train_history.csv")

    # Final test evaluation
    model.load_state_dict(torch.load("best_triplet_autoencoder.pt", map_location=DEVICE))
    model.eval()
    test_loss_sum = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(DEVICE, non_blocking=True)
            y = y.to(DEVICE, non_blocking=True)
            recon = model(x)
            loss = criterion(recon, y)
            test_loss_sum += loss.item() * x.size(0)

    test_loss = test_loss_sum / len(test_loader.dataset)
    print(f"\nFinal test MSE (best model): {test_loss:.6f}")


if __name__ == "__main__":
    train_autoencoder()

# Evaluate the test set in PCA space


In [None]:
def analyze_ae_pca_anomalies_test_only():
    """
    Run AE and 3D PCA anomaly analysis only on the test split.
    """
    # Build full triplet table and recover test split
    df_triplets = build_triplet_table(MANIFEST_CSV)
    train_df, val_df, test_df = make_splits_by_object(df_triplets, RANDOM_SEED)

    test_df = test_df.reset_index(drop=True)
    N = len(test_df)

    # Dataset & loader
    test_ds = TripletCutoutDataset(test_df, image_size=IMAGE_SIZE)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=0)

    # Load trained AE
    model = ConvAutoencoder(image_size=IMAGE_SIZE).to(DEVICE)
    model.load_state_dict(torch.load("best_triplet_autoencoder.pt", map_location=DEVICE))
    model.eval()

    # AE forward pass on test set
    all_X_flat = []
    all_ae_mse = []

    with torch.no_grad():
        for x, _ in tqdm(test_loader, desc="AE forward on test", leave=False):
            x = x.to(DEVICE, non_blocking=True)

            recon = model(x)

            # Per-sample MSE 
            loss_per_pixel = F.mse_loss(recon, x, reduction="none")
            loss_per_sample = loss_per_pixel.view(loss_per_pixel.size(0), -1).mean(dim=1)
            all_ae_mse.append(loss_per_sample.cpu().numpy())

            # Flatten inputs for PCA
            x_cpu = x.cpu().numpy()
            x_flat = x_cpu.reshape(x_cpu.shape[0], -1)
            all_X_flat.append(x_flat)

    X = np.concatenate(all_X_flat, axis=0)       
    ae_mse = np.concatenate(all_ae_mse, axis=0)  
    assert X.shape[0] == N == ae_mse.shape[0]

    print("Finished AE forward pass on test set.")
    print(f"X shape for PCA (test only): {X.shape}")
    print(f"AE MSE range (test): {ae_mse.min():.4e} – {ae_mse.max():.4e}")

    # 3D PCA on flattened test inputs
    pca = PCA(n_components=3, random_state=RANDOM_SEED)
    pcs_3d = pca.fit_transform(X) 

    # PCA reconstruction error in the original feature space
    X_recon = pca.inverse_transform(pcs_3d)
    pca_mse = np.mean((X - X_recon) ** 2, axis=1)


    # Define top 1% anomalies for AE and PCA
    ae_thresh = np.quantile(ae_mse, 0.99)
    pca_thresh = np.quantile(pca_mse, 0.99)

    is_ae_top35   = ae_mse >= ae_thresh
    is_pca_top1  = pca_mse >= pca_thresh
    is_both_top1 = is_ae_top35 & is_pca_top1

    ae_only_top1  = is_ae_top35 & ~is_pca_top1
    pca_only_top1 = is_pca_top1 & ~is_ae_top35
    normal_mask   = ~(is_ae_top35 | is_pca_top1)


    # Save test anomaly table
    df_out = test_df.copy()
    df_out["ae_mse"] = ae_mse
    df_out["pca_mse"] = pca_mse
    df_out["is_ae_top35"] = is_ae_top35
    df_out["is_pca_top1"] = is_pca_top1
    df_out["is_both_top1"] = is_both_top1

    df_out.to_csv("triplet_ae_pca_anomalies_test.csv", index=False)
    print("Saved test anomaly table to triplet_ae_pca_anomalies_test.csv")

    # 3D Plotly scatter for test set only
    fig = go.Figure()

    def add_trace(mask, name, color, opacity=0.9):
        if np.any(mask):
            fig.add_trace(
                go.Scatter3d(
                    x=pcs_3d[mask, 0],
                    y=pcs_3d[mask, 1],
                    z=pcs_3d[mask, 2],
                    mode="markers",
                    name=name,
                    marker=dict(
                        size=6,
                        color=color,
                        opacity=opacity,
                    ),
                )
            )

    add_trace(normal_mask,   "Normal (non-top-1%, test)", "lightblue", opacity=0.4)
    add_trace(ae_only_top1,  "AE top 1% only (test)",      "red",       opacity=0.9)
    add_trace(pca_only_top1, "PCA top 1% only (test)",     "orange",    opacity=0.9)
    add_trace(is_both_top1,  "Top 1% in AE & PCA (test)",  "yellow",    opacity=1.0)

    fig.update_layout(
        title="3D PCA of ZTF triplet cutouts (TEST SET ONLY)<br>"
              "Top 1% anomalies from AE and PCA",
        scene=dict(
            xaxis_title="PC1",
            yaxis_title="PC2",
            zaxis_title="PC3",
        ),
        legend=dict(itemsizing="constant"),
        margin=dict(l=0, r=0, b=0, t=40),
    )

    html_path = "triplet_ae_pca_3d_anomalies_test.html"
    fig.write_html(html_path, auto_open=True)


if __name__ == "__main__":
    analyze_ae_pca_anomalies_test_only()

## Top 10 Anomalies from the test set 

In [None]:
def export_top_ae_anomalies_training_rows_test(
    training_csv_path: str = "training_data.csv",
    output_csv_path: str = "tripletimage_AE_top10anomalies.csv",
):
    """
    1) Compute AE reconstruction loss on the test split only.
    2) Select top 3.5% AE anomalies within the test set.
    3) Compute object-level AE loss for those anomalies (aggregated over selected triplets).
    4) Match their objectId values to `training_data.csv`.
    5) Save all rows from training_data.csv whose objectId is in that
       top-3.5% AE anomaly set (test-only) to `output_csv_path`,
       Including the reconstruction loss values.
    """

    # Build full triplet table and recover test split
    df_triplets = build_triplet_table(MANIFEST_CSV)
    train_df, val_df, test_df = make_splits_by_object(df_triplets, RANDOM_SEED)

    test_df = test_df.reset_index(drop=True)
    N = len(test_df)
    print(f"Number of test triplets: {N}")

    # Dataset and loader 
    test_ds = TripletCutoutDataset(test_df, image_size=IMAGE_SIZE)
    test_loader = DataLoader(
        test_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
    )

    # Load trained AE
    model = ConvAutoencoder(image_size=IMAGE_SIZE).to(DEVICE)
    model.load_state_dict(torch.load("best_triplet_autoencoder.pt", map_location=DEVICE))
    model.eval()

    # AE forward pass on TEST set to get per-sample MSE
    all_ae_mse = []

    with torch.no_grad():
        for x, _ in tqdm(test_loader, desc="AE forward on TEST (for losses)", leave=False):
            x = x.to(DEVICE, non_blocking=True)
            recon = model(x)

            # Per-sample MSE = mean over channels and pixels
            loss_per_pixel = F.mse_loss(recon, x, reduction="none")
            loss_per_sample = loss_per_pixel.view(loss_per_pixel.size(0), -1).mean(dim=1)
            all_ae_mse.append(loss_per_sample.cpu().numpy())

    ae_mse = np.concatenate(all_ae_mse, axis=0)  # shape (N,)
    assert ae_mse.shape[0] == N
    print(f"AE MSE range (TEST): {ae_mse.min():.4e} – {ae_mse.max():.4e}")

    # Identify top 3.5% AE anomalies in the TEST set
    ae_thresh = np.quantile(ae_mse, 0.965)
    is_ae_top35 = ae_mse >= ae_thresh
    idx_top = np.where(is_ae_top35)[0]

    print(f"Top 3.5% AE anomalies (TEST): {len(idx_top)} / {N}")
    if len(idx_top) == 0:
        print("No top 3.5% anomalies found in TEST (unexpected). Exiting.")
        return

    test_with_loss = test_df.copy()
    test_with_loss["ae_mse"] = ae_mse
    test_top = test_with_loss.loc[is_ae_top35].copy()

    # Object-level aggregation
    obj_loss_summary = (
        test_top.groupby(test_top["objectId"].astype(str))["ae_mse"]
        .agg(
            ae_mse_mean_top35_test="mean",
            ae_mse_max_top35_test="max",
            ae_mse_median_top35_test="median",
            n_triplets_top35_test="count",
        )
        .reset_index()
        .rename(columns={"objectId": "objectId_str"})
    )

    top_obj_ids = obj_loss_summary["objectId_str"].unique()


    # Load training_data.csv and match on objectId
    train_df_full = pd.read_csv(training_csv_path)
    print(f"Loaded training data from {training_csv_path} with {len(train_df_full)} rows")

    # Compare objectId as string on both sides
    train_df_full["objectId_str"] = train_df_full["objectId"].astype(str)

    matched_df = train_df_full[train_df_full["objectId_str"].isin(top_obj_ids)].copy()
    print(f"Rows in training_data.csv matching TEST AE top-3.5% objectIds: {len(matched_df)}")

    missing_ids = set(top_obj_ids) - set(train_df_full["objectId_str"].unique())
    if missing_ids:
        print(f"Warning: {len(missing_ids)} AE-top-3.5% TEST objectIds not found in training_data.csv")

    matched_df = matched_df.merge(
        obj_loss_summary,
        on="objectId_str",
        how="left",
        validate="many_to_one",
    )

    matched_df.drop(columns=["objectId_str"], inplace=True)

    matched_df.to_csv(output_csv_path, index=False)
    print(f"Saved test AE top 10 anomaly rows (with AE losses) to: {output_csv_path}")

if __name__ == "__main__":
    export_top_ae_anomalies_training_rows_test()