A transformer-based classifier for TESS light curves.
pip install -r requirements.txt
pip install -e .Classify your FITS files using our pre-trained weights:
astrafier predict --fits-dir /path/to/fits --output-dir ./predictionsOutput: predictions.csv with columns tic and probabilities.
Download the paper's training data from HuggingFace (includes TESS and Kepler light curves):
from huggingface_hub import hf_hub_download
for f in ["train.safetensors", "train.json", "test.safetensors", "test.json"]:
hf_hub_download("paulg9/astrafier", f, local_dir="./data")Fine-tune from our checkpoint:
astrafier train --train-pt data/train.safetensors --test-pt data/test.safetensorsOr train from scratch:
astrafier train --train-pt data/train.safetensors --test-pt data/test.safetensors --no-load-from-hfFrom CSV — splitting and preprocessing are handled automatically:
astrafier train --train-csv data.csvYour CSV needs label and path columns. If you include a TIC column, we split by TIC to prevent data leakage (same star in train and test). Without it, we split by row.
| label | path | TIC |
|---|---|---|
| ECLIPSE | /data/tic123.fits | 123 |
| CONTACT_ROT | /data/tic456.fits | 456 |
Save preprocessed tensors for faster reruns:
astrafier train --train-csv data.csv --save-preprocessed ./tensorsFrom tensors — if you have your own preprocessing:
astrafier train --train-pt train.pt --test-pt test.ptExpected .pt structure:
torch.save({
"flux": flux_tensor, # (N, seq_len) float32
"time": time_tensor, # (N, seq_len) float32
"labels": label_tensor, # (N, num_classes) float32 one-hot
"mask": mask_tensor, # (N, seq_len) bool, True = valid
"label_map": {"ECLIPSE": 0, ...},
}, "train.pt")astrafier predict --fits-dir /path/to/fits --checkpoint model.ckpt --no-load-from-hfData source:
--train-csv: From CSV with FITS paths (we split + preprocess)--train-pt/--test-pt: From pre-processed tensors
Model weights:
--load-from-hf: Fine-tune from our checkpoint (default)--no-load-from-hf: Train from scratch
Other flags:
--batch-size: Default 128--max-epochs: Default 250--precision: Use32-trueif bf16 not supported--save-preprocessed: Save tensors when using--train-csv--mc-dropout/--mc-samples: Monte Carlo dropout for uncertainty
Run astrafier train --help for all options.
The model classifies into 8 categories:
APERIODICCONTACT_ROTDSCT_BCEPECLIPSEGDOR_SPBINSTRUMENT/JUNKRRLYR_CEPHSOLARLIKE
Our TESS preprocessing pipeline (used by --train-csv and astrafier preprocess):
- Filter by QUALITY flags (keep 0, 64, 256, 1024, 2048, 8192)
- Remove NaNs and 10σ outliers
- Subtract Gaussian-filtered trend (σ=61)
- Standardize (median=0, std=1)
- Pad/truncate to sequence length (default: 1171)
For more control, use the individual commands:
# Split CSV by TIC (prevents data leakage)
astrafier split --csv data.csv --output-dir ./splits
# Preprocess to tensors
astrafier preprocess --csv splits/train.csv --output train.pt
astrafier preprocess --csv splits/test.csv --output test.pt
# Train
astrafier train --train-pt train.pt --test-pt test.ptsrc/astrafier/
├── commands/ # CLI: train, predict, preprocess, split
├── data/
│ └── loading.py # FITS preprocessing
└── models/ # ASTRAFier architecture
pytest