In [1]:
# test_phase_1_updated.ipynb

import torch
import os

# --- Troubleshooting Step ---
# If you are on macOS, run this in your terminal before starting jupyter:
# export PYTORCH_ENABLE_MPS_FALLBACK=1
# This can improve stability of the MPS backend.
os.environ.setdefault('PYTORCH_ENABLE_MPS_FALLBACK', '1')

# --- Import our new library components ---
from neural_mi.utils.helpers import get_device
from neural_mi.models.embeddings import MLP
from neural_mi.models.critics import SeparableCritic
from neural_mi.estimators import bounds

# --- Control Flag for Testing ---
# Set to True to force the test to run on the CPU.
# This helps isolate library issues from MPS-specific hardware issues.
test_on_cpu = False

# --- Determine Device ---
if test_on_cpu:
    device = 'cpu'
else:
    device = get_device()
    
print(f"Using device: {device}")

Using device: mps


In [2]:
# --- 1. Define Hyperparameters ---
BATCH_SIZE = 64
X_DIM = 100
Y_DIM = 10
HIDDEN_DIM = 128
EMBED_DIM = 32
N_LAYERS = 2

In [None]:
# --- 2. Create Dummy Data ---
x_sample = torch.randn(BATCH_SIZE, X_DIM).to(device)
y_sample = torch.randn(BATCH_SIZE, Y_DIM).to(device)
print("Dummy data shapes:")
print(f"X: {x_sample.shape} on {x_sample.device}")
print(f"Y: {y_sample.shape} on {y_sample.device}")
print("-" * 20)

In [4]:
# --- 3. Build the Model (Critic) ---
# Ensure models are moved to the correct device
embedding_x = MLP(input_dim=X_DIM, hidden_dim=HIDDEN_DIM, embed_dim=EMBED_DIM, n_layers=N_LAYERS).to(device)
embedding_y = MLP(input_dim=Y_DIM, hidden_dim=HIDDEN_DIM, embed_dim=EMBED_DIM, n_layers=N_LAYERS).to(device)
critic = SeparableCritic(embedding_net_x=embedding_x, embedding_net_y=embedding_y)
print("Model built successfully.")
print("-" * 20)

Model built successfully.
--------------------


In [5]:
# --- 4. Get Scores Matrix ---
scores = critic(x_sample, y_sample)
print(f"Scores matrix shape: {scores.shape} on {scores.device}")
assert scores.shape == (BATCH_SIZE, BATCH_SIZE)
print("-" * 20)

Scores matrix shape: torch.Size([64, 64]) on cpu
--------------------


In [6]:
# --- 5. Run All MI Estimators ---
# These functions should now correctly infer the device from the 'scores' tensor
print("Testing MI Bounds:")
try:
    mi_infonce = bounds.infonce_lower_bound(scores)
    print(f"  InfoNCE MI = {mi_infonce.item():.4f}")

    mi_nwj = bounds.nwj_lower_bound(scores)
    print(f"  NWJ MI     = {mi_nwj.item():.4f}")
    
    mi_tuba = bounds.tuba_lower_bound(scores)
    print(f"  TUBA MI    = {mi_tuba.item():.4f}")

    mi_js = bounds.js_fgan_lower_bound(scores)
    print(f"  JS-fGAN MI = {mi_js.item():.4f}")

    mi_smile = bounds.smile_lower_bound(scores)
    print(f"  SMILE MI   = {mi_smile.item():.4f}")
    
    print("\n✅ All estimators ran successfully!")
except Exception as e:
    print(f"\n❌ An error occurred: {e}")

Testing MI Bounds:
  InfoNCE MI = -0.0002
  NWJ MI     = -0.3105
  TUBA MI    = -0.0090
  JS-fGAN MI = -1.3902
  SMILE MI   = -0.0038

✅ All estimators ran successfully!
