In [None]:
import os
import math
import random
from typing import List

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [None]:
DATA_PATH = '/content/exoplanet_habitability.csv'
assert os.path.exists(DATA_PATH), f"CSV not found at {DATA_PATH}"

raw = pd.read_csv(DATA_PATH)
print('Raw shape:', raw.shape)
print('Detected columns:', raw.columns[:30].tolist(), '...')

Raw shape: (5933, 320)
Detected columns: ['rowid', 'pl_name', 'hostname', 'pl_letter', 'hd_name', 'hip_name', 'tic_id', 'gaia_id', 'sy_snum', 'sy_pnum', 'sy_mnum', 'cb_flag', 'discoverymethod', 'disc_year', 'disc_refname', 'disc_pubdate', 'disc_locale', 'disc_facility', 'disc_telescope', 'disc_instrument', 'rv_flag', 'pul_flag', 'ptv_flag', 'tran_flag', 'ast_flag', 'obm_flag', 'micro_flag', 'etv_flag', 'ima_flag', 'dkin_flag'] ...


In [None]:
if 'habitability' not in raw.columns:
    raise RuntimeError("Expected 'habitability' column in CSV")

print('\nTarget value counts:')
print(raw['habitability'].value_counts(dropna=False))


Target value counts:
habitability
0.0    5863
0.6      41
0.9      29
Name: count, dtype: int64


In [None]:
label_map = {0.0: 0, 0.6: 1, 0.9: 2}
unique_targets = sorted(raw['habitability'].dropna().unique())
if set(unique_targets).issubset(set(label_map.keys())):
    raw['target'] = raw['habitability'].map(label_map)
else:
    # fallback automatic mapping
    mapping = {v:i for i,v in enumerate(unique_targets)}
    raw['target'] = raw['habitability'].map(mapping)
    print('Auto-mapped targets:', mapping)

print('\nMapped target distribution:')
print(raw['target'].value_counts())


Mapped target distribution:
target
0    5863
1      41
2      29
Name: count, dtype: int64


In [None]:
def select_relevant_columns(df: pd.DataFrame) -> pd.DataFrame:

    # keep numeric columns and drop strings like names
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    # keep target
    if 'target' in df.columns and 'target' not in numeric_cols:
        numeric_cols.append('target')
    # drop columns that are clearly IDs or coordinates if present (optional)
    to_drop_prefixes = ['pl_name', 'hostname', 'sy_', 'ra', 'dec']
    filtered = [c for c in numeric_cols if not any(c.lower().startswith(p) for p in to_drop_prefixes)]
    return df[filtered].copy()

In [None]:
def engineer_features(df: pd.DataFrame) -> pd.DataFrame:

    df = df.copy()
    if 'pl_orbsmax' in df.columns and 'pl_orbper' in df.columns:
        mask = df['pl_orbsmax'].notna() & df['pl_orbper'].notna() & (df['pl_orbper'] != 0)
        if mask.sum() > 0:
            # period in days -> seconds factor cancels out for relative ranking
            df.loc[mask, 'orbital_velocity'] = 2 * math.pi * df.loc[mask, 'pl_orbsmax'] / (df.loc[mask, 'pl_orbper'] / 365.25)
    if 'pl_bmasse' in df.columns and 'pl_rade' in df.columns:
        mask = df['pl_bmasse'].notna() & df['pl_rade'].notna() & (df['pl_rade'] > 0)
        if mask.sum() > 0:
            df.loc[mask, 'surface_gravity'] = df.loc[mask, 'pl_bmasse'] / (df.loc[mask, 'pl_rade'] ** 2)
    # Example binary feature
    if 'pl_bmasse' in df.columns:
        df['is_super_earth'] = (df['pl_bmasse'] >= 1).astype(float)
    return df


In [None]:
def preprocess_tabular(df: pd.DataFrame, min_nonnull_frac=0.6):

    df = select_relevant_columns(df)
    df = engineer_features(df)

    # drop columns with too many missing values
    keep_cols = []
    n = len(df)
    for c in df.columns:
        nonnull_frac = df[c].notna().sum() / max(1, n)
        if nonnull_frac >= min_nonnull_frac or c == 'target':
            keep_cols.append(c)
    df = df[keep_cols].copy()

    # impute numeric columns with median
    num_cols = [c for c in df.columns if c != 'target']
    for c in num_cols:
        if df[c].isna().any():
            med = df[c].median()
            df[c] = df[c].fillna(med)

    # drop columns with near-zero variance
    nunique = df[num_cols].nunique()
    low_var_cols = nunique[nunique <= 1].index.tolist()
    if low_var_cols:
        print('Dropping low-variance cols:', low_var_cols)
        df = df.drop(columns=low_var_cols)

    return df

In [None]:
proc = preprocess_tabular(raw, min_nonnull_frac=0.5)
print('Preprocessed shape:', proc.shape)

features = [c for c in proc.columns if c != 'target']
X = proc[features].values.astype(np.float32)
y = proc['target'].astype(int).values

Dropping low-variance cols: ['pl_angseplim', 'pl_insollim', 'pl_tranmidlim', 'pl_ratrorlim', 'st_radlim', 'st_masslim', 'st_lumlim', 'st_denslim']
Preprocessed shape: (5933, 113)


In [None]:
scaler = StandardScaler()
X = scaler.fit_transform(X)

In [None]:
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

print('Train / Val / Test sizes:', len(X_train), len(X_val), len(X_test))

Train / Val / Test sizes: 4153 890 890


In [None]:
class TabularDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X).float()
        self.y = torch.from_numpy(y).long()
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = TabularDataset(X_train, y_train)
val_ds = TabularDataset(X_val, y_val)
test_ds = TabularDataset(X_test, y_test)

In [None]:
class_sample_counts = np.bincount(y_train)
print('Class counts (train):', class_sample_counts)
weights = 1.0 / class_sample_counts
sample_weights = weights[y_train]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

Class counts (train): [4104   29   20]


In [None]:
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        hidden = max(1, channels // reduction)
        self.fc1 = nn.Linear(channels, hidden)
        self.fc2 = nn.Linear(hidden, channels)
    def forward(self, x):
        # x: (batch, channels)
        s = F.relu(self.fc1(x))
        s = torch.sigmoid(self.fc2(s))
        return x * s

In [None]:
class DenseBranch(nn.Module):
    def __init__(self, input_dim, layer_sizes: List[int], dropout=0.2):
        super().__init__()
        layers = []
        prev = input_dim
        for h in layer_sizes:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.BatchNorm1d(h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev = h
        self.net = nn.Sequential(*layers)
        self.se = SEBlock(prev, reduction=4)
    def forward(self, x):
        h = self.net(x)
        h = self.se(h)
        return h

In [None]:
class TransformerBranch(nn.Module):
    def __init__(self, num_features, embed_dim=64, n_heads=4, n_layers=2, dropout=0.1):
        super().__init__()
        self.num_features = num_features
        self.embed_dim = embed_dim
        # project each scalar feature -> embed_dim vector
        self.feature_proj = nn.Linear(1, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(num_features, embed_dim) * 0.01)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(nn.ModuleDict({
                'mha': nn.MultiheadAttention(embed_dim, num_heads=n_heads, dropout=dropout, batch_first=True),
                'ffn': nn.Sequential(nn.Linear(embed_dim, embed_dim*4), nn.ReLU(), nn.Linear(embed_dim*4, embed_dim))
            }))
        self.layernorms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(n_layers*2)])
    def forward(self, x):
        # x: (batch, num_features)
        b = x.shape[0]
        # to (batch, seq_len, 1)
        tokens = x.unsqueeze(-1)
        h = self.feature_proj(tokens)  # (batch, seq_len, embed_dim)
        # add pos embed
        h = h + self.pos_embed.unsqueeze(0)
        for i,layer in enumerate(self.layers):
            # MultiheadAttention with batch_first=True expects (batch, seq, embed)
            attn_out, _ = layer['mha'](h, h, h)
            h = self.layernorms[2*i](h + attn_out)
            ffn_out = layer['ffn'](h)
            h = self.layernorms[2*i+1](h + ffn_out)
        # aggregate across sequence (mean pooling)
        h = h.mean(dim=1)  # (batch, embed_dim)
        return h


In [None]:
class MSMAClassifier(nn.Module):
    def __init__(self, input_dim, n_classes=3, embed_dim=64):
        super().__init__()
        # three branches: shallow, deep, transformer
        self.branch_shallow = DenseBranch(input_dim, [128], dropout=0.2)
        self.branch_deep = DenseBranch(input_dim, [256, 128], dropout=0.3)
        self.branch_transformer = TransformerBranch(num_features=input_dim, embed_dim=embed_dim, n_heads=4, n_layers=2)
        combined_dim = 128 + 128 + embed_dim
        self.head = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, n_classes)
        )
    def forward(self, x):
        # x: (batch, input_dim)
        a = self.branch_shallow(x)
        b = self.branch_deep(x)
        c = self.branch_transformer(x)
        # ensure dims match along channel
        # if branch outputs have extra dims, flatten them
        if a.dim() > 2: a = a.view(a.size(0), -1)
        if b.dim() > 2: b = b.view(b.size(0), -1)
        if c.dim() > 2: c = c.view(c.size(0), -1)
        out = torch.cat([a, b, c], dim=1)
        logits = self.head(out)
        return logits


In [None]:

def compute_class_weights(y):
    counts = np.bincount(y)
    weights = len(y) / (len(counts) * counts.clip(min=1))
    return torch.tensor(weights, dtype=torch.float32)

In [None]:
def train_one_epoch(model, loader, opt, criterion):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for xb, yb in loader:
        xb = xb.to(device)
        yb = yb.to(device)
        opt.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        opt.step()
        total_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += xb.size(0)
    return total_loss / total, correct / total

In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)
            total_loss += loss.item() * xb.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
            y_true.append(yb.cpu().numpy())
            y_pred.append(preds.cpu().numpy())
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    return total_loss / total, correct / total, y_true, y_pred


In [None]:
input_dim = X_train.shape[1]
n_classes = len(np.unique(y))
model = MSMAClassifier(input_dim=input_dim, n_classes=n_classes, embed_dim=64).to(device)

class_weights = compute_class_weights(y_train).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm
n_epochs = 40
best_val_acc = 0.0
patience = 8
patience_counter = 0
save_path = 'msma_best.pt'

In [None]:
for epoch in range(1, n_epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, opt, criterion)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion)
    print(f"Epoch {epoch:02d}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f}  |  val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        patience_counter = 0
        print('  -> saved new best model')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print('Early stopping triggered')
            break

Epoch 01: train_loss=0.2553, train_acc=0.6078  |  val_loss=2.4282, val_acc=0.0112
  -> saved new best model
Epoch 02: train_loss=0.0114, train_acc=0.6805  |  val_loss=1.0181, val_acc=0.0112
Epoch 03: train_loss=0.0032, train_acc=0.8904  |  val_loss=0.4007, val_acc=0.8876
  -> saved new best model
Epoch 04: train_loss=0.0011, train_acc=0.9728  |  val_loss=0.3064, val_acc=0.9382
  -> saved new best model
Epoch 05: train_loss=0.0006, train_acc=0.9865  |  val_loss=0.2533, val_acc=0.9652
  -> saved new best model
Epoch 06: train_loss=0.0004, train_acc=0.9916  |  val_loss=0.2189, val_acc=0.9910
  -> saved new best model
Epoch 07: train_loss=0.0003, train_acc=0.9949  |  val_loss=0.2440, val_acc=0.9933
  -> saved new best model
Epoch 08: train_loss=0.0001, train_acc=0.9988  |  val_loss=0.2402, val_acc=0.9933
Epoch 09: train_loss=0.0001, train_acc=0.9974  |  val_loss=0.2208, val_acc=0.9978
  -> saved new best model
Epoch 10: train_loss=0.0001, train_acc=0.9986  |  val_loss=0.2324, val_acc=0.996

In [None]:
model.load_state_dict(torch.load(save_path))
_test_loss, test_acc, y_true, y_pred = evaluate(model, test_loader, criterion)
print('\nFinal test accuracy:', test_acc)
print('\nClassification report:')
print(classification_report(y_true, y_pred))
print('\nConfusion matrix:')
print(confusion_matrix(y_true, y_pred))


Final test accuracy: 0.996629213483146

Classification report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       880
           1       0.67      1.00      0.80         6
           2       1.00      0.75      0.86         4

    accuracy                           1.00       890
   macro avg       0.89      0.92      0.89       890
weighted avg       1.00      1.00      1.00       890


Confusion matrix:
[[878   2   0]
 [  0   6   0]
 [  0   1   3]]
