# 03b_cnn_train_mps — Train Two‑Branch 1D‑CNN (PyTorch + MPS)
_Generated 2025-10-04_

In [None]:
import os, json, numpy as np, torch
from torch.utils.data import random_split, DataLoader
from app.models.cnn1d import make_model
from app.data.fold import Item, LightCurveViewsDataset
from app.trainers.cnn1d_trainer import train
from app.calibration.calibrate import run_and_save
device = 'mps' if torch.backends.mps.is_available() else 'cpu'; device

In [None]:
def synth_curve(n=2048, period=2.5, depth=0.0015, duration=0.08, t0=0.5, noise=0.0005, seed=0):
    rng = np.random.default_rng(seed)
    t = np.linspace(0, period*5, n)
    y = np.ones_like(t)
    phase = ((t - t0) / period) % 1.0
    in_tr = (phase < duration/period)
    y[in_tr] -= depth
    y += rng.normal(0, noise, size=y.shape)
    return t, y
items = []
for i in range(200):
    lbl = 1 if i%2==0 else 0
    if lbl:
        t,y = synth_curve(seed=i)
    else:
        t = np.linspace(0, 12.5, 2048); y = 1 + np.random.default_rng(i).normal(0,0.001,size=t.shape)
    items.append(Item(time=t, flux=y, period=2.5, t0=0.5, duration=0.08, label=lbl))
ds = LightCurveViewsDataset(items) ; len(ds)

In [None]:
n_total=len(ds); n_train=int(n_total*0.8); n_val=n_total-n_train
train_ds, val_ds = random_split(ds,[n_train,n_val])
model = make_model()
metrics = train(model, train_ds, val_ds, device=device, batch_size=64, lr=1e-3, max_epochs=10, patience=4, workdir='.')
metrics

In [None]:
import numpy as np, torch
from torch.utils.data import DataLoader
from app.models.cnn1d import make_model
from app.calibration.calibrate import run_and_save
mdl = make_model(); mdl.load_state_dict(torch.load('artifacts/cnn1d.pt', map_location=device)); mdl.to(device); mdl.eval()
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
probs=[]; ys=[]
with torch.no_grad():
    for G,L,Y in val_loader:
        G=torch.tensor(G,dtype=torch.float32,device=device)
        L=torch.tensor(L,dtype=torch.float32,device=device)
        logits=mdl(G,L).squeeze(1)
        probs.append(torch.sigmoid(logits).cpu().numpy()); ys.append(Y)
probs=np.concatenate(probs); ys=np.concatenate(ys)
cal_info = run_and_save(ys, probs, out_dir='artifacts', method='isotonic')
cal_info