# Demo: Exoplanet Habitability Prediction

In this notebook, users can input planetary and stellar parameters and get a habitability score. No sliders—just fields with specified valid ranges. We’ll:

1. Load our scaler, model, and threshold.  
2. Display input fields with min/max hints.  
3. On button click, scale the inputs, run the model, and show the probability & classification.  

In [2]:
# 0) Ensure imports can find src/
import os, sys
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# 1) Standard imports & load artifacts
import json, joblib, numpy as np, torch
import ipywidgets as widgets
from IPython.display import display, Markdown
from src.model import SimpleHabitabilityNet

ART = os.path.join("..","data","artifacts")
scaler    = joblib.load(f"{ART}/scaler.joblib")
threshold = json.load(open(f"{ART}/config.json"))["threshold"]
model     = SimpleHabitabilityNet(input_dim=8)
model.load_state_dict(torch.load(f"{ART}/model.pth", map_location="cpu"))
model.eval()

SimpleHabitabilityNet(
  (net): Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=32, out_features=16, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=16, out_features=1, bias=True)
  )
)

In [3]:
# 2) Build the input fields, button, and output area
fields = dict(
    pl_rade    = widgets.BoundedFloatText(1.0, min=0.1, max=100,   step=0.1, description='pl_rade:'),
    pl_bmasse  = widgets.BoundedFloatText(1.0, min=0.1, max=10000, step=0.1, description='pl_bmasse:'),
    pl_orbsmax = widgets.BoundedFloatText(1.0, min=0.001, max=1000,step=0.001, description='pl_orbsmax:'),
    pl_orbeccen= widgets.BoundedFloatText(0.0, min=0.0, max=1.0,   step=0.01, description='pl_orbeccen:'),
    pl_insol   = widgets.BoundedFloatText(1.0, min=0.0, max=1e5,   step=1,   description='pl_insol:'),
    st_teff    = widgets.BoundedIntText   (5778, min=2000, max=50000,step=10,  description='st_teff:'),
    st_rad     = widgets.BoundedFloatText(1.0, min=0.01, max=100,  step=0.01, description='st_rad:'),
    st_mass    = widgets.BoundedFloatText(1.0, min=0.01, max=100,  step=0.01, description='st_mass:'),
)
button = widgets.Button(description="Predict")
output = widgets.Output()

In [4]:
# 3) Define the callback—and bind it *before* we display anything
def on_predict(_):
    vals = np.array([[fld.value for fld in fields.values()]])
    Xs   = scaler.transform(vals)
    with torch.no_grad():
        logits = model(torch.from_numpy(Xs).float())
        prob   = torch.sigmoid(logits).item()
    verdict = "🪐 Habitable" if prob >= threshold else "✖ Non-habitable"
    with output:
        output.clear_output()
        display(Markdown(
            f"**Probability:** {prob:.3f}  \n"
            f"**Prediction (@{threshold}):** {verdict}"
        ))

button.on_click(on_predict)


In [5]:
# 4) Now display the form and the live button
for fld in fields.values():
    display(fld)
display(button, output)


BoundedFloatText(value=1.0, description='pl_rade:', min=0.1, step=0.1)

BoundedFloatText(value=1.0, description='pl_bmasse:', max=10000.0, min=0.1, step=0.1)

BoundedFloatText(value=1.0, description='pl_orbsmax:', max=1000.0, min=0.001, step=0.001)

BoundedFloatText(value=0.0, description='pl_orbeccen:', max=1.0, step=0.01)

BoundedFloatText(value=1.0, description='pl_insol:', max=100000.0, step=1.0)

BoundedIntText(value=5778, description='st_teff:', max=50000, min=2000, step=10)

BoundedFloatText(value=1.0, description='st_rad:', min=0.01, step=0.01)

BoundedFloatText(value=1.0, description='st_mass:', min=0.01, step=0.01)

Button(description='Predict', style=ButtonStyle())

Output()

## How to Use

1. Adjust any of the eight fields within their allowed ranges.  
2. Click **Predict**—the result shows immediately below.  
3. Change values and re-click to rerun the prediction.
