In [None]:
#### ANOMALY IMPUTATION USING DIFFUSION MODEL ####


In [None]:
import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# ------------------------------
# 0) MOUNT + PATHS
# ------------------------------
from google.colab import drive
drive.mount('/content/drive')

ANOM_FOLDER  = "AnomalyDiffusion"
BASE_IN_DIR  = "/content/drive/MyDrive/Masters_IndependentStudy/Original_Datasets_Fridge"
OUT_DIR      = f"/content/drive/MyDrive/Masters_IndependentStudy/{ANOM_FOLDER}"
os.makedirs(OUT_DIR, exist_ok=True)

TRAIN_RESIDENCE = "REFIT_House01"
TRAIN_CSV_PATH  = os.path.join(BASE_IN_DIR, f"{TRAIN_RESIDENCE}_Fridge_15minutes.csv")

# ------------------------------
# 1) RESIDENCES + ANOMALY DATES
# ------------------------------
RESIDENCES = [
    "REFIT_House01",
    "REFIT_House02",
    "REFIT_House03",
    "REFIT_House05",
    "REFIT_House07",
    "REFIT_House09",
    "REFIT_House15",
    "UKDALE_House01",
    "UKDALE_House02",
    "UKDALE_House05",
    "AMPds2_House01",
    "GREEND_House00",
    "GREEND_House01",
    "GREEND_House03",
]

ANOMALY_DATES = {
    "REFIT_House01":  "2015-03-09",
    "REFIT_House02":  "2015-01-26",
    "REFIT_House03":  "2015-02-02",
    "REFIT_House05":  "2015-03-02",
    "REFIT_House07":  "2015-03-09",
    "REFIT_House09":  "2015-03-23",
    "REFIT_House15":  "2015-03-16",
    "UKDALE_House01": "2016-06-06",
    "UKDALE_House02": "2013-09-16",
    "UKDALE_House05": "2014-10-20",
    "AMPds2_House01": "2013-11-11",
    "GREEND_House00": "2014-08-18",
    "GREEND_House01": "2014-09-15",
    "GREEND_House03": "2014-09-22",
}

SAVE_ANOM_TYPES = [
    "stepchange",
    "multistepchange",
    "mirror",
    "repeating",
    "stuckmax",
    "stuckmin",
    "powercycling",
]

# ------------------------------
# 2) SETTINGS
# ------------------------------
NUM_STEPS   = 80
NUM_EPOCHS  = 50
BASE_CH     = 64

K_NORMAL       = 220
PER_TYPE_HARD  = 520
PER_TYPE_EASY  = 320

BATCH_SIZE  = 256
NUM_WORKERS = 2

LAMBDA_X0   = 0.5
LAMBDA_GRAD = 0.55

CFG_DROP_PROB  = 0.15
CFG_GUIDE_W    = 3.0
MIRROR_GUIDE_W = 3.0

PRINT_EVERY_BATCHES = 25

# Plot saving settings
PLOT_DPI_FULL = 200
PLOT_DPI_ZOOM = 220

# ------------------------------
# 3) DEVICE
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------------
# 4) CLASSES
# ------------------------------
ANOM_TYPES = {
    "normal":          0,
    "stepchange":      1,
    "multistepchange": 2,
    "mirror":          3,
    "repeating":       4,
    "stuckmax":        5,
    "stuckmin":        6,
    "powercycling":    7,
}
NUM_ANOM_TYPES = len(ANOM_TYPES)
HARD_TYPES = {"stepchange", "multistepchange", "mirror", "repeating"}

# ------------------------------
# 5) COLUMN DETECTION + LOADING
# ------------------------------
def detect_columns(df: pd.DataFrame):
    # ts = timestamp, p = active_power
    ts_candidates = []
    for c in df.columns:
        cl = c.lower()
        if "time" in cl or "date" in cl or "stamp" in cl:
            ts_candidates.append(c)

    p_candidates = []
    for c in df.columns:
        cl = c.lower()
        if "active" in cl and ("power" in cl or "p_" in cl or cl.endswith("p")):
            p_candidates.append(c)
        elif cl in ["active_power", "p", "p_active"]:
            p_candidates.append(c)

    ts_col = ts_candidates[0] if ts_candidates else df.columns[0]
    p_col  = p_candidates[0] if p_candidates else None

    if p_col is None:
        numeric_cols = []
        for c in df.columns:
            if c == ts_col:
                continue
            if pd.api.types.is_numeric_dtype(df[c]):
                numeric_cols.append(c)
        if not numeric_cols:
            raise ValueError("Could not find a numeric power column. Rename your active power column to 'active_power'.")
        p_col = numeric_cols[0]

    return ts_col, p_col

def load_two_cols(csv_path: str):
    df = pd.read_csv(csv_path)
    ts_col, p_col = detect_columns(df)
    df = df[[ts_col, p_col]].copy()
    df.rename(columns={ts_col: "timestamp", p_col: "active_power"}, inplace=True)
    df["timestamp"] = pd.to_datetime(df["timestamp"], errors="coerce")
    df["active_power"] = pd.to_numeric(df["active_power"], errors="coerce")
    df = df.dropna(subset=["timestamp", "active_power"]).sort_values("timestamp").reset_index(drop=True)
    return df, ts_col, p_col

# ------------------------------
# 6) BUILD WINDOWS OF FULL DAYS BASED ON 15MIN INTERVALS (96/day)
# ------------------------------
def build_daily_windows(df: pd.DataFrame, expected_per_day=96):
    df2 = df.copy()
    df2["date"] = df2["timestamp"].dt.date
    days_raw, days_ts, day_keys = [], [], []
    for d, g in df2.groupby("date"):
        g = g.sort_values("timestamp")
        if len(g) != expected_per_day:
            continue
        days_raw.append(g["active_power"].to_numpy(dtype=np.float32))
        days_ts.append(g["timestamp"].to_numpy())
        day_keys.append(str(d))
    if len(days_raw) == 0:
        return None, None, None
    return np.stack(days_raw), np.stack(days_ts), day_keys

# ------------------------------
# 7) SCALE EACH DAY TO [0,1] (Z-Normalization)
# ------------------------------
def scale_day_01(day: np.ndarray):
    dmin = float(np.min(day))
    dmax = float(np.max(day))
    if dmax - dmin < 1e-6:
        return np.zeros_like(day, dtype=np.float32), dmin, dmax
    return ((day - dmin) / (dmax - dmin)).astype(np.float32), dmin, dmax

def scale_days_01(days_raw: np.ndarray):
    days_01 = []
    meta = []
    for d in days_raw:
        sd, dmin, dmax = scale_day_01(d)
        days_01.append(sd)
        meta.append((dmin, dmax))
    return np.stack(days_01).astype(np.float32), np.array(meta, dtype=np.float32)

# ------------------------------
# 8) ANOMALY REGION
# - located at the middle third
# - add jitter
# ------------------------------
def anomaly_region_fixed(L=96):
    s = L // 3
    e = 2 * L // 3
    return s, e

def anomaly_region_jitter(L=96, jitter=6, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    base_s, base_e = anomaly_region_fixed(L)
    base_len = base_e - base_s
    ds = int(rng.integers(-jitter, jitter + 1))
    s = int(np.clip(base_s + ds, 0, L - 2))
    e = s + base_len
    if e > L:
        e = L
        s = e - base_len
    s = int(np.clip(s, 0, L - 2))
    e = int(np.clip(e, s + 2, L))
    return s, e

# ------------------------------
# 9) TEMPLATE INJECTORS
# These are the anomalies: StepChange, MultiStepChange, Mirror, Repeating, StuckMAX, StuckMIN, PowerCycling.
# ------------------------------
def inject_stepchange(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    mid = int(rng.integers(s + 2, e - 2))
    low1 = float(rng.uniform(0.10, 0.35))
    high = float(rng.uniform(0.80, 0.98))
    out[s:mid] = np.clip(low1 + rng.uniform(-0.02, 0.02), 0.0, 1.0)
    out[mid:e] = np.clip(high + rng.uniform(-0.02, 0.02), 0.0, 1.0)
    return out.astype(np.float32), s, e

def inject_multistepchange(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    seg = e - s
    cuts = np.sort(rng.choice(np.arange(1, seg-1), size=3, replace=False))
    edges = np.concatenate([[0], cuts, [seg]]).astype(int)

    base = float(rng.uniform(0.08, 0.25))
    incs = rng.uniform(0.15, 0.28, size=3).astype(np.float32)
    levels = [base,
              base + incs[0],
              base + incs[0] + incs[1],
              base + incs[0] + incs[1] + incs[2]]

    mx = max(levels)
    if mx > 0.98:
        scale = 0.98 / mx
        levels = [l * scale for l in levels]
    levels = [float(np.clip(l + rng.uniform(-0.015, 0.015), 0.0, 1.0)) for l in levels]

    for k in range(4):
        a = s + edges[k]
        b = s + edges[k+1]
        out[a:b] = levels[k]
    return out.astype(np.float32), s, e

def inject_mirror(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    seg = out[s:e].copy()
    out[s:e] = np.clip(seg[::-1], 0.0, 1.0)
    return out.astype(np.float32), s, e

def inject_repeating(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    seg = e - s
    repeats = int(rng.integers(4, 8))
    edges = np.linspace(0, seg, repeats + 1).round().astype(int)
    center = float(rng.uniform(0.45, 0.55))
    amp    = float(rng.uniform(0.35, 0.48))
    phase  = float(rng.uniform(0, 2*np.pi))
    for i in range(repeats):
        a = edges[i]
        b = edges[i+1]
        n = b - a
        if n <= 0:
            continue
        t = np.linspace(0, 2*np.pi, n, endpoint=False).astype(np.float32)
        wave = np.sin(t + phase)
        out[s+a:s+b] = np.clip(center + amp * wave, 0.0, 1.0)
    return out.astype(np.float32), s, e

def inject_stuckmax(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    out[s:e] = 1.0
    return out.astype(np.float32), s, e

def inject_stuckmin(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    out[s:e] = 0.0
    return out.astype(np.float32), s, e

def inject_powercycling(day01: np.ndarray, rng, s, e):
    out = day01.copy()
    seg = e - s
    high = float(rng.uniform(0.90, 1.00))
    low  = float(rng.uniform(0.00, 0.10))
    cycles = int(rng.integers(4, 11))
    period = max(2, seg // cycles)
    half   = max(1, int(period * rng.uniform(0.35, 0.65)))
    for i in range(seg):
        out[s + i] = high if ((i // half) % 2 == 0) else low
    return out.astype(np.float32), s, e

INJECT_FUNCS = {
    "stepchange":      inject_stepchange,
    "multistepchange": inject_multistepchange,
    "mirror":          inject_mirror,
    "repeating":       inject_repeating,
    "stuckmax":        inject_stuckmax,
    "stuckmin":        inject_stuckmin,
    "powercycling":    inject_powercycling,
}

# ------------------------------
# 10) DIFFUSION SCHEDULE (the Beta value)
# ------------------------------
def make_schedule(num_steps, device):
    betas  = torch.linspace(1e-4, 0.02, num_steps, device=device)
    alphas = 1.0 - betas
    abar   = torch.cumprod(alphas, dim=0)
    return betas, alphas, abar

# ------------------------------
# 11) MODEL - ARCHITECTURE
# ------------------------------
# This function turns timestep into a smooth 64-D signal.
# This gives a sense of where it is in the noise-adding / noise-removing process.
# early â†’ lots of noise
# middle â†’ some noise
# late â†’ very little noise
def time_embed(t, dim=64):
    half = dim // 2
    freqs = torch.exp(-math.log(10000) / (half - 1) * torch.arange(half, device=t.device))
    args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
    return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

# We adopt a UNet1D based model that basically is encoders and decoders and skip connections
class UNet1D(nn.Module):
    def __init__(self, base=64, in_ch=2 + NUM_ANOM_TYPES):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(64, base), nn.SiLU(), nn.Linear(base, base))
        self.c1  = nn.Conv1d(in_ch, base, 3, padding=1)
        self.c2  = nn.Conv1d(base, base*2, 3, padding=1)
        self.c3  = nn.Conv1d(base*2, base*4, 3, padding=1)
        self.d1  = nn.Conv1d(base*4, base*2, 3, padding=1)
        self.d2  = nn.Conv1d(base*2, base, 3, padding=1)
        self.out = nn.Conv1d(base, 1, 3, padding=1)
        self.act = nn.SiLU()

    def forward(self, x, t):
        emb = self.mlp(time_embed(t)).unsqueeze(-1)
        h1 = self.act(self.c1(x)) + emb
        h2 = self.act(self.c2(h1))
        h3 = self.act(self.c3(h2))
        d1 = self.act(self.d1(h3))
        d2 = self.act(self.d2(d1 + h2))
        return self.out(d2 + h1)

# ------------------------------
# 12) DATASET + UTILS
# ------------------------------
# Packages x, mask, condition maps.
class DiffDataset(Dataset):
    def __init__(self, X, M, C):
        self.X = X
        self.M = M
        self.C = C
    def __len__(self): return len(self.X)
    def __getitem__(self, i):
        x = torch.from_numpy(self.X[i]).float().unsqueeze(0)   # (1,96)
        m = torch.from_numpy(self.M[i]).float().unsqueeze(0)   # (1,96)
        c = torch.from_numpy(self.C[i]).float()                # (K,96)
        return x, m, c

# mask for anomaly type
def make_c_map_mask_gated(cls_id: int, mask_1d: np.ndarray):
    c_map = np.zeros((NUM_ANOM_TYPES, 96), dtype=np.float32)
    c_map[cls_id, :] = mask_1d.astype(np.float32)
    return c_map

# changes of signals within the masked anomaly region
# dx finds changes, mm filters them to anomaly-only change locations.
def masked_grad(x, m):
    dx = x[:, :, 1:] - x[:, :, :-1]
    mm = m[:, :, 1:] * m[:, :, :-1]
    return dx, mm

# ------------------------------
# 13) HARD MIRROR CONSTRAINT - this is necessary because Mirror is not easily derived using the diffusion process.
# - Hard mirror constraint (non-diffusive anomaly)
# ------------------------------
@torch.no_grad()
def apply_hard_mirror(xt, x_clean01, s, e):
    xt[:, :, s:e] = torch.flip(x_clean01[:, :, s:e], dims=[-1])
    return xt

# ------------------------------
# 14) INFERENCE: DDPM + Conditioning (Anomaly Classes)
# ------------------------------
@torch.no_grad()
def impute_localized_ddpm_cfg(model, x_clean01, cls_id, s, e, betas, alphas, abar, num_steps, guide_w=3.0):
    model.eval()
    B, C, T = x_clean01.shape

    m = torch.zeros((B, 1, T), device=device, dtype=torch.float32)
    m[:, :, s:e] = 1.0

    c_cond = torch.zeros((B, NUM_ANOM_TYPES, T), device=device, dtype=torch.float32)
    c_cond[:, cls_id, :] = m[:, 0, :]
    c_uncond = torch.zeros_like(c_cond)

    # mirror is a special case that needs to be hardcoded
    if cls_id == ANOM_TYPES["mirror"]:
        xt = x_clean01.clone()
        xt = apply_hard_mirror(xt, x_clean01, s, e)
        xt = xt * (1 - m) + (xt + 0.05 * torch.randn_like(xt)).clamp(0, 1) * m
    else:
        xt = x_clean01 * (1 - m) + torch.randn_like(x_clean01) * m

    # Since it is INFERENCE, we start from the last timestep and return to the initial.
    for i in reversed(range(num_steps)):
        t = torch.full((B,), i, device=device, dtype=torch.long)

        inp_c = torch.cat([xt, m, c_cond], dim=1)
        inp_u = torch.cat([xt, m, c_uncond], dim=1)

        # Classifier-free guidance (conditional diffusion)
        # Trains the diffusion model to handle both conditional and unconditional (class or non-class)
        # It randomly drops the conditioning during training,
        # then at inference steers generation toward a desired anomaly
        # type by amplifying the difference between conditional and
        # unconditional noise predictions.
        eps_c = model(inp_c, t)
        eps_u = model(inp_u, t)
        eps = eps_u + guide_w * (eps_c - eps_u)

        beta_t  = betas[i]
        alpha_t = alphas[i]
        abar_t  = abar[i]

        # Reverse diffusion step (inference / imputation)
        mean = (1 / torch.sqrt(alpha_t)) * (xt - (beta_t / torch.sqrt(1 - abar_t + 1e-8)) * eps)
        if i > 0:
            z = torch.randn_like(xt)
            xt = mean + torch.sqrt(beta_t) * z
        else:
            xt = mean

        # Locality constraint (if mirror then hard conditioning)
        xt = xt * m + x_clean01 * (1 - m)
        if cls_id == ANOM_TYPES["mirror"]:
            xt = apply_hard_mirror(xt, x_clean01, s, e)

    return xt

# ------------------------------
# 15) PLOT SAVING
# ------------------------------
def save_plots(residence, anomaly_type, target_date, s, e,
               orig_raw, gen_raw, out_dir, folder_tag):

    base = f"{residence}_Fridge_{anomaly_type}_{folder_tag}"

    # -------- FULL DAY PLOT --------
    out_full = os.path.join(out_dir, f"{base}_FULL.png")
    plt.figure(figsize=(12, 4))
    plt.plot(orig_raw, label="original (raw units)", alpha=0.75)
    plt.plot(gen_raw,  label=f"generated {anomaly_type} (raw units)", alpha=0.95)
    plt.axvspan(s, e, color="gray", alpha=0.18)
    plt.title(f"{base} | day={target_date}")
    plt.legend(loc="best")
    plt.tight_layout()
    plt.savefig(out_full, dpi=PLOT_DPI_FULL)
    plt.show()          # ðŸ‘ˆ DISPLAY
    plt.close()

    # -------- ZOOM PLOT --------
    out_zoom = os.path.join(out_dir, f"{base}_ZOOM.png")
    plt.figure(figsize=(12, 3))
    plt.plot(orig_raw[s:e], label="original zoom", alpha=0.75)
    plt.plot(gen_raw[s:e],  label=f"{anomaly_type} zoom", alpha=0.95)
    plt.title(f"{base}_ZOOM | [{s}:{e}] | {target_date}")
    plt.legend(loc="best")
    plt.tight_layout()
    plt.savefig(out_zoom, dpi=PLOT_DPI_ZOOM)
    plt.show()          # ðŸ‘ˆ DISPLAY
    plt.close()

    return out_full, out_zoom

# ============================================================
# A) TRAIN ON REFIT_House01 ONLY - All other datasets use only this trained model.
# ============================================================
if not os.path.exists(TRAIN_CSV_PATH):
    raise FileNotFoundError(f"Training file not found: {TRAIN_CSV_PATH}")

print("\n" + "="*80)
print("TRAINING ONLY ON:", TRAIN_CSV_PATH)
print("="*80)

train_df, ts_col, p_col = load_two_cols(TRAIN_CSV_PATH)
print("Detected columns -> timestamp:", ts_col, " | active_power:", p_col)
print("Rows:", len(train_df), "Range:", train_df["timestamp"].min(), "to", train_df["timestamp"].max())

train_days_raw, _, _ = build_daily_windows(train_df, expected_per_day=96)
if train_days_raw is None or len(train_days_raw) < 10:
    raise ValueError(f"Too few full 96-sample days found in {TRAIN_RESIDENCE}. Found: {0 if train_days_raw is None else len(train_days_raw)}")

train_days_01, _ = scale_days_01(train_days_raw)
print("Usable training days:", len(train_days_01))

# Build training set
rng = np.random.default_rng(0)
X_list, M_list, C_list = [], [], []

# NORMAL - This code creates masked normal examples so the model learns that masked regions can contain no anomaly at all.
idxs = rng.choice(len(train_days_01), size=K_NORMAL, replace=True)
for idx in idxs:
    x = train_days_01[idx].copy()
    s_j, e_j = anomaly_region_jitter(96, jitter=6, rng=rng)
    mask = np.zeros(96, dtype=np.float32); mask[s_j:e_j] = 1.0
    X_list.append(x.astype(np.float32))
    M_list.append(mask)
    C_list.append(make_c_map_mask_gated(ANOM_TYPES["normal"], mask))

# ANOMALIES - generates training examples with actual injected anomalies, one anomaly type at a time
for name, cls_id in ANOM_TYPES.items():
    if name == "normal":
        continue
    k = PER_TYPE_HARD if name in HARD_TYPES else PER_TYPE_EASY
    idxs = rng.choice(len(train_days_01), size=k, replace=True)
    for idx in idxs:
        s_j, e_j = anomaly_region_jitter(96, jitter=6, rng=rng)

        ### Anomalies are Inserted here
        inj, s_j, e_j = INJECT_FUNCS[name](train_days_01[idx], rng, s_j, e_j)
        mask = np.zeros(96, dtype=np.float32); mask[s_j:e_j] = 1.0
        X_list.append(inj.astype(np.float32))
        M_list.append(mask)
        C_list.append(make_c_map_mask_gated(cls_id, mask))

X_np = np.stack(X_list).astype(np.float32)
M_np = np.stack(M_list).astype(np.float32)
C_np = np.stack(C_list).astype(np.float32)
print("Train shapes:", X_np.shape, M_np.shape, C_np.shape)

# Create the dataset paramters
dataset = DiffDataset(X_np, M_np, C_np)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type == "cuda"),
)

# Create the UNet1D architecture
model = UNet1D(base=BASE_CH).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
betas, alphas, abar = make_schedule(NUM_STEPS, device=device)

# Loop through epochs and calculate the losses
for epoch in range(NUM_EPOCHS):
    model.train()
    total = 0.0
    for bi, (x0, m, c_map) in enumerate(dataloader, start=1):
        x0 = x0.to(device, dtype=torch.float32)
        m  = m.to(device, dtype=torch.float32)
        c_map = c_map.to(device, dtype=torch.float32)

        if CFG_DROP_PROB > 0:
            drop = (torch.rand((x0.size(0), 1, 1), device=device) < CFG_DROP_PROB).float()
            c_map = c_map * (1.0 - drop)

        B = x0.size(0)

        ### FORWARD DIFFUSION PROCESS
        t = torch.randint(0, NUM_STEPS, (B,), device=device, dtype=torch.long)
        eps = torch.randn_like(x0)
        abar_t = abar[t].view(B, 1, 1)
        xt = torch.sqrt(abar_t) * x0 + torch.sqrt(1 - abar_t) * eps

        ### Noise prediction objective (core DDPM loss)
        inp = torch.cat([xt, m, c_map], dim=1)
        eps_pred = model(inp, t)
        loss_eps = (((eps_pred - eps) ** 2) * m).sum() / (m.sum() + 1e-8)

        ### Clean-signal reconstruction from predicted noise
        x0_pred = (xt - torch.sqrt(1 - abar_t) * eps_pred) / (torch.sqrt(abar_t) + 1e-8)

        ### Masked reconstruction loss (local fidelity)
        loss_x0 = (((x0_pred - x0) ** 2) * m).sum() / (m.sum() + 1e-8)

        ### Gradient-consistency loss (shape realism)
        dxp, mm = masked_grad(x0_pred, m)
        dx0, _  = masked_grad(x0, m)
        loss_g = (torch.abs(dxp - dx0) * mm).sum() / (mm.sum() + 1e-8)

        ### Total training objective (final equation)
        loss = loss_eps + LAMBDA_X0 * loss_x0 + LAMBDA_GRAD * loss_g

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        total += loss.item()

        if (bi % PRINT_EVERY_BATCHES) == 0 or bi == 1:
            print(f"  [Train] Epoch {epoch+1}/{NUM_EPOCHS} | Batch {bi}/{len(dataloader)} | loss={loss.item():.6f}")

    print(f"[Epoch Done] Epoch {epoch+1}/{NUM_EPOCHS} | avg_loss={total/len(dataloader):.6f}")

print("\nTRAINING DONE. Starting inference on all residences...")

# ============================================================
# B) INFERENCE ON ALL RESIDENCES + SAVE CSV + SAVE PLOTS
# ============================================================
s_fix, e_fix = anomaly_region_fixed(96)

# Loop through all residences
for residence in RESIDENCES:
    csv_path = os.path.join(BASE_IN_DIR, f"{residence}_Fridge_15minutes.csv")
    target_date = ANOMALY_DATES.get(residence, None)

    print("\n" + "-"*80)
    print("INFER:", residence, "| date:", target_date)
    print("-"*80)

    if target_date is None:
        print("[SKIP] Missing anomaly date.")
        continue
    if not os.path.exists(csv_path):
        print("[SKIP] Missing file:", csv_path)
        continue

    df, ts_col, p_col = load_two_cols(csv_path)
    print("Detected columns -> timestamp:", ts_col, " | active_power:", p_col)
    print("Rows:", len(df), "Range:", df["timestamp"].min(), "to", df["timestamp"].max())

    # Keep a full copy for "save all timestamps"
    df_full = df.copy()
    df_full["date_str"] = df_full["timestamp"].dt.date.astype(str)

    days_raw, days_ts, day_keys = build_daily_windows(df, expected_per_day=96)
    if days_raw is None or len(days_raw) < 1:
        print("[SKIP] No full 96-sample days found.")
        continue

    if target_date not in day_keys:
        print("[SKIP] Target day not found as a complete 96-sample day.")
        continue

    infer_idx = day_keys.index(target_date)
    infer_ts = days_ts[infer_idx]
    infer_day_raw = days_raw[infer_idx]

    infer_day_01, dmin, dmax = scale_day_01(infer_day_raw)
    x_clean01 = torch.tensor(infer_day_01, dtype=torch.float32, device=device).view(1, 1, 96)

    print(f"[OK] Using fixed mask [{s_fix}:{e_fix}] and saving to {OUT_DIR}")

    # For each anomaly type, run the trained model
    for anomaly_type in SAVE_ANOM_TYPES:
        cls_id = ANOM_TYPES[anomaly_type]
        w = MIRROR_GUIDE_W if anomaly_type == "mirror" else CFG_GUIDE_W

        gen01 = impute_localized_ddpm_cfg(
            model=model,
            x_clean01=x_clean01,
            cls_id=cls_id,
            s=s_fix, e=e_fix,
            betas=betas, alphas=alphas, abar=abar,
            num_steps=NUM_STEPS,
            guide_w=w
        ).detach().cpu().numpy().reshape(-1)

        gen01 = np.clip(gen01, 0.0, 1.0)
        gen = gen01 * (dmax - dmin) + dmin  # unscale to raw units

        # ---- SAVE CSV (ALL timestamps; modify ONLY the anomaly day) ----
        df_out = df_full.copy()
        mask_day = (df_out["date_str"] == target_date)

        if mask_day.sum() != 96:
            print(f"  [WARN] {residence} {anomaly_type}: expected 96 rows on {target_date}, found {mask_day.sum()}")

        # write the generated anomaly day back into the full series
        df_out.loc[mask_day, "active_power"] = gen.astype(np.float32)

        out_csv = os.path.join(OUT_DIR, f"{residence}_Fridge_{anomaly_type}_{ANOM_FOLDER}.csv")
        df_out[["timestamp", "active_power"]].to_csv(out_csv, index=False)

        # ---- SAVE PLOTS (FULL + ZOOM) ----
        out_full, out_zoom = save_plots(
            residence=residence,
            anomaly_type=anomaly_type,
            target_date=target_date,
            s=s_fix, e=e_fix,
            orig_raw=infer_day_raw,
            gen_raw=gen,
            out_dir=OUT_DIR,
            folder_tag=ANOM_FOLDER
        )

        print(f"  [SAVED] {residence} | {anomaly_type}")
        print(f"    CSV : {out_csv}")
        print(f"    FULL: {out_full}")
        print(f"    ZOOM: {out_zoom}")

print("\nALL DONE.")
print("Saved folder:", OUT_DIR)
