<a href="https://colab.research.google.com/github/debashisdotchatterjee/SDSS-Sloan-Bayesian-Ml-2025/blob/main/SDSS_Sloan_Bayesian_%2B_Ml_2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install gpytorch

Collecting gpytorch
  Downloading gpytorch-1.14-py3-none-any.whl.metadata (8.0 kB)
Collecting jaxtyping (from gpytorch)
  Downloading jaxtyping-0.3.1-py3-none-any.whl.metadata (7.0 kB)
Collecting linear-operator>=0.6 (from gpytorch)
  Downloading linear_operator-0.6-py3-none-any.whl.metadata (15 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping->gpytorch)
  Downloading wadler_lindig-0.1.5-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->linear-operator>=0.6->gpytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->linear-operator>=0.6->gpytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->linear-operator>=0.6->gpytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (

In [3]:
# ================================================================
#  Sloan Digital Sky Survey  –  Bayesian Deep‑Kernel Pipeline
#  (author: ChatGPT, 18‑Apr‑2025)
# ================================================================
import os, time, random
from pathlib import Path

import numpy  as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics      import confusion_matrix, classification_report

import torch, gpytorch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset

# ----------------------------------------------------------------
# CONFIGURATION – raise values for a full production run
LATENT_DIM   = 16      # size of embedding
NUM_INDUCING = 200     # sparse GP M (≥300 for highest fidelity)
N_EPOCHS     = 25      # 150–200 for production
BATCH_SIZE   = 512
LR           = 3e-3
SUBSAMPLE    = None    # e.g. 4000 for a <3‑min sanity run
SEED         = 0
# ----------------------------------------------------------------

torch.manual_seed(SEED);  np.random.seed(SEED);  random.seed(SEED)

CSV_PATH = Path('Skyserver_SQL2_27_2018 6_51_39 PM.csv')
OUT_DIR  = Path('sdss_bayes_output');  OUT_DIR.mkdir(exist_ok=True)

# ================================================================
# 1.  DATA  &  BASIC FEATURES
# ================================================================
df = pd.read_csv(CSV_PATH)

# colours
df['g_r'] = df['g'] - df['r']
df['r_i'] = df['r'] - df['i']
df['i_z'] = df['i'] - df['z']

# keep the three canonical label types
class_map = {'STAR': 0, 'GALAXY': 1, 'QSO': 2}
df = df[df['class'].isin(class_map)].copy()
df['y']  = df['class'].map(class_map)

if SUBSAMPLE:
    df = df.sample(SUBSAMPLE, random_state=SEED).reset_index(drop=True)

feat_cols   = ['u','g','r','i','z','g_r','r_i','i_z']
scaler      = StandardScaler().fit(df[feat_cols])
X_std       = scaler.transform(df[feat_cols])
y           = df['y'].values.astype(np.int64)
z_spec_full = df['redshift'].fillna(0.).values.astype(np.float32)

# stratified split
X_tr, X_te, y_tr, y_te, z_tr, z_te = train_test_split(
    X_std, y, z_spec_full, test_size=0.2, random_state=SEED, stratify=y
)

# torch datasets
train_dl = DataLoader(
    TensorDataset(torch.tensor(X_tr, dtype=torch.float32),
                  torch.tensor(y_tr, dtype=torch.long),
                  torch.tensor(z_tr, dtype=torch.float32)),
    batch_size=BATCH_SIZE, shuffle=True)

# ================================================================
# 2.  MODEL PARTS
# ================================================================
class EmbedNet(nn.Module):
    def __init__(self, d=LATENT_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(8, 64), nn.SiLU(),
            nn.Linear(64, 32), nn.SiLU(),
            nn.Linear(32, d))
    def forward(self, x): return self.net(x)

class SparseGP(gpytorch.models.ApproximateGP):
    """Single‑task sparse variational GP for photometric z."""
    def __init__(self, d=LATENT_DIM, m=NUM_INDUCING):
        inducing = torch.randn(m, d)
        q_dist   = gpytorch.variational.CholeskyVariationalDistribution(m)
        var_strat= gpytorch.variational.VariationalStrategy(
            self, inducing, q_dist, learn_inducing_locations=True)
        super().__init__(var_strat)
        self.mean_module  = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel())

    def forward(self, x):
        return gpytorch.distributions.MultivariateNormal(
            self.mean_module(x), self.covar_module(x))

class SDSSBayes(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = EmbedNet()
        self.gate  = nn.Linear(LATENT_DIM, 3)     # soft‑max head
        self.gp    = SparseGP()                   # single‑task z‑GP
    def forward(self, x):
        h = self.embed(x)
        return self.gate(h), self.gp(h), h

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model  = SDSSBayes().to(device)
lik_gp = gpytorch.likelihoods.GaussianLikelihood().to(device)

mll_gp = gpytorch.mlls.VariationalELBO(lik_gp, model.gp,
                                       num_data=(y_tr!=0).sum()).to(device)
optimiser = optim.Adam(model.parameters(), lr=LR)

# ================================================================
# 3.  TRAIN
# ================================================================
model.train();  lik_gp.train()
loss_hist = []

t0 = time.time()
for epoch in range(1, N_EPOCHS+1):
    tot = 0.0
    for xb, yb, zb in train_dl:
        xb, yb, zb = xb.to(device), yb.to(device), zb.to(device)

        logits, gp_all, _ = model(xb)
        loss_cls = nn.functional.cross_entropy(logits, yb)

        mask = (yb != 0)             # extragalactic subset
        if mask.sum():
            gp = model.gp(model.embed(xb[mask]))
            targ = zb[mask]
            loss_gp = -mll_gp(gp, targ)
        else:
            loss_gp = torch.tensor(0., device=device)

        loss = loss_cls + loss_gp
        optimiser.zero_grad();  loss.backward();  optimiser.step()
        tot += loss.item()

    loss_hist.append(tot)
    if epoch % 5 == 0 or epoch == N_EPOCHS:
        print(f'epoch {epoch:3d}/{N_EPOCHS}  total={tot:.3f}  cls={loss_cls.item():.3f}  gp={loss_gp.item():.3f}')

print(f'✓ training finished in {time.time()-t0:.1f}s')

# ================================================================
# 4.  EVALUATION & PLOTS
# ================================================================
model.eval();  lik_gp.eval()

# ---------- classification ------------
X_te_t = torch.tensor(X_te, dtype=torch.float32, device=device)
with torch.no_grad():
    logits_te = model.gate(model.embed(X_te_t))
    probs_te  = torch.softmax(logits_te, dim=1).cpu().numpy()
y_pred = probs_te.argmax(1)

cm   = confusion_matrix(y_te, y_pred, labels=[0,1,2])
rep  = classification_report(y_te, y_pred,
        target_names=['STAR','GALAXY','QSO'],
        output_dict=True, zero_division=0)
rep_df = pd.DataFrame(rep).T.round(3)
rep_df.to_csv(OUT_DIR/'classification_report.csv')

# ---------- figures -------------------
# 1. colour–colour
fig,ax = plt.subplots()
ax.scatter(df['g_r'], df['r_i'], s=6, alpha=.5)
ax.set_xlabel('g - r'); ax.set_ylabel('r - i')
ax.set_title('Colour–Colour Diagram')
fig.tight_layout(); fig.savefig(OUT_DIR/'colour_colour.png', dpi=300); plt.close(fig)

# 2. red‑shift density
fig,ax = plt.subplots()
for cls in ['GALAXY','QSO']:
    sub = df[df['class']==cls]['redshift']
    ax.hist(sub,bins=40,density=True,histtype='step',label=cls)
ax.set_xlabel('redshift'); ax.set_ylabel('density'); ax.legend()
fig.tight_layout(); fig.savefig(OUT_DIR/'redshift_hist.png', dpi=300); plt.close(fig)

# 3. training loss
fig,ax = plt.subplots()
ax.plot(loss_hist); ax.set_xlabel('epoch'); ax.set_ylabel('loss')
ax.set_title('Training Loss')
fig.tight_layout(); fig.savefig(OUT_DIR/'loss_curve.png', dpi=300); plt.close(fig)

# 4. confusion matrix
fig,ax=plt.subplots()
im=ax.imshow(cm, cmap='Blues'); ax.set_xticks(range(3)); ax.set_yticks(range(3))
ax.set_xticklabels(['STAR','GAL','QSO']); ax.set_yticklabels(['STAR','GAL','QSO'])
ax.set_xlabel('predicted'); ax.set_ylabel('true')
for i in range(3):
    for j in range(3):
        ax.text(j,i,cm[i,j],ha='center',va='center',
                color='white' if cm[i,j]>cm.max()*0.6 else 'black')
fig.tight_layout(); fig.savefig(OUT_DIR/'confusion_matrix.png', dpi=300); plt.close(fig)

# 5. z‑scatter (only extragalactic in test set)
mask_ex = y_te!=0
with torch.no_grad():
    mu = model.gp(model.embed(torch.tensor(X_te[mask_ex], dtype=torch.float32,
                                           device=device))).mean.cpu().numpy()
fig,ax = plt.subplots()
ax.scatter(z_te[mask_ex], mu, s=10, alpha=.6)
ax.plot([0,z_te.max()],[0,z_te.max()], ls='--', lw=1)
ax.set_xlabel('z (spec)'); ax.set_ylabel('ẑ (phot)')
ax.set_title('Photometric vs. Spectroscopic z')
fig.tight_layout(); fig.savefig(OUT_DIR/'zscatter.png', dpi=300); plt.close(fig)

# ================================================================
# 5.  CONSOLE PREVIEW
# ================================================================
print("\n==> files saved to", OUT_DIR)
for f in sorted(os.listdir(OUT_DIR)): print("  ", f)

print("\n--- SDSS head (10) ---")
print(df.head(10).to_string(index=False))

print("\n--- Classification report ---")
print(rep_df.to_string())


epoch   5/25  total=26.332  cls=0.349  gp=1.357
epoch  10/25  total=18.016  cls=0.201  gp=0.871
epoch  15/25  total=15.378  cls=0.176  gp=0.885
epoch  20/25  total=14.984  cls=0.212  gp=0.923
epoch  25/25  total=14.698  cls=0.132  gp=0.773
✓ training finished in 19.6s

==> files saved to sdss_bayes_output
   classification_report.csv
   colour_colour.png
   confusion_matrix.png
   loss_curve.png
   redshift_hist.png
   zscatter.png

--- SDSS head (10) ---
       objid         ra      dec        u        g        r        i        z  run  rerun  camcol  field    specobjid  class  redshift  plate   mjd  fiberid      g_r      r_i      i_z  y
1.237650e+18 183.531326 0.089693 19.47406 17.04240 15.94699 15.50342 15.22531  752    301       4    267 3.722360e+18   STAR -0.000009   3306 54922      491  1.09541  0.44357  0.27811  0
1.237650e+18 183.598370 0.135285 18.66280 17.21449 16.67637 16.48922 16.39150  752    301       4    267 3.638140e+17   STAR -0.000055    323 51615      541  0.53812 