Skip to content

jeraud/TESS-Transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ASTRAFier

A transformer-based classifier for TESS light curves.

Installation

pip install -r requirements.txt
pip install -e .

Quick Start

Predict with Pre-trained Model

Classify your FITS files using our pre-trained weights:

astrafier predict --fits-dir /path/to/fits --output-dir ./predictions

Output: predictions.csv with columns tic and probabilities.

Train on Our Data

Download the paper's training data from HuggingFace (includes TESS and Kepler light curves):

from huggingface_hub import hf_hub_download

for f in ["train.safetensors", "train.json", "test.safetensors", "test.json"]:
    hf_hub_download("paulg9/astrafier", f, local_dir="./data")

Fine-tune from our checkpoint:

astrafier train --train-pt data/train.safetensors --test-pt data/test.safetensors

Or train from scratch:

astrafier train --train-pt data/train.safetensors --test-pt data/test.safetensors --no-load-from-hf

Train on Your Data

From CSV — splitting and preprocessing are handled automatically:

astrafier train --train-csv data.csv

Your CSV needs label and path columns. If you include a TIC column, we split by TIC to prevent data leakage (same star in train and test). Without it, we split by row.

label path TIC
ECLIPSE /data/tic123.fits 123
CONTACT_ROT /data/tic456.fits 456

Save preprocessed tensors for faster reruns:

astrafier train --train-csv data.csv --save-preprocessed ./tensors

From tensors — if you have your own preprocessing:

astrafier train --train-pt train.pt --test-pt test.pt

Expected .pt structure:

torch.save({
    "flux": flux_tensor,      # (N, seq_len) float32
    "time": time_tensor,      # (N, seq_len) float32
    "labels": label_tensor,   # (N, num_classes) float32 one-hot
    "mask": mask_tensor,      # (N, seq_len) bool, True = valid
    "label_map": {"ECLIPSE": 0, ...},
}, "train.pt")

Predict with Your Checkpoint

astrafier predict --fits-dir /path/to/fits --checkpoint model.ckpt --no-load-from-hf

Training Options

Data source:

  • --train-csv: From CSV with FITS paths (we split + preprocess)
  • --train-pt / --test-pt: From pre-processed tensors

Model weights:

  • --load-from-hf: Fine-tune from our checkpoint (default)
  • --no-load-from-hf: Train from scratch

Other flags:

  • --batch-size: Default 128
  • --max-epochs: Default 250
  • --precision: Use 32-true if bf16 not supported
  • --save-preprocessed: Save tensors when using --train-csv
  • --mc-dropout / --mc-samples: Monte Carlo dropout for uncertainty

Run astrafier train --help for all options.

Class Labels

The model classifies into 8 categories:

  • APERIODIC
  • CONTACT_ROT
  • DSCT_BCEP
  • ECLIPSE
  • GDOR_SPB
  • INSTRUMENT/JUNK
  • RRLYR_CEPH
  • SOLARLIKE

Preprocessing

Our TESS preprocessing pipeline (used by --train-csv and astrafier preprocess):

  1. Filter by QUALITY flags (keep 0, 64, 256, 1024, 2048, 8192)
  2. Remove NaNs and 10σ outliers
  3. Subtract Gaussian-filtered trend (σ=61)
  4. Standardize (median=0, std=1)
  5. Pad/truncate to sequence length (default: 1171)

Building Blocks

For more control, use the individual commands:

# Split CSV by TIC (prevents data leakage)
astrafier split --csv data.csv --output-dir ./splits

# Preprocess to tensors
astrafier preprocess --csv splits/train.csv --output train.pt
astrafier preprocess --csv splits/test.csv --output test.pt

# Train
astrafier train --train-pt train.pt --test-pt test.pt

Project Layout

src/astrafier/
├── commands/       # CLI: train, predict, preprocess, split
├── data/
│   └── loading.py  # FITS preprocessing
└── models/         # ASTRAFier architecture

Testing

pytest

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages