In [None]:
# load data
import numpy as np

data = np.load("./")
traces = data['traces']
coords_np = data['xyz']
labels = data['labels']

In [None]:
# func matrix
trial_avg = traces.mean(axis=1)
func_mat_np = np.corrcoef(trial_avg)
print(func_mat_np.shape)

In [None]:
# dist matrix
from scipy.spatial.distance import cdist
dist = cdist(coords_np, coords_np, metric='euclidean')
print(dist)

In [None]:
X_train = traces

In [None]:
Spatial = True

In [None]:
import numpy as np
import torch

from model.train import train_trace_with_custom_pairs

# ============ 0) data preparation ============
def to_float_tensor(x):
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x).float()
    elif isinstance(x, torch.Tensor):
        return x.float()
    else:
        raise TypeError("X_train / X_test must be numpy.ndarray or torch.Tensor")

X_train_t = to_float_tensor(X_train)  # [N, R, T]

N, R, T = X_train_t.shape
print(f"[info] X_train: N={N}, R={R}, T={T}")

# random seed
torch.manual_seed(0)
np.random.seed(0)

# device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[info] device:", device)

# ============ 1) spatial/functional prior ============
func_vec = None         # torch.Tensor [N]

if not Spatial:
    coords = None
    func_mat = None
else:
    coords   = torch.from_numpy(coords_np).float()
    func_mat = torch.from_numpy(func_mat_np).float()

# ============ 2) training ============
k = max(1, R // 2)

epochs     = 1100
batch_size = 200
lr         = 1e-3
proj_mode  = "large"

model, history = train_trace_with_custom_pairs(
    X=X_train_t,             # notice:keep the data on CPUï¼ŒDataLoader will move the batch to device
    T=T,
    coords=coords,           # None or [N,2/3]
    func_vec=None,       # None or [N]
    func_mat=func_mat,       # None or [N,N]
    epochs=epochs,
    batch_size=batch_size,
    lr=lr,
    k=k,
    r_max=0.5,              # spatial constraint
    f_pos_th=0.18,            # functional threshold
    same_sign=True,          # this is only func_vec
    proj_mode=proj_mode,
    device=device,
    grad_clip=5.0,
    log_every=5,
    return_history=True,
)

print(f"[done] training finished. logged {len(history['loss'])} loss points.")

In [None]:
# ============ 3) inference ============
# training set
with torch.no_grad():
    Xm_train = X_train_t.mean(dim=1).to(device)  # [N, T]
    h_train, u_train = model(Xm_train)           # h:[N,128], u:[N,2]
    h_train = h_train.cpu().numpy()
    u_train = u_train.cpu().numpy()

print("[done] embeddings ready:",
      f"h_train {h_train.shape}, u_train {u_train.shape}")

In [None]:
# Save
np.savez("./", emb=u_train, coords=coords_np, labels=labels)