In [2]:
from torch.utils.data import Dataset, DataLoader, random_split
from collections import defaultdict
import astropy
from astropy.io import fits
from astropy.table import Table, hstack
from astropy.utils.metadata import MergeConflictWarning
import glob
import torch
import random 
import os 
import subprocess
import numpy as np
import warnings
from sklearn.model_selection import train_test_split
import pandas as pd 

In [3]:
warnings.simplefilter('ignore', MergeConflictWarning)
# List all FITS files
# Get the repo root (assumes script is inside STARDUSTAI/)
repo_root = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], text=True).strip()
base_dir = os.path.join(repo_root, "data/full/")
file_paths = glob.glob(os.path.join(base_dir, "*/*.fits"))

# If no FITS files are found, raise an error
if not file_paths:
    raise ValueError("No FITS files found in 'data/full/'")

# Shuffle the file paths
random.shuffle(file_paths)


In [4]:
# Custom PyTorch dataset for lazy loading
class FitsDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths  # Store file paths
        self.class_categories = ['STAR', 'GALAXY', 'QSO']
        self.subclass_categories = ['nan', 'Starforming', 'Starburst', 'AGN', 'O', 'OB', 'B6', 'B9', 'A0', 'A0p', 'F2', 'F5', 'F9', 'G0', 'G2', 'G5', 'K1', 'K3', 'K5', 'K7', 'M0V', 'M2V', 'M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'M7', 'M8', 'L0', 'L1', 'L2', 'L3', 'L4', 'L5', 'L5.5', 'L9', 'T2', 'Carbon', 'Carbon_lines', 'CarbonWD', 'CV', 'BROADLINE']
        self.plate_quality_tags = {'bad': 0,  'marginal': 1, 'good': 2, 'nan': np.nan}

    def __len__(self):
        return len(self.file_paths)  # Total number of files

    def getitem(self, idx):
        file_path = self.file_paths[idx]
        
        # Read FITS file as Astropy Table
        dat1 = Table.read(file_path, format='fits', hdu=1)
        dat1 = dat1['flux', 'loglam', 'ivar', 'model']
        dat2 = Table.read(file_path, format='fits', hdu=2)
        dat2 = dat2['PLATEQUALITY', 'PLATESN2', 'PLATE', 'TILE', 'MJD', 'FIBERID', 'CLASS', "SUBCLASS", 'Z', 'Z_ERR', 'SN_MEDIAN', 'SN_MEDIAN_ALL', 'ZWARNING' , 'RCHI2']
        data = hstack([dat1, dat2])  # Merge HDUs
        sn_median_values = np.vstack(data['SN_MEDIAN'])  # Shape: (4590, 5)

        # Add new columns for each filter
        data['SN_MEDIAN_UV'] = sn_median_values[:, 0]  # Ultraviolet
        data['SN_MEDIAN_G'] = sn_median_values[:, 1]   # Green
        data['SN_MEDIAN_R'] = sn_median_values[:, 2]   # Red
        data['SN_MEDIAN_NIR'] = sn_median_values[:, 3] # Near-Infrared
        data['SN_MEDIAN_IR'] = sn_median_values[:, 4]  # Infrared

        # Remove the original SN_MEDIAN column if needed
        data.remove_column('SN_MEDIAN')

        # Convert Astropy Table to Pandas DataFrame
        df = data.to_pandas()

        # Map PLATEQUALITY to numerical values and fill NaNs with same value
        df['PLATEQUALITY'] = df['PLATEQUALITY'].astype(str).map(self.plate_quality_tags)
        first_value = df['PLATEQUALITY'].iloc[0]
        df.fillna(value = {'PLATEQUALITY': first_value}, inplace=True)
       

        # one hot encode class
        df['CLASS'] = df['CLASS'].astype(str)
        class_label = df['CLASS'].values[0] 
        class_one_hot = np.zeros(len(self.class_categories))
        class_one_hot[self.class_categories.index(class_label)] = 1

        # one hot encode subclass
        df['SUBCLASS'] = df['SUBCLASS'].astype(str)
        subclass_label = df['SUBCLASS'].values[0]
        if subclass_label not in self.subclass_categories:
            subclass_label = 'nan'
        subclass_one_hot = np.zeros(len(self.subclass_categories))
        subclass_one_hot[self.subclass_categories.index(subclass_label)] = 1

        # Drop class and subclass columns and keep everything else 
        features = df.drop(columns=['CLASS', 'SUBCLASS'])
        features = features.fillna(0) 
        features = features.astype(np.float32) 

        features_tensor = torch.tensor(features.values, dtype=torch.float32)
        class_label_tensor = torch.tensor(class_one_hot, dtype=torch.long)
        subclass_label_tensor = torch.tensor(subclass_one_hot, dtype=torch.long)

        return features_tensor, class_label_tensor, subclass_label_tensor  

In [5]:
dataset = FitsDataset(file_paths)
print(dataset.getitem(3)[0][4])

tensor([7.4049, 3.5555, 0.0350, 0.9253, 2.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4804, 1.5487, 2.5069,
        2.4406, 1.6370])


In [None]:
def data_stats(file_paths):
    class_counts = defaultdict(int)
    subclass_counts = defaultdict(int)
    zwarning_zero = 0
    total_objects = 0
    problem_childs = 0
    
    for file_path in file_paths:
        with fits.open(file_path) as hdul:
            if len(hdul) <= 2:
                print(f"Skipping {file_path}: HDU 2 does not exist.")
                continue
            hdu2 = hdul[2].data
            main_class = str(hdu2['CLASS'][0]).strip().upper()
            raw_subclass = str(hdu2['SUBCLASS'][0]).strip()
            #if main_class == 'QSO':
            #    print(raw_subclass)
            
            # Split subclass name to only include main title
            clean_subclass = raw_subclass.split('(')[0].split('/')[0].strip()[:2]
            if main_class == 'STAR' and len(clean_subclass) == 2:
                subclass = clean_subclass
            else:
                subclass = raw_subclass if raw_subclass in SUBCLASS_CATEGORIES else 'nan'

            class_counts[main_class] += 1
            subclass_counts[subclass] += 1
            
            # ZWARNING
            zwarnings = hdu2['ZWARNING']
            zwarning_zero += (zwarnings == 0).sum()
            total_objects += len(zwarnings)
    
    print("\nClass Distribution:")
    for cls, count in class_counts.items():
        print(f"{cls}: {count} ({count/sum(class_counts.values()):.1%})")

    print("\nSubclass Distribution:")
    for subcls, count in sorted(subclass_counts.items(), key=lambda x: -x[1]):
        print(f"{subcls}: {count}")
        
    print("\nSNR Distribution")
    
    # Print ZWARNING
    print("\nZWARNING Distribution:")
    zwarning_nonzero = total_objects - zwarning_zero
    print(f"ZWARNING=0: {zwarning_zero} ({zwarning_zero/total_objects:.1%})")
    print(f"ZWARNING≠0: {zwarning_nonzero} ({zwarning_nonzero/total_objects:.1%})")

In [7]:
SUBCLASS_CATEGORIES = ['nan', 'STARFORMING', 'STARBURST', 'AGN', 'O', 'OB', 'B6', 'B9', 'A0', 'A0p', 'F2', 'F5', 'F9', 'G0', 'G2', 'G5', 'K1', 'K3', 'K5', 'K7', 'M0V', 'M2V', 'M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'M7', 'M8', 'L0', 'L1', 'L2', 'L3', 'L4', 'L5', 'L5.5', 'L9', 'T2', 'Carbon', 'Carbon_lines', 'CarbonWD', 'CV', 'BROADLINE']
data_stats(file_paths)

        Use textwrap.indent() instead. [astropy.io.fits.hdu.hdulist]
    Header size is not multiple of 2880: 35070
There may be extra bytes after the last HDU or the file is corrupted. [astropy.io.fits.hdu.hdulist]


Skipping C:/Users/haris/Desktop/EngSci Y3S2/StarDust/StarDustAI\data/full\10227\spec-10227-58224-0419.fits: HDU 2 does not exist.

Class Distribution:
QSO: 7785 (51.0%)
STAR: 2135 (14.0%)
GALAXY: 5351 (35.0%)

Subclass Distribution:
nan: 8501
BROADLINE: 4303
M5: 451
F3: 290
STARBURST: 251
M4: 208
WD: 123
M6: 82
F0: 81
CV: 79
M1: 79
F8: 77
STARFORMING: 71
K3: 63
G0: 60
G8: 58
K5: 55
K0: 55
Ld: 42
A2: 34
G4: 28
F2: 26
M3: 25
F6: 22
Ca: 18
K1: 15
A4: 13
A1: 11
B5: 11
AGN: 10
sd: 9
K4: 9
O8: 8
B9: 8
G9: 8
A9: 8
G5: 8
B8: 7
G1: 6
B3: 5
B2: 5
M0: 5
A5: 4
A3: 4
O9: 4
F9: 3
G3: 3
M2: 3
M7: 3
K2: 3
B0: 3
A8: 3
M8: 2
B7: 2
F5: 2
B1: 1
Am: 1
A0: 1
B6: 1

