# Building The StarTrack Model - 2

In [14]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [20]:
tabular_data = pd.read_csv('../data/processed_data.csv')
spectrogram_data = np.load('../data/startrack_spectrograms.npz', allow_pickle=True)

In [21]:
columns_to_drop = ["Unnamed: 0", "CLASS", "SUBCLASS", "SUBCLASS_CLEAN", "SPECTRAL_GROUP"]
labels_tabular = tabular_data["SPECTRAL_GROUP"]
tabular_data = tabular_data.drop(columns=columns_to_drop)
pd.set_option('display.max_columns', None)
tabular_data.head()

Unnamed: 0,PLATE,MJD,FIBERID,ELODIE_SPTYPE,ELODIE_TEFF,ELODIE_LOGG,ELODIE_FEH,Z,Z_ERR,ZWARNING,VDISP,VDISP_ERR,SN_MEDIAN_ALL,RCHI2,DOF,SPECTROFLUX,url,SNR_Bin,High_Quality,Mean_Flux,Flux_to_Noise,u_flux,g_flux,r_flux,i_flux,z_flux
0,6051,56093,932,M1V,3705.0,4.8,0.6,-0.000485,1.4e-05,0,0.0,0.0,19.021759,0.880377,4542,"[19.371477127075195, 31.882356643676758, 42.20...",https://data.sdss.org/sas/dr17/eboss/spectro/r...,Moderate,False,37.168239,1.953985,0.269059,0.268631,0.195929,0.125402,0.047316
1,4287,55483,210,M1V,3705.0,4.8,0.6,-0.000195,1.5e-05,0,0.0,0.0,22.828102,1.034391,4487,"[20.819482803344727, 33.85621643066406, 44.854...",https://data.sdss.org/sas/dr17/eboss/spectro/r...,Good,True,40.272045,1.764143,0.313935,0.310882,0.236147,0.178934,0.105151
2,4569,55631,889,M1V,3705.0,4.8,0.6,0.000728,1.1e-05,0,0.0,0.0,28.743233,0.948585,4505,"[31.04962921142578, 56.28861999511719, 76.5466...",https://data.sdss.org/sas/dr17/eboss/spectro/r...,Good,True,67.842227,2.360285,0.630979,0.791056,0.716503,0.637338,0.519138
3,5770,56014,694,M1V,3705.0,4.8,0.6,0.000277,1.6e-05,0,0.0,0.0,23.706001,0.875567,4568,"[19.188594818115234, 33.665863037109375, 45.82...",https://data.sdss.org/sas/dr17/eboss/spectro/r...,Good,True,40.974394,1.72844,0.263392,0.306808,0.250842,0.196867,0.137755
4,3784,55269,200,O9.5Ib...,9899.0,2.924,0.09,0.000616,8.1e-05,0,0.0,0.0,2.716969,0.984394,4395,"[1.5342345237731934, 2.4548885822296143, 3.740...",https://data.sdss.org/sas/dr17/eboss/spectro/r...,Very Low,False,2.968928,1.092735,-0.283738,-0.361274,-0.387023,-0.406692,-0.432778


In [37]:
tabular_data.shape

(26743, 26)

In [22]:
labels.head()

0    F
1    F
2    F
3    F
4    F
Name: SPECTRAL_GROUP, dtype: object

In [29]:
spectrograms = spectrogram_data["spectrograms"] 
labels_spect = spectrogram_data["labels"] 
spectrograms.shape

(24889, 1024)

# Training The LSTM (Tabular Data) branch

In [33]:
le_tabular = LabelEncoder()
labels_encoded_tabular = le_tabular.fit_transform(labels)
X = spectrograms
X = X.reshape((X.shape[0], X.shape[1], 1))
y = labels
print("X shape:", X.shape)
print(f'Total samples: {X.shape[0]}')
print(f'Unique classes: {le_tabular.classes_}')

X shape: (24889, 1024, 1)
Total samples: 24889
Unique classes: ['A' 'B' 'CV' 'F' 'G' 'K' 'L' 'M' 'WD']


In [34]:
# Method taken from my EchoScope project, check it out here: https://github.com/blueskinlizard/EchoScope/tree/main
def split_sets(X, y):
    X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y)

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.25, random_state=42, stratify=y_train_val)
    
    print(f'Train shape: {X_train.shape}, {y_train.shape}')
    print(f'Validation shape: {X_val.shape}, {y_val.shape}')
    print(f'Test shape: {X_test.shape}, {y_test.shape}')
    return X_train, y_train, X_val, y_val, X_test, y_test

In [35]:
X_train, y_train, X_val, y_val, X_test, y_test = split_sets(X, y)

ValueError: Found input variables with inconsistent numbers of samples: [24889, 26743]

In [8]:
# Dataloader needs (B, 1024, 1) shape
class StarTrack_LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=256, output_size=9, num_layers=3):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            dropout=0.5, # Lower/Increase based off testing results
            batch_first=True
        )
        self.bn = nn.BatchNorm1d(hidden_size*2)
        self.fc1 = nn.Linear(hidden_size*2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.bn(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    def extract_features(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.bn(x)
        x = F.relu(self.fc1(x))
        return x
        
# Number of classes after grouping = 9

In [11]:
class StarTrack_Dense(nn.Module):
    def __init__(self, input_size=22, hidden_sizes=[128, 64], output_size=9):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.bn1 = nn.BatchNorm1d(hidden_sizes[0])
        self.dropout1 = nn.Dropout(0.3)

        self.fc2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.bn2 = nn.BatchNorm1d(hidden_sizes[1])
        self.dropout2 = nn.Dropout(0.3)

        self.out = nn.Linear(hidden_sizes[1], output_size)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)

        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)

        x = self.out(x)
        return x
    def extract_features(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        return x

In [None]:
class StarTrack_Fusion(nn.Module):
    def __init__(self, lstm_model: nn.Module, dense_model: nn.Module,
                 fused_hidden_dim=512, num_classes=9): 
        super().__init__()
        self.lstm_branch = lstm_model
        self.dense_branch = dense_model

        for param in self.lstm_branch.parameters():
            param.requires_grad = False
        for param in self.dense_branch.parameters():
            param.requires_grad = False
            
        for param in self.lstm_branch.fc1.parameters():
            param.requires_grad = True
        for param in self.dense_branch.fc1.parameters():
            param.requires_grad = True

        lstm_output_dim = 256
        dense_output_dim = 64
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=lstm_output_dim, num_heads=4)
        self.fusion = nn.Sequential(
            nn.Linear(lstm_output_dim + dense_output_dim, fused_hidden_dim),
            nn.GELU(),  
            nn.BatchNorm1d(fused_hidden_dim),
            nn.Linear(fused_hidden_dim, fused_hidden_dim//2),
            nn.GELU(),
            nn.Dropout(0.3), 
            nn.Linear(fused_hidden_dim//2, num_classes)
        )
        
        self.lstm_classifier = nn.Linear(lstm_output_dim, num_classes)
        self.dense_classifier = nn.Linear(dense_output_dim, num_classes)
        self.dense_proj = nn.Linear(dense_output_dim, lstm_output_dim)

    def forward(self, lstm_input, dense_input):
            lstm_feat = self.lstm_branch.extract_features(lstm_input)
            dense_feat = self.dense_branch.extract_features(dense_input)
        
            lstm_q = lstm_feat.unsqueeze(0)  
            dense_kv = dense_proj.unsqueeze(0)  

            dense_proj = self.dense_proj(dense_feat) 
           
            attn_out, _ = self.cross_attn(query=lstm_q, key=dense_kv, value=dense_kv)
            attn_out = attn_out.squeeze(0)
        
            lstm_feat = lstm_feat + attn_out      
            fused = torch.cat([lstm_feat, dense_feat], dim=1) 
            out = self.fusion(fused)
        
            lstm_out = self.lstm_classifier(lstm_feat)
            dense_out = self.dense_classifier(dense_feat)
            return out, lstm_out, dense_out
