In [None]:
import importlib
import inspect
import math
import sys
from pathlib import Path
import numpy as np
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch.utils.data import DataLoader, TensorDataset

# Cell 0
# Experiment: build MLP from repo (if available), otherwise fallback to local PyTorch MLP.
# Train on a toy dataset with different activation functions and plot loss/accuracy + decision boundaries.


import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim

# Configuration
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATASET_N_SAMPLES = 1000
TEST_SIZE = 0.2
BATCH_SIZE = 64
EPOCHS = 200
HIDDEN_SIZES = [64, 32]
LR = 1e-3

# Candidate module/class names to try from repo
CANDIDATES = [
    ("mlp", "MLP"),
    ("model", "MLP"),
    ("models.mlp", "MLP"),
    ("models", "MLP"),
    ("network", "MLP"),
    ("net", "MLP"),
    ("mlp_model", "MLP"),
]

def try_import_repo_mlp():
    """
    Try to import an MLP implementation from the repository. Returns a callable factory
    that accepts (input_dim, hidden_sizes, output_dim, activation) and returns nn.Module,
    or None if no suitable repo MLP is found.
    """
    for module_name, class_name in CANDIDATES:
        try:
            mod = importlib.import_module(module_name)
        except Exception:
            continue
        # try attribute
        cls = getattr(mod, class_name, None)
        if cls is None:
            # maybe module itself is a class
            if inspect.isclass(mod) and mod.__name__.lower().startswith("mlp"):
                cls = mod
        if cls is None:
            continue
        # Check constructor signature to see if activation param is supported
        sig = None
        try:
            sig = inspect.signature(cls)
        except Exception:
            pass
        def factory(input_dim, hidden_sizes, output_dim, activation):
            """
            Try to instantiate repo MLP. If it supports activation argument, pass it.
            Else try positional args (input, hidden, output) otherwise fall back to local wrapper.
            """
            # If class expects an activation argument, pass it
            try:
                if sig and "activation" in sig.parameters:
                    return cls(input_dim, hidden_sizes, output_dim, activation=activation)
                # try common alternative names
                for name in ("act_fn", "nonlinearity", "activation_fn"):
                    if sig and name in sig.parameters:
                        return cls(input_dim, hidden_sizes, output_dim, **{name: activation})
                # try simple constructor patterns
                try:
                    return cls(input_dim, hidden_sizes, output_dim)
                except Exception:
                    # try only hidden sizes
                    try:
                        return cls(hidden_sizes, activation=activation)
                    except Exception:
                        raise
            except Exception:
                raise
        # test instantiate with dummy args (do not fail hard here)
        try:
            _ = factory(2, HIDDEN_SIZES, 2, nn.ReLU())
            print(f"Using repository MLP: {module_name}.{class_name}")
            return factory
        except Exception:
            # not compatible, continue searching
            continue
    return None

# Local fallback MLP builder
def build_local_mlp(input_dim, hidden_sizes, output_dim, activation):
    layers = []
    prev = input_dim
    for h in hidden_sizes:
        layers.append(nn.Linear(prev, h))
        layers.append(activation if isinstance(activation, nn.Module) else activation())
        prev = h
    layers.append(nn.Linear(prev, output_dim))
    return nn.Sequential(*layers)

# Build activations map
ACTIVATIONS = {
    "relu": lambda: nn.ReLU(),
    "tanh": lambda: nn.Tanh(),
    "sigmoid": lambda: nn.Sigmoid(),
    "leaky_relu": lambda: nn.LeakyReLU(0.1),
    "elu": lambda: nn.ELU(),
    "swish": lambda: nn.SiLU(),  # SiLU is a good swish approximation
}

# Try to get repo MLP factory
repo_factory = try_import_repo_mlp()

def make_model(act_name):
    act_ctor = ACTIVATIONS[act_name]
    activation_module = act_ctor()
    if repo_factory is not None:
        try:
            m = repo_factory(2, HIDDEN_SIZES, 2, activation_module)
            if isinstance(m, nn.Module):
                return m.to(DEVICE)
        except Exception:
            pass
    # fallback
    return build_local_mlp(2, HIDDEN_SIZES, 2, activation_module).to(DEVICE)

# Prepare dataset
X, y = make_moons(n_samples=DATASET_N_SAMPLES, noise=0.2, random_state=SEED)
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=TEST_SIZE, random_state=SEED)

train_ds = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
val_ds = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# Training utilities
def evaluate(model, loader):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb)
            loss = loss_fn(logits, yb)
            total_loss += loss.item() * xb.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
    return total_loss / total, correct / total

def train_one(model, train_loader, val_loader, epochs=EPOCHS, lr=LR):
    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            optimizer.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * xb.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
        train_loss = running_loss / total
        train_acc = correct / total
        val_loss, val_acc = evaluate(model, val_loader)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
    return history

# Run experiments
results = {}
for act_name in ACTIVATIONS.keys():
    print(f"Training with activation: {act_name}")
    model = make_model(act_name)
    hist = train_one(model, train_loader, val_loader)
    results[act_name] = {"model": model, "history": hist}

# Plotting: loss and accuracy curves
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
for name, res in results.items():
    plt.plot(res["history"]["train_loss"], label=f"{name} train", alpha=0.6)
    plt.plot(res["history"]["val_loss"], linestyle="--", label=f"{name} val", alpha=0.9)
plt.title("Loss curves")
plt.xlabel("Epoch")
plt.ylabel("Cross-entropy")
plt.legend(fontsize="small", ncol=2)

plt.subplot(1, 2, 2)
for name, res in results.items():
    plt.plot(res["history"]["val_acc"], label=name)
plt.title("Validation accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend(fontsize="small", ncol=2)
plt.tight_layout()
plt.show()

# Decision boundary plots
xx_min, xx_max = X[:, 0].min() - .5, X[:, 0].max() + .5
yy_min, yy_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.linspace(xx_min, xx_max, 200), np.linspace(yy_min, yy_max, 200))
grid = np.c_[xx.ravel(), yy.ravel()]
grid_scaled = scaler.transform(grid)
grid_t = torch.tensor(grid_scaled, dtype=torch.float32).to(DEVICE)

n = len(results)
cols = 3
rows = math.ceil(n / cols)
plt.figure(figsize=(4 * cols, 4 * rows))
i = 1
for name, res in results.items():
    model = res["model"]
    model.eval()
    with torch.no_grad():
        logits = model(grid_t)
        probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
    Z = probs.reshape(xx.shape)
    plt.subplot(rows, cols, i)
    plt.contourf(xx, yy, Z, levels=20, cmap="RdYlBu", alpha=0.8)
    plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolor="k", cmap="bwr", s=20, alpha=0.6)
    plt.title(f"Decision boundary: {name}")
    plt.xlim(xx_min, xx_max)
    plt.ylim(yy_min, yy_max)
    i += 1

plt.tight_layout()
plt.show()

# Summarize final validation accuracies
print("Final validation accuracies:")
for name, res in results.items():
    acc = res["history"]["val_acc"][-1]
    print(f"  {name}: {acc:.4f}")

# Save figures to files for later review
out_dir = Path("mlp_activation_experiments")
out_dir.mkdir(exist_ok=True)
plt.savefig(out_dir / "decision_boundaries.png", bbox_inches="tight")
# Note: previous loss/accuracy figure was shown inline; if rerunning, you can save it similarly.