# In this notebook, we apply SMOTE to our filtered databank to later apply to our object feature autoencoder

In [1]:
# All imports 
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import os, io, gzip, glob, random, csv, json, requests
from astropy.io import fits
from imblearn.over_sampling import SMOTE
import lasair

In [2]:
# Load databank
df_str = pd.read_csv("training_data.csv", dtype=str, keep_default_na=False)
df = pd.read_csv("training_data.csv")

FILTER_COLS = [
    "dmdt_g_err","dmdt_r_err",
    "mag_g02","mag_g08","mag_g28",
    "mag_r02","mag_r08","mag_r28",
]

# Coerce to numeric 
for c in FILTER_COLS:
    df[c] = pd.to_numeric(df[c], errors="coerce")

# Kill inf/-inf as well
df[FILTER_COLS] = df[FILTER_COLS].replace([np.inf, -np.inf], np.nan)

# Filter to get rid of non-detections and large errors 
mask = (
    df["dmdt_g_err"].between(-1, 1, inclusive="both") &
    df["dmdt_r_err"].between(-1, 1, inclusive="both") &
    df["mag_g02"].between(0, 30, inclusive="both") &
    df["mag_g08"].between(0, 30, inclusive="both") &
    df["mag_g28"].between(0, 30, inclusive="both") &
    df["mag_r02"].between(0, 30, inclusive="both") &
    df["mag_r08"].between(0, 30, inclusive="both") &
    df["mag_r28"].between(0, 30, inclusive="both")
)

df = df.loc[mask].reset_index(drop=True)
df_str = df_str.loc[mask].reset_index(drop=True)

LABEL_COL = "source_label"
SENTINEL_UNKNOWN = -999
UNK_EPS = 1e-6

# Normalizing labels
def _norm_label(s):
    if pd.isna(s):
        return s
    t = str(s).strip().lower().replace(" ", "")
    if t in {"snia","sn-ia","sn_ia"}: return "SN Ia"
    if t in {"snii","sn-ii","sn_ii","sniip","sniil"}: return "SN II"
    if t in {"snib/c","snibc","sn-ib/c","sn_ib/c"}: return "SN Ib/c"
    if t in {"exotic","other","odd"}: return "Exotic"
    return str(s)
    
y = df[LABEL_COL].apply(_norm_label)
df[LABEL_COL] = y
df_str[LABEL_COL] = y

# Identify numeric columns for SMOTE
candidate_cols = [c for c in df.columns if c != LABEL_COL]
num_cols = []
for c in candidate_cols:
    if pd.api.types.is_numeric_dtype(df[c]):
        num_cols.append(c)
        continue
    conv = pd.to_numeric(df[c], errors="coerce")
    if conv.notna().mean() >= 0.95:
        df[c] = conv
        num_cols.append(c)

if len(num_cols) == 0:
    raise ValueError("No numeric columns detected for SMOTE. Please check the dataset.")


# Build SMOTE feature matrix that preserves -999 as 'unknown'
# For each feature f:
# f__val = numeric with -999 replaced by median (ignoring -999)
# f__unk = 1 if f was -999, else 0

X_num = df[num_cols].copy()

# Track unknowns
unknown_mask = (X_num == SENTINEL_UNKNOWN)

# Compute medians ignoring sentinel and NaNs
X_for_median = X_num.mask(unknown_mask, np.nan)
medians = X_for_median.median(numeric_only=True)

# Build SMOTE matrix
X_smote_parts = []
smote_feature_names = []
for c in num_cols:
    val_col = X_num[c].copy()

    # Replace sentinel with median for SMOTE math; keep real NaNs for now
    fill_val = medians.get(c, np.nan)
    val_col = val_col.mask(unknown_mask[c], fill_val)

    # If a column is all NaN after masking (edge case), fall back to 0.0
    if pd.isna(fill_val):
        val_col = val_col.fillna(0.0)
    # Fill remaining NaNs (SMOTE cannot handle NaN)
    else:
        val_col = val_col.fillna(fill_val)

    unk_col = unknown_mask[c].astype(float)  # 1.0 if unknown else 0.0

    X_smote_parts.append(val_col.to_numpy().reshape(-1, 1))
    smote_feature_names.append(f"{c}__val")

    X_smote_parts.append(unk_col.to_numpy().reshape(-1, 1))
    smote_feature_names.append(f"{c}__unk")

X_smote = np.hstack(X_smote_parts)

# Target class sizes
counts = y.value_counts()
n_target = counts.get("SN Ia", 0)
if n_target == 0:
    raise ValueError("No 'SN Ia' examples found; cannot set target size.")

desired = {}
for cls in ["SN II", "SN Ib/c"]:
    n_cls = counts.get(cls, 0)
    if n_cls == 0:
        continue
    desired[cls] = max(n_target, n_cls)

if not desired:
    out = df_str.copy()
    out["is_synthetic"] = "False"
    out.to_csv("ClassImbalanced_FinalTrainingSet.csv", index=False)
else:
    # k-neighbors feasibility
    k_neighbors = 8
    for cls, tgt in desired.items():
        n_cls = (y == cls).sum()
        if n_cls < (k_neighbors + 1):
            raise ValueError(
                f"Class '{cls}' has only {n_cls} samples; SMOTE with k={k_neighbors} "
                f"requires at least {k_neighbors+1}. Reduce k or gather more samples."
            )

    # Run SMOTE
    smote = SMOTE(
        sampling_strategy=desired,
        k_neighbors=k_neighbors,
        random_state=42
    )
    X_res, y_res = smote.fit_resample(X_smote, y.values)

    # Separate originals vs synthetics
    n_orig = len(df)
    n_res  = len(y_res)
    n_syn  = n_res - n_orig
    if n_syn < 0:
        raise RuntimeError("SMOTE produced fewer samples than originals; unexpected state.")

    originals = df_str.copy()
    originals["is_synthetic"] = "False"

    # Synthetic block
    X_syn = X_res[n_orig:, :]
    syn = pd.DataFrame({col: "" for col in df_str.columns}, index=range(n_syn))
    syn[LABEL_COL] = y_res[n_orig:]
    syn["is_synthetic"] = "True"

    # Preserve original numeric formatting style
    def count_decimals(s):
        if s is None:
            return None
        t = str(s)
        if t == "":
            return None
        if "e" in t.lower():
            if "." in t:
                return len(t.split("e")[0].split(".")[-1])
            return 0
        if "." in t:
            return len(t.split(".")[-1])
        return 0

    typical_decimals = {}
    for c in num_cols:
        dec_counts = []
        for val in df_str[c].tolist():
            if val is None or val == "":
                continue
            try:
                _ = float(str(val).replace(",", ""))
                d = count_decimals(val)
                if d is not None:
                    dec_counts.append(d)
            except Exception:
                pass
        typical_decimals[c] = int(np.median(dec_counts)) if dec_counts else 0

    # Reconstruct synthetic numeric features with sentinel preservation
    for j, c in enumerate(num_cols):
        val_idx = 2 * j
        unk_idx = 2 * j + 1

        vals = X_syn[:, val_idx]
        unks = X_syn[:, unk_idx]

        is_unknown = unks > UNK_EPS

        # Build output strings
        out_str = np.empty(n_syn, dtype=object)
        out_str[is_unknown] = str(SENTINEL_UNKNOWN)

        d = typical_decimals.get(c, 0)
        known_vals = vals[~is_unknown]

        if d <= 0:
            # integer-like formatting for known values
            known_fmt = pd.Series(np.round(known_vals, 0)).astype("Int64").astype(str).to_numpy()
        else:
            fmt = "{:." + str(d) + "f}"
            known_fmt = np.array([fmt.format(v) if pd.notna(v) else "" for v in known_vals], dtype=object)

        out_str[~is_unknown] = known_fmt
        syn[c] = out_str

    # Assemble + write
    out = pd.concat([originals, syn], ignore_index=True)
    out.to_csv("ClassImbalanced_FinalTrainingSet.csv", index=False)