In [None]:
import sys
import os
sys.path.append('../../') # Add the root directory to sys.path

In [None]:
import torch
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

from models.starnet import StarNet
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold

In [None]:
data_dir = '../../data/Gaia DR3/train.parquet'

In [None]:
data = pd.read_parquet(data_dir)
data.head(5)

In [None]:
df = data.drop(columns = ['teff_gspphot', 'logg_gspphot', 'mh_gspphot', 'spectraltype_esphs'])
df.head()

In [None]:
df.info()

In [None]:
# rerun this cell to see random examples of different spectra


# Random sample from 'M' category (massive star)
sample_ms = df[df['Cat'] == 'M'].sample(n=1).index
flux_ms = df['flux'].iloc[sample_ms].values[0]
object_id_ms = df['source_id'].iloc[sample_ms].values[0]

# Random sample from 'LM' category (low-mass star)
sample_lm = df[df['Cat'] == 'LM'].sample(n=1).index
flux_lm = df['flux'].iloc[sample_lm].values[0]
object_id_lm = df['source_id'].iloc[sample_lm].values[0]

plt.figure(figsize=(12, 6))

# Plot for 'M' category (massive star)
plt.subplot(1, 2, 1)
plt.plot(flux_ms)
plt.title(f"Massive Star ({object_id_ms})")
plt.xlabel('Wavelength (nm)')
plt.ylabel('Magnitude')

# Plot for 'LM' category (low-mass star)
plt.subplot(1, 2, 2)
plt.plot(flux_lm)
plt.title(f"Low-Mass Star ({object_id_lm})")
plt.xlabel('Wavelength (nm)')
plt.ylabel('Magnitude')

plt.tight_layout()
plt.show()


In [None]:
num_samples = df.shape[0]
spectrum_width = len(df['flux'][0])

num_samples_lm = df['Cat'].value_counts()['LM']
num_samples_m = df['Cat'].value_counts()['M']
num_classes = df['Cat'].nunique()

print("Number of total spectral samples:", num_samples)
print("Number of bins in each spectra:", spectrum_width)
print("In the dataset, we have", num_samples_lm, "spectra for low mass stars and", num_samples_m, "spectra for high mass stars.")

In [None]:
X = np.vstack(df['flux'])
y = np.vstack(df['Cat'])

# encode categories to int
y = torch.from_numpy(np.where(y == 'M', 1, np.where(y == 'LM', 0, y)).astype(float))

# L2 normalization
X = torch.from_numpy(X/np.linalg.norm(X,keepdims=True)).float()

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [None]:
def fit_model(model, x_train, y_train, x_val, y_val, prt_steps = 1, verbose=True):
    
    # initialize weights
    model.apply(init_weights)

    # hyperparameters
    epochs = 200
    learning_rate = 1e-4
    batch_size = 64
    device = 'cuda'
    
    # model components
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    
    # move everything to gpu
    model.to(device)
    x_train = x_train.to(device)
    y_train = y_train.to(device)
    x_val = x_val.to(device)
    y_val = y_val.to(device)

    # metrics
    training_losses, validation_losses = [], []
    accuracy = []
    
    # lr cycling
    max_lr = 1e-2
    steps_per_epoch = len(x_train) // batch_size
    scheduler = OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=steps_per_epoch, epochs=epochs)

    batch_start = torch.arange(0, len(x_train), batch_size)

    for epoch in tqdm(range(epochs), desc='Epochs', dynamic_ncols=True):
        
        model.train()
        running_loss = 0

        for start in batch_start:

            x_batch = x_train[start:start+batch_size]
            y_batch = y_train[start:start+batch_size]
            
            output = model(x_batch.unsqueeze(1))
            loss = criterion(output, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_loss += loss.item() * x_batch.size(0)

        train_loss = running_loss / len(x_train)
        training_losses.append(train_loss)
        if verbose and (epoch+1) % prt_steps == 0:
            print(f'Train loss: {train_loss:.4f}', end='\r')


        model.eval()
        preds, labels = [], []

        with torch.no_grad():
            
            output = model(x_val.unsqueeze(1))
            loss = criterion(output, y_val)

            probs = torch.sigmoid(output)
            pred = torch.round(probs).cpu().numpy().astype(float)  # pred: [batch_size]
            
            preds.extend(pred)
            labels.extend(y_val.cpu().numpy())
            
            val_loss = loss.item()
        
        epoch_acc = accuracy_score(labels, preds)

        validation_losses.append(val_loss)
        accuracy.append(epoch_acc)
        
        if verbose and (epoch+1) % prt_steps == 0:
            print(f'Train loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Accuracy: {epoch_acc:.4f}', end='\r')    
            
    return training_losses, validation_losses, accuracy

In [None]:
kfold = StratifiedKFold(n_splits=5)

training_losses, validation_losses, accuracy_scores = [], [], []

params = {
    'num_fluxes':spectrum_width, 
    'filter_length':3, 
    'pool_length':4,
    'num_filters':[4,16],
    'num_hidden':[256,128],
    'num_labels':y.size(1)
}

model = StarNet(**params)
print(model)

for fold, (train_idx, val_idx) in enumerate(kfold.split(X, y)):
    
    print(f"\nFitting fold {fold+1}")

    tr_loss, val_loss, acc = fit_model(model, X[train_idx], y[train_idx], X[val_idx], y[val_idx])
    training_losses.append(tr_loss)
    validation_losses.append(val_loss)
    accuracy_scores.append(acc)

training_losses = np.mean(training_losses, axis=0)
validation_losses = np.mean(validation_losses, axis=0)
accuracy_scores = np.mean(accuracy_scores, axis=0)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 14))
    
ax1.plot(accuracy_scores, label='Validation Accuracy')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Accuracy')
ax1.set_title('Accuracy')
ax1.legend()

ax2.plot(training_losses, label='Training Loss')
ax2.plot(validation_losses, label='Validation Loss')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Loss')
ax2.set_title('Training and Validation Loss')
ax2.legend()

plt.tight_layout()
plt.show()