<a href="https://colab.research.google.com/github/jahnavi-maddhuri/JahnaviMaddhuri-DukeXAI/blob/main/mechanistic_interpreter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jahnavi-maddhuri/JahnaviMaddhuri-DukeXAI/blob/main/mechanistic_interpreter.ipynb)

# Interpreting an MLP that Attempts for String Reversal
Multilayer Perceptrons have grown to be powerful feed forward neural networks made of fully connected layers that are stacked on top of each other. These layers begin with an input layer, end with an output layer and have at least one hidden layer in between. Often times, with these powerful models, interpretability goes out the window and the model user often has no idea what is happening behind the scenes of these predictions.

In this notebook I try to demystify the black box to improve explainability and interpretability of these types of models. I use a very simple use case and model below to start by explaining a simpler, smaller example that can then be extrapolated to understand more complex models as well.

In the example below, my MLP has one input, only one hidden layer and one output layer, keeping it simple. Next, my MLPs objective is to use one-hot-encoded strings and produce them in reverse order. I use synthetic data to fit this model with the simple use case of reversing strings of length three with from vocab size of three.

To better understand what Multilayer Perceptrons were, how to understand the hidden layers and activation functions and also develop the scripts and fine tune them, I often used ChatGPT 5.1. Feel free to read the conversation here: https://chatgpt.com/c/69167c70-ee64-8327-9be2-f76bafb68287

In [1]:
# Necessary Library Improts for NN
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import numpy as np

### Data Generation
First, I generate synthetic data below so the model can train and test on this dataset. For simplicity, my strings are only of length three and only have an alphabet with three characters, a, b and c. The generated X are 2500 records with each letter one-hot-encoded to form a matrix with dimensions (2500 x 3 x 3). For y, the output sequence are array representations of the index encodings. This is a matrix with dimensions (2500 x 3).

For training and testing, I split this dataset into 2000 training samples and 500 testing samples.

In [2]:
def generate_string_dataset(n_train=2000, n_test=500, seq_length=3, alphabet="abc"):
    """
    Generates synthetic data for string reversal.

    X: one-hot encoded original strings
       shape: (n_samples, seq_length, vocab_size), dtype=float32
    y: reversed strings as integer indices
       shape: (n_samples, seq_length), dtype=long
    """
    n_samples = n_train + n_test
    vocab = list(alphabet)
    vocab_size = len(vocab)
    char_to_idx = {ch: i for i, ch in enumerate(vocab)}

    X = torch.zeros(n_samples, seq_length, vocab_size, dtype=torch.float32)
    y = torch.zeros(n_samples, seq_length, dtype=torch.long)

    for n in range(n_samples):
        # random string of exactly seq_length
        s = ''.join(random.choice(vocab) for _ in range(seq_length))

        # one-hot encode original string into X[n]
        for t, ch in enumerate(s):
            X[n, t, char_to_idx[ch]] = 1.0

        # reversed string â†’ indices
        s_rev = s[::-1]
        for t, ch in enumerate(s_rev):
            y[n, t] = char_to_idx[ch]

    # Optional: shuffle before splitting
    perm = torch.randperm(n_samples)
    X = X[perm]
    y = y[perm]

    # Train/test split
    X_train = X[:n_train]
    y_train = y[:n_train]
    X_test  = X[n_train:n_train + n_test]
    y_test  = y[n_train:n_train + n_test]

    return X_train, y_train, X_test, y_test, vocab, char_to_idx

In [19]:
X_train, y_train, X_test, y_test, vocab, char_to_idx = generate_string_dataset()
seq_length = X_train.size(1)
vocab_size = X_train.size(2)
print(f'Dimensions of X (train):{X_train.shape}')  # (2000, 3, 3)
print(f'Dimensions of y (train):{y_train.shape}') # (2000, 3)

Dimensions of X (train):torch.Size([2000, 3, 3])
Dimensions of y (train):torch.Size([2000, 3])


### Build a my String Reversal MLP
This simple model has the architecture (1) Input Layer, (2) Hidden Layer, (3) Output Layer. My Input Layer takes in one-hot-encodings of the strings. I set my hidden layer to be of size 10. This means that there are 10 internal neurons, or features that the model uses to learn the reversal behavior. Each hidden neuron is meant to detect a slightly different pattern, and thus is tuned to a different weight through the model's training process.

Below, I create the MLP, train and test the model, producing evaluation metrics at the end of the process.

In [13]:
class StringReverseMLP(nn.Module):
    def __init__(self, seq_length, vocab_size, hidden_dim=10):
        super().__init__()
        self.seq_length = seq_length
        self.vocab_size = vocab_size

        # Input is flattened one-hot: seq_length * vocab_size
        input_dim = seq_length * vocab_size

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, seq_length * vocab_size)

    def forward(self, x):
        """
        x: (batch_size, seq_length, vocab_size) one-hot
        returns:
          logits: (batch_size, seq_length, vocab_size)
          h:      (batch_size, hidden_dim)
        """
        batch_size = x.size(0)
        x_flat = x.view(batch_size, -1)  # (batch_size, seq_length * vocab_size)

        h = F.relu(self.fc1(x_flat))
        logits_flat = self.fc2(h)  # (batch_size, seq_length * vocab_size)
        logits = logits_flat.view(batch_size, self.seq_length, self.vocab_size)

        return logits, h

In [14]:
# Instantiate model (one hidden layer, 10 activations)
model = StringReverseMLP(seq_length=seq_length,
                         vocab_size=vocab_size,
                         hidden_dim=10)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

losses = []
for epoch in range(100):
    model.train()

    # Forward
    logits, h = model(X_train)           # (N, L, V)
    N, L, V = logits.shape

    # Flatten for CrossEntropyLoss
    logits_flat = logits.view(N * L, V)  # (N*L, V)
    targets_flat = y_train.view(N * L)   # (N*L,)

    loss = criterion(logits_flat, targets_flat)

    # Backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.item())

print("Final training loss:", losses[-1])


StringReverseMLP(
  (fc1): Linear(in_features=9, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=9, bias=True)
)
Final training loss: 0.06964609026908875


In [16]:
model.eval()
with torch.no_grad():
    logits_val, h_val = model(X_test)
    preds_idx = logits_val.argmax(dim=-1)  # (N, L)

    # Per-position accuracy
    per_pos_acc = (preds_idx == y_test).float().mean().item()

    # Exact string accuracy (all positions correct)
    exact_match = (preds_idx == y_test).all(dim=1).float().mean().item()

print(f"Per-position accuracy: {per_pos_acc:.2f}")
print(f"Exact string accuracy: {exact_match:.2f}")


Per-position accuracy: 1.00
Exact string accuracy: 1.00


Wow! The MLP is an excellent predictor for the reversal of size 3 strings of size 3 vocab.

# Explanations
While this accuracy is extremely high and the model appears to work well, we still don't really understand what's happening underneath the hood. Let's dive deep into a specific example from our test set to understand what our model predicted and how it got to that prediction.

In [18]:
def decode_indices(idx_tensor, vocab):
    return ''.join(vocab[int(i)] for i in idx_tensor)

sample_idx = 0
x_sample = X_test[sample_idx:sample_idx+1]  # (1, L, V)
y_sample = y_test[sample_idx]               # (L,)

with torch.no_grad():
    logits_sample, hidden_sample = model(x_sample)
    probs_sample = torch.softmax(logits_sample, dim=-1)  # (1, L, V)
    pred_idx_sample = logits_sample.argmax(dim=-1).squeeze(0)  # (L,)

orig_idx = X_test[sample_idx].argmax(dim=-1)  # original string indices (L,)
orig_str = decode_indices(orig_idx, vocab)
true_rev_str = decode_indices(y_sample, vocab)
pred_rev_str = decode_indices(pred_idx_sample, vocab)

print("Original string:      ", orig_str)
print("True reversed string: ", true_rev_str)
print("Predicted reversed:   ", pred_rev_str)
print("\nHidden layer activations:")
print(hidden_sample)       # (1, 10)
print("\nOutput logits:")
print(logits_sample)
print("\nOutput probabilities (softmax):")
print(probs_sample)


Original string:       bac
True reversed string:  cab
Predicted reversed:    cab

Hidden layer activations:
tensor([[2.3032, 0.0000, 1.0295, 0.0000, 0.0000, 3.0154, 0.0625, 0.0000, 0.2466,
         0.0000]])

Output logits:
tensor([[[-2.0811, -2.9473,  2.0437],
         [ 3.3705, -0.9398, -3.9562],
         [-0.5727,  3.9886,  0.7585]]])

Output probabilities (softmax):
tensor([[[1.5803e-02, 6.6461e-03, 9.7755e-01],
         [9.8611e-01, 1.3242e-02, 6.4856e-04],
         [9.9509e-03, 9.5238e-01, 3.7670e-02]]])
