# TDE Classification with Lightcurve-Derived 1-NN
This notebook builds a lightweight nearest-neighbour classifier using summary statistics extracted from each lightcurve.

In [None]:
import csv
import math
from collections import defaultdict
from pathlib import Path

DATA_DIR = Path('Data')
FILTERS = ['u', 'g', 'r', 'i', 'z', 'y']
FEATURE_LEN = 37


In [None]:
def load_metadata(path):
    with path.open(newline='') as handle:
        return list(csv.DictReader(handle))

def load_split_lightcurves(meta_rows, filename):
    grouped = defaultdict(lambda: defaultdict(list))
    for row in meta_rows:
        grouped[row['split']][row['object_id']] = []
    for split, objects in grouped.items():
        with (DATA_DIR / split / filename).open(newline='') as handle:
            reader = csv.DictReader(handle)
            for entry in reader:
                if entry['object_id'] in objects:
                    objects[entry['object_id']].append(entry)
    return grouped

def lightcurve_features(meta_rows, split_data):
    features = {}
    for row in meta_rows:
        object_id = row['object_id']
        lc = split_data[row['split']][object_id]
        fluxes = [float(entry['Flux']) for entry in lc if entry['Flux']]
        times = [float(entry['Time (MJD)']) for entry in lc if entry['Flux']]
        errs = [float(entry['Flux_err']) for entry in lc if entry['Flux_err']]
        n = len(fluxes)
        if n == 0:
            vec = [0.0] * FEATURE_LEN
        else:
            mean_flux = sum(fluxes) / n
            max_flux = max(fluxes)
            min_flux = min(fluxes)
            range_flux = max_flux - min_flux
            pos_frac = sum(1 for v in fluxes if v > 0) / n
            neg_frac = sum(1 for v in fluxes if v < 0) / n
            mean_err = sum(errs) / len(errs) if errs else 0.0
            mean_time = sum(times) / n if n else 0.0
            max_time = times[fluxes.index(max_flux)] if n else 0.0
            min_time = times[fluxes.index(min_flux)] if n else 0.0
            per_filter = []
            for flt in FILTERS:
                flt_fluxes = [float(entry['Flux']) for entry in lc if entry['Filter'] == flt and entry['Flux']]
                count = len(flt_fluxes)
                if count:
                    flt_mean = sum(flt_fluxes) / count
                    flt_max = max(flt_fluxes)
                    flt_min = min(flt_fluxes)
                else:
                    flt_mean = flt_max = flt_min = 0.0
                per_filter.extend([count, flt_mean, flt_max, flt_min])
            vec = [
                n, mean_flux, max_flux, min_flux, range_flux, pos_frac, neg_frac, mean_err,
                mean_time, max_time, min_time,
                float(row['Z']) if row['Z'] else 0.0,
                float(row['EBV']) if row['EBV'] else 0.0
            ] + per_filter
        features[object_id] = vec
    return features

def standardize(vectors):
    length = len(next(iter(vectors.values())))
    means = [0.0] * length
    for vec in vectors.values():
        for idx, val in enumerate(vec):
            means[idx] += val
    n = len(vectors)
    means = [val / n for val in means]
    stds = [0.0] * length
    for vec in vectors.values():
        for idx, val in enumerate(vec):
            diff = val - means[idx]
            stds[idx] += diff * diff
    stds = [math.sqrt(val / n) if val > 0 else 1.0 for val in stds]
    scaled = {obj: [(val - means[idx]) / stds[idx] for idx, val in enumerate(vec)]
              for obj, vec in vectors.items()}
    return scaled, means, stds

def knn_predict(test_vectors, train_vectors, train_labels):
    items = list(train_vectors.items())
    predictions = {}
    for object_id, vec in test_vectors.items():
        best_dist = float('inf')
        best_label = 0
        for train_id, train_vec in items:
            dist = math.dist(vec, train_vec)
            if dist < best_dist:
                best_dist = dist
                best_label = train_labels[train_id]
                if dist == 0:
                    break
        predictions[object_id] = best_label
    return predictions


In [None]:
train_meta = load_metadata(DATA_DIR / 'train_log.csv')
test_meta = load_metadata(DATA_DIR / 'test_log.csv')
train_lookup = {row['object_id']: row for row in train_meta}

train_lightcurves = load_split_lightcurves(train_meta, 'train_full_lightcurves.csv')
test_lightcurves = load_split_lightcurves(test_meta, 'test_full_lightcurves.csv')

train_features = lightcurve_features(train_meta, train_lightcurves)
train_vectors, feature_means, feature_stds = standardize(train_features)
train_labels = {obj: int(train_lookup[obj]['target']) for obj in train_vectors}

test_features = lightcurve_features(test_meta, test_lightcurves)
test_vectors = {obj: [(val - feature_means[idx]) / feature_stds[idx]
                     for idx, val in enumerate(vec)]
                for obj, vec in test_features.items()}

len(train_vectors), len(test_vectors)


In [None]:
# Leave-one-out estimate for the 1-NN model
correct = 0
train_items = list(train_vectors.items())
for object_id, vec in train_items:
    best_dist = float('inf')
    best_label = 0
    for other_id, other_vec in train_items:
        if other_id == object_id:
            continue
        dist = math.dist(vec, other_vec)
        if dist < best_dist:
            best_dist = dist
            best_label = train_labels[other_id]
            if dist == 0:
                break
    if best_label == train_labels[object_id]:
        correct += 1
loo_accuracy = correct / len(train_items)
loo_accuracy


In [None]:
predictions = knn_predict(test_vectors, train_vectors, train_labels)
submission_path = Path('submission.csv')
with submission_path.open('w', newline='') as handle:
    writer = csv.writer(handle)
    writer.writerow(['object_id', 'prediction'])
    for row in test_meta:
        writer.writerow([row['object_id'], predictions[row['object_id']]])
submission_path
