# Getting Started

## Wrapping a PyTorch model
Create a simple PyTorch model.

In [2]:
import torch
import numpy as np
import torch.nn as nn

class SimpleModel(nn.Module):
    """
    Simple model class.
    """
    def __init__(self, input_dim, output_dim, hidden_dim = 64):
        super(SimpleModel, self).__init__()
        mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, output_dim)
        )
        self.mlp = mlp

    def forward(self, x):
        x = self.mlp(x)
        return x

Train the model on some data.

$$
y = x_0^2 +3 \sin(x_4)-2
$$

In [3]:
# Make the dataset 
x = np.array([np.random.uniform(0, 1, 1_000) for _ in range(5)]).T  
y = x[:, 0]**2 + 3*np.sin(x[:, 4]) - 4
noise = np.array([np.random.normal(0, 0.05*np.std(y)) for _ in range(len(y))])
y = y + noise 

In [4]:
# Set up training

import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

def train_model(model, dataloader, opt, criterion, epochs = 100):
    """
    Train a model for the specified number of epochs.
    
    Args:
        model: PyTorch model to train
        dataloader: DataLoader for training data
        opt: Optimizer
        criterion: Loss function
        epochs: Number of training epochs
        
    Returns:
        tuple: (trained_model, loss_tracker)
    """
    loss_tracker = []
    for epoch in range(epochs):
        epoch_loss = 0.0
        
        for batch_x, batch_y in dataloader:
            # Forward pass
            pred = model(batch_x)
            loss = criterion(pred, batch_y)
            
            # Backward pass
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            epoch_loss += loss.item()
        
        loss_tracker.append(epoch_loss)
        if (epoch + 1) % 5 == 0:
            avg_loss = epoch_loss / len(dataloader)
            print(f'Epoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.6f}')
    return model, loss_tracker

model = SimpleModel(input_dim=x.shape[1], output_dim=1)
criterion = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr=0.001)
X_train, _, y_train, _ = train_test_split(x, y.reshape(-1,1), test_size=0.2, random_state=290402)
dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [5]:
# Train the model and save the weights

model, losses = train_model(model, dataloader, opt, criterion, 20)
torch.save(model.state_dict(), 'model_weights.pth')

Epoch [5/20], Avg Loss: 0.296784
Epoch [10/20], Avg Loss: 0.158174
Epoch [15/20], Avg Loss: 0.128454
Epoch [20/20], Avg Loss: 0.107239


Now, create a new model with the same architecture as the previous model but wrap the MLP.

In [7]:
from interpretsr import mlp_sr

class SimpleModel_SR(nn.Module):
    """
    Simple model class.
    """
    def __init__(self, input_dim, output_dim, hidden_dim = 64):
        super(SimpleModel, self).__init__()
        mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, output_dim)
        )
        self.mlp = mlp_sr.MLP_SR(mlp)

    def forward(self, x):
        x = self.mlp(x)
        return x