# Building The StarTrack Model - 2

In [60]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
import tqdm

In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [32]:
tabular_data = pd.read_csv('../data/processed_data.csv')
spectrogram_data = np.load('../data/startrack_spectrograms.npz', allow_pickle=True)

In [33]:
columns_to_drop = ["Unnamed: 0", "Unnamed: 0.1", "Unnamed: 0.3", "Unnamed: 0.2",
                   "CLASS", "SUBCLASS", "SUBCLASS_CLEAN", "SPECTRAL_GROUP", 
                   "url", "filename", "PLATE", "MJD", "FIBERID", "ELODIE_SPTYPE" ]
labels_tabular = tabular_data["SPECTRAL_GROUP"]
tabular_data = tabular_data.drop(columns=columns_to_drop)
pd.set_option('display.max_columns', None)
bin_dict = {'Very Low': 0, 'Low': 1, 'Moderate': 2, 'Good': 3, 'High': 4, 'Very High': 5}
tabular_data["SNR_Bin"] = tabular_data["SNR_Bin"].map(bin_dict)
qual_dict = {False: 0, True: 1} # There's probably a better way to do this not going to lie
tabular_data["High_Quality"] = tabular_data["High_Quality"].map(qual_dict)
tabular_data.head()

Unnamed: 0,ELODIE_TEFF,ELODIE_LOGG,ELODIE_FEH,Z,Z_ERR,ZWARNING,VDISP,VDISP_ERR,SN_MEDIAN_ALL,RCHI2,DOF,SNR_Bin,High_Quality,Mean_Flux,Flux_to_Noise,u_flux,g_flux,r_flux,i_flux,z_flux,flux_mean,flux_std,flux_min,flux_max,flux_median,flux_p25,flux_p75
0,3705.0,4.8,0.6,-0.000485,1.4e-05,0,0.0,0.0,19.021759,0.880377,4542,2.0,0,37.168239,1.953985,0.269059,0.268631,0.195929,0.125402,0.047316,37.168239,10.325851,19.371477,46.528843,42.201336,31.882357,45.857182
1,3705.0,4.8,0.6,-0.000195,1.5e-05,0,0.0,0.0,22.828102,1.034391,4487,3.0,1,40.272045,1.764143,0.313935,0.310882,0.236147,0.178934,0.105151,40.272045,11.563604,20.819483,51.774773,44.854801,33.856216,50.054951
2,3705.0,4.8,0.6,0.000728,1.1e-05,0,0.0,0.0,28.743233,0.948585,4505,3.0,1,67.842227,2.360285,0.630979,0.791056,0.716503,0.637338,0.519138,67.842227,21.69888,31.049629,89.325089,76.546692,56.28862,86.001106
3,3705.0,4.8,0.6,0.000277,1.6e-05,0,0.0,0.0,23.706001,0.875567,4568,3.0,1,40.974394,1.72844,0.263392,0.306808,0.250842,0.196867,0.137755,40.974394,13.044562,19.188595,54.732037,45.824268,33.665863,51.461208
4,9899.0,2.924,0.09,0.000616,8.1e-05,0,0.0,0.0,2.716969,0.984394,4395,0.0,0,2.968928,1.092735,-0.283738,-0.361274,-0.387023,-0.406692,-0.432778,2.968928,0.92436,1.534235,4.132542,2.982392,2.454889,3.740581


In [34]:
tabular_data.shape

(24782, 27)

In [35]:
labels_tabular.head()

0    F
1    F
2    F
3    F
4    F
Name: SPECTRAL_GROUP, dtype: object

In [36]:
spectrograms = spectrogram_data["spectrograms"] 
labels_spect = spectrogram_data["labels"] 
spectrograms.shape

(24782, 1024)

In [38]:
# Method taken from my EchoScope project, check it out here: https://github.com/blueskinlizard/EchoScope/tree/main
def split_sets(X, y):
    X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y)

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.25, random_state=42, stratify=y_train_val)
    
    print(f'Train shape: {X_train.shape}, {y_train.shape}')
    print(f'Validation shape: {X_val.shape}, {y_val.shape}')
    print(f'Test shape: {X_test.shape}, {y_test.shape}')
    return X_train, y_train, X_val, y_val, X_test, y_test

In [58]:
# Method also taken from EchoScope
def train_model(model, criterion, optimizer, epochs, train_loader, val_loader, patience=5, scheduler=None):
    best_val_loss = float('inf')
    best_model_state = None
    no_improve_epochs = 0
    
    train_losses = []
    val_losses = []

    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0.0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Train]", leave=False)
        for X_batch, y_batch in train_bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            train_bar.set_postfix(loss=loss.item())
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        val_bar = tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [Val]", leave=False)
        with torch.no_grad():
            for X_batch, y_batch in val_bar:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                val_loss += loss.item() * X_batch.size(0)
                val_bar.set_postfix(loss=loss.item())
        val_loss /= len(val_loader.dataset)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Epoch {epoch}/{epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Loss/Val", val_loss, epoch)
        writer.add_scalar("LR", optimizer.param_groups[0]['lr'], epoch)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = copy.deepcopy(model.state_dict())
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print(f"Early stopping at epoch {epoch} (best val loss: {best_val_loss:.6f})")
                break
        if scheduler:
            scheduler.step(val_loss)
            
    model.load_state_dict(best_model_state)
    return train_losses, val_losses

In [59]:
# Method also taken from EchoScope
def set_seed(seed=1):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed()

# Training The Dense (Tabular Data) branch

In [37]:
le_tabular = LabelEncoder()
labels_encoded_tabular = le_tabular.fit_transform(labels_tabular)
X = tabular_data
y = labels_encoded_tabular
print("X shape:", X.shape)
print(f'Total samples: {X.shape[0]}')
print(f'Unique classes: {le_tabular.classes_}')

X shape: (24782, 27)
Total samples: 24782
Unique classes: ['A' 'B' 'CV' 'F' 'G' 'K' 'L' 'M' 'WD']


In [39]:
X_table_train, y_table_train, X_table_val, y_table_val, X_table_test, y_table_test = split_sets(X, y)

Train shape: (14868, 27), (14868,)
Validation shape: (4957, 27), (4957,)
Test shape: (4957, 27), (4957,)


In [50]:
X_table_train_fusion = X_table_train.copy()
X_table_val_fusion = X_table_val.copy()
X_table_test_fusion = X_table_test.copy()

y_table_train_fusion = y_table_train.copy()
y_table_val_fusion = y_table_val.copy()
y_table_test_fusion = y_table_test.copy()

In [41]:
train_dataset_table_np = torch.utils.data.TensorDataset(
    torch.tensor(X_table_train.values, dtype=torch.float32),
    torch.tensor(y_table_train, dtype=torch.long)
)

val_dataset_table_np = torch.utils.data.TensorDataset(
    torch.tensor(X_table_val.values, dtype=torch.float32),
    torch.tensor(y_table_val, dtype=torch.long)
)

test_dataset_table_np = torch.utils.data.TensorDataset(
    torch.tensor(X_table_test.values, dtype=torch.float32),
    torch.tensor(y_table_test, dtype=torch.long)
)

In [61]:
train_loader_tabular = DataLoader(train_dataset_table_np, batch_size=64, shuffle=True)
val_loader_tabular = DataLoader(val_dataset_table_np, batch_size=64, shuffle=True)
test_loader_tabular = DataLoader(test_dataset_table_np, batch_size=64, shuffle=True)

In [43]:
class StarTrack_Dense(nn.Module):
    def __init__(self, input_size=22, hidden_sizes=[128, 64], output_size=9):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.bn1 = nn.BatchNorm1d(hidden_sizes[0])
        self.dropout1 = nn.Dropout(0.3)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.bn2 = nn.BatchNorm1d(hidden_sizes[1])
        self.dropout2 = nn.Dropout(0.3)

        self.out = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)

        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)

        x = self.out(x)
        return x
    def extract_features(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        return x

In [55]:
StarTrack_Dense_V1 = StarTrack_Dense()
StarTrack_Dense_V1 = StarTrack_Dense_V1 .to(device)

In [62]:
class_weights_dense = compute_class_weight('balanced', classes=np.unique(y_table_train), y=y_table_train)
class_weights_dense = torch.tensor(class_weights, dtype=torch.float).to(device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights_dense)
optimizer_dense = torch.optim.Adam(StarTrack_Dense_V1.parameters(), lr=1e-4, weight_decay=1e-4)

  class_weights_dense = torch.tensor(class_weights, dtype=torch.float).to(device)


In [None]:
train_losses_dense, val_losses_dense = train_model(model=StarTrack_Dense_V1, 
                                       criterion=loss_fn, 
                                       optimizer=optimizer, 
                                       epochs=15, 
                                       train_loader=train_loader_tabular, 
                                       val_loader=val_loader_tabular)

# Training The LSTM (Spectrogram) Branch

In [44]:
le_spect = LabelEncoder()
labels_encoded_spect = le_spect.fit_transform(labels_spect)
X = spectrograms
X = X.reshape((X.shape[0], X.shape[1], 1))
y = labels_encoded_spect
print("X shape:", X.shape)
print(f'Total samples: {X.shape[0]}')
print(f'Unique classes: {le_spect.classes_}')

X shape: (24782, 1024, 1)
Total samples: 24782
Unique classes: ['A' 'B' 'CV' 'F' 'G' 'K' 'L' 'M' 'WD']


In [47]:
X_spect_train, y_spect_train, X_spect_val, y_spect_val, X_spect_test, y_spect_test = split_sets(X, y)

Train shape: (14868, 1024, 1), (14868,)
Validation shape: (4957, 1024, 1), (4957,)
Test shape: (4957, 1024, 1), (4957,)


In [48]:
X_spect_train_fusion = X_spect_train.copy()
X_spect_val_fusion = X_spect_val.copy()
X_spect_test_fusion = X_spect_test.copy()

y_spect_train_fusion = y_spect_train.copy()
y_spect_val_fusion = y_spect_val.copy()
y_spect_test_fusion = y_spect_test.copy()

In [68]:
train_dataset_spect = TensorDataset(torch.tensor(X_spect_train, dtype=torch.float32), torch.tensor(y_spect_train, dtype=torch.long))
val_dataset_spect = TensorDataset(torch.tensor(X_spect_val, dtype=torch.float32), torch.tensor(y_spect_val, dtype=torch.long))
test_dataset_spect = TensorDataset(torch.tensor(X_spect_test, dtype=torch.float32), torch.tensor(y_spect_test, dtype=torch.long))

train_loader_spect = DataLoader(train_dataset_time, batch_size=64, shuffle=True)
val_loader_spect = DataLoader(val_dataset_time, batch_size=64, shuffle=False)
test_loader_spect = DataLoader(test_dataset_time, batch_size=64, shuffle=False)

In [65]:
# Dataloader needs (B, 1024, 1) shape
class StarTrack_LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=256, output_size=9, num_layers=3):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            dropout=0.5, # Lower/Increase based off testing results
            batch_first=True
        )
        self.bn = nn.BatchNorm1d(hidden_size*2)
        self.fc1 = nn.Linear(hidden_size*2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.bn(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def extract_features(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.bn(x)
        x = F.relu(self.fc1(x))
        return x
        
# Number of classes after grouping = 9

In [66]:
StarTrack_LSTM_V1 = StarTrack_LSTM()
StarTrack_LSTM_V1 = StarTrack_LSTM_V1 .to(device)

In [67]:
class_weights_LSTM = compute_class_weight('balanced', classes=np.unique(y_spect_train), y=y_spect_train)
class_weights_LSTM = torch.tensor(class_weights, dtype=torch.float).to(device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights_LSTM)
optimizer_LSTM = torch.optim.Adam(StarTrack_LSTM_V1.parameters(), lr=1e-4, weight_decay=1e-4)

  class_weights_LSTM = torch.tensor(class_weights, dtype=torch.float).to(device)


In [None]:
train_losses_LSTM, val_losses_LSTM= train_model(model=StarTrack_LSTM_V1, 
                                       criterion=loss_fn, 
                                       optimizer=optimizer_LSTM, 
                                       epochs=15, 
                                       train_loader=train_loader_spect, 
                                       val_loader=val_loader_spect)

# Training The Fusion Branch

In [None]:
class StarTrack_Fusion(nn.Module):
    def __init__(self, lstm_model: nn.Module, dense_model: nn.Module,
                 fused_hidden_dim=512, num_classes=9): 
        super().__init__()
        self.lstm_branch = lstm_model
        self.dense_branch = dense_model

        for param in self.lstm_branch.parameters():
            param.requires_grad = False
        for param in self.dense_branch.parameters():
            param.requires_grad = False
            
        for param in self.lstm_branch.fc1.parameters():
            param.requires_grad = True
        for param in self.dense_branch.fc1.parameters():
            param.requires_grad = True

        lstm_output_dim = 256
        dense_output_dim = 64
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=lstm_output_dim, num_heads=4)
        self.fusion = nn.Sequential(
            nn.Linear(lstm_output_dim + dense_output_dim, fused_hidden_dim),
            nn.GELU(),  
            nn.BatchNorm1d(fused_hidden_dim),
            nn.Linear(fused_hidden_dim, fused_hidden_dim//2),
            nn.GELU(),
            nn.Dropout(0.3), 
            nn.Linear(fused_hidden_dim//2, num_classes)
        )
        
        self.lstm_classifier = nn.Linear(lstm_output_dim, num_classes)
        self.dense_classifier = nn.Linear(dense_output_dim, num_classes)
        self.dense_proj = nn.Linear(dense_output_dim, lstm_output_dim)

    def forward(self, lstm_input, dense_input):
            lstm_feat = self.lstm_branch.extract_features(lstm_input)
            dense_feat = self.dense_branch.extract_features(dense_input)
        
            lstm_q = lstm_feat.unsqueeze(0)  
            dense_kv = dense_proj.unsqueeze(0)  

            dense_proj = self.dense_proj(dense_feat) 
           
            attn_out, _ = self.cross_attn(query=lstm_q, key=dense_kv, value=dense_kv)
            attn_out = attn_out.squeeze(0)
        
            lstm_feat = lstm_feat + attn_out      
            fused = torch.cat([lstm_feat, dense_feat], dim=1) 
            out = self.fusion(fused)
        
            lstm_out = self.lstm_classifier(lstm_feat)
            dense_out = self.dense_classifier(dense_feat)
            return out, lstm_out, dense_out
