# Tutorial 7: Advanced Customization with Custom Models

Welcome to our final tutorial! You have now mastered the main workflows of `NeuralMI`, from simple estimates to rigorous, publication-ready analyses. But what happens when your research requires a model architecture that isn't built into the library? 

This tutorial is for the advanced user who wants maximum flexibility. We will show you how to define your own models using PyTorch and seamlessly integrate them into the `nmi.run` pipeline, ensuring that `NeuralMI` can grow with your research needs.

## 1. The Requirements for a Custom Model

To be compatible with the `NeuralMI` trainer, any custom model must meet two simple requirements:

1.  It must inherit from **`nmi.models.BaseCritic`**.
2.  Its `forward` method must accept two arguments, `x` and `y`, and return a **tuple** containing:
    - `scores`: A `(batch_size, batch_size)` tensor of similarity scores.
    - `kl_loss`: A scalar tensor for any KL divergence loss. If not using a variational model, this should be `torch.tensor(0.0)`.

Let's explore the two main ways to achieve this.

In [13]:
import torch
import torch.nn as nn
import numpy as np
import neural_mi as nmi
import seaborn as sns

sns.set_context("talk")

## 2. Method 1: Full Control with `custom_critic`

This method gives you complete control. You define the entire critic architecture from scratch and pass a pre-initialized **instance** of your model to `nmi.run`.

Let's build a simple custom critic that uses a linear embedding layer.

In [14]:
# Define a simple embedding model (can be anything that inherits from BaseEmbedding)
class LinearEmbedding(nmi.models.BaseEmbedding):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim, embedding_dim)

    def forward(self, x):
        x_flat = x.view(x.shape[0], -1)
        return self.layer(x_flat)

# Define our custom critic that uses the embedding model
class MyCustomSeparableCritic(nmi.models.BaseCritic):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.embedding_net = LinearEmbedding(input_dim, embedding_dim)

    def forward(self, x, y):
        x_embedded = self.embedding_net(x)
        y_embedded = self.embedding_net(y)
        scores = torch.matmul(x_embedded, y_embedded.t())
        return scores, torch.tensor(0.0, device=scores.device)

print("Custom critic class defined successfully!")

Custom critic class defined successfully!


### Using the Custom Critic in `nmi.run`

Using our new model is simple: we instantiate our critic and pass the **instance** to the `custom_critic` argument. The library will then skip its internal model-building logic and use our model directly. Any model architecture parameters in `base_params` (like `embedding_dim`, `hidden_dim`, etc.) will be ignored.

In [15]:
# --- Generate some simple data ---
x_raw, y_raw = nmi.datasets.generate_correlated_gaussians(n_samples=5000, dim=5, mi=2.0)

# --- Instantiate our model ---
my_critic_instance = MyCustomSeparableCritic(input_dim=5, embedding_dim=16)

# --- Define trainer parameters (no model architecture params needed) ---
base_params = {
    'n_epochs': 50, 'learning_rate': 1e-3, 'batch_size': 128,
    'patience': 10
}

# --- Run the estimation ---
results = nmi.run(
    x_data=x_raw.T, y_data=y_raw.T,
    mode='estimate',
    processor_type_x='continuous',
    processor_params_x={'window_size': 1},
    base_params=base_params,
    custom_critic=my_critic_instance, # Here is the magic!
    n_workers=1,
    random_seed=42
)

print(f"\n--- Results with custom_critic ---")
print(f"Ground Truth MI:  2.000 bits")
print(f"Estimated MI:     {results.mi_estimate:.3f} bits")

2025-10-07 23:27:48 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...


Sequential Sweep Progress:   0%|          | 0/1 [00:00<?, ?it/s]

Run 1cea093d-85f7-411c-b79c-4c69872ba971_c0:   0%|          | 0/50 [00:00<?, ?it/s]

2025-10-07 23:27:50 - neural_mi - INFO - Parameter sweep finished.

--- Results with custom_critic ---
Ground Truth MI:  2.000 bits
Estimated MI:     1.588 bits


## 3. Method 2: Modular Control with `custom_embedding_cls`

Sometimes you don't need to reinvent the wheel. You might like the library's built-in `SeparableCritic`, but you just want to swap out the embedding model (e.g., use a Transformer instead of an MLP).

The `custom_embedding_cls` parameter is perfect for this. Instead of a model *instance*, you provide the **class** of your custom embedding model. The library will then handle instantiating it for you, using the architecture parameters from `base_params`.

**Important:** For this to work, your custom embedding's `__init__` method must be designed to accept the standard parameters that the library provides: `input_dim`, `hidden_dim`, `embed_dim`, and `n_layers`.

In [16]:
# Define a more complex custom embedding that is compatible with the library's builder
class CustomMLP(nmi.models.BaseEmbedding):
    # This __init__ signature matches the arguments the library's internal builder will provide
    def __init__(self, input_dim: int, hidden_dim: int, embed_dim: int, n_layers: int, activation: str = 'relu'):
        super().__init__()
        
        # You can define any architecture you want inside
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(n_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
        self.network = nn.Sequential(*layers)
        self.output_layer = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.output_layer(self.network(x.view(x.shape[0], -1)))

print("CustomMLP class defined successfully!")

CustomMLP class defined successfully!


In [17]:
# --- Define model and trainer parameters ---
# This time, we DO need to provide the architecture params, as the library will use them
# to instantiate our CustomMLP class.
base_params_cls = {
    'n_epochs': 50, 'learning_rate': 1e-3, 'batch_size': 128,
    'patience': 10, 'embedding_dim': 16, 'hidden_dim': 64, 'n_layers': 2,
    'critic_type': 'separable'
}

# --- Run the estimation ---
results_cls = nmi.run(
    x_data=x_raw.T, y_data=y_raw.T,
    mode='estimate',
    processor_type_x='continuous',
    processor_params_x={'window_size': 1},
    base_params=base_params_cls,
    custom_embedding_cls=CustomMLP, # Pass the CLASS here
    n_workers=1,
    random_seed=42
)

print(f"\n--- Results with custom_embedding_cls ---")
print(f"Ground Truth MI:  2.000 bits")
print(f"Estimated MI:     {results_cls.mi_estimate:.3f} bits")

2025-10-07 23:27:54 - neural_mi - INFO - Starting parameter sweep sequentially (n_workers=1)...


Sequential Sweep Progress:   0%|          | 0/1 [00:00<?, ?it/s]

Run c513a7b1-b883-49ec-a60f-301058d080d8_c0:   0%|          | 0/50 [00:00<?, ?it/s]

2025-10-07 23:27:58 - neural_mi - INFO - Parameter sweep finished.

--- Results with custom_embedding_cls ---
Ground Truth MI:  2.000 bits
Estimated MI:     2.013 bits


Success! The estimate is probably more accurate. This modular approach allows you to leverage the library's tested critic architectures while still having the freedom to design novel embedding models for your specific data.

## 4. Conclusion

Congratulations! You have completed the `NeuralMI` learning path. You now have the skills to handle complex neural data, choose the right model architecture, perform scientifically rigorous analyses, and even extend the library with your own custom models.

The `custom_critic` and `custom_embedding_cls` features provide escape hatches for maximum flexibility, ensuring that `NeuralMI` can serve as the foundation for your analysis, no matter how specialized your research becomes.