In [1]:
import gc
from typing import List

import numpy as np
import pandas as pd
from scipy import stats

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.animation as animation
from IPython.display import HTML, display

from sklearn.datasets import fetch_covtype
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, log_loss
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.optim import Optimizer
from torch.autograd import Function
import torch.utils.benchmark as benchmark
import torch._dynamo
import torch._inductor.metrics as metrics

from torchinfo import summary

In [12]:
import time

def benchmark(model, x, y, optimizer, loss_fn, iterations, device):
    """
    Benchmarks the forward and backward pass of a given PyTorch model.

    Args:
        model (nn.Module): The model to benchmark.
        x (torch.Tensor): The input data.
        y (torch.Tensor): The target data.
        optimizer (torch.optim.Optimizer): The optimizer for the backward pass.
        loss_fn: The loss function.
        iterations (int): The number of iterations to run the benchmark.
        device (str): The device to run the benchmark on ('cpu' or 'cuda').
    """
    forward_times = []
    backward_times = []

    # Warm-up iterations to handle initial CUDA overhead, etc.
    for _ in range(5):
        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()

    # Ensure all CUDA operations are synchronized before starting the timer
    if device == 'cuda':
        torch.cuda.synchronize()

    # --- Benchmarking Loop ---
    for i in range(iterations):
        optimizer.zero_grad()

        # Benchmark Forward Pass
        start_time = time.perf_counter()
        output = model(x)
        if device == 'cuda':
            torch.cuda.synchronize()
        end_time = time.perf_counter()
        forward_times.append(end_time - start_time)

        loss = loss_fn(output, y)

        # Benchmark Backward Pass
        start_time = time.perf_counter()
        loss.backward()
        if device == 'cuda':
            torch.cuda.synchronize()
        end_time = time.perf_counter()
        backward_times.append(end_time - start_time)

        optimizer.step()

    # --- Report Results ---
    avg_forward = sum(forward_times) / iterations * 1000  # Convert to ms
    avg_backward = sum(backward_times) / iterations * 1000 # Convert to ms

    print("-" * 50)
    print(f"Benchmarking Results on '{device.upper()}'")
    print(f"Iterations: {iterations}")
    print(f"Average Forward Pass Time:  {avg_forward:.4f} ms")
    print(f"Average Backward Pass Time: {avg_backward:.4f} ms")
    print("-" * 50)

In [28]:
BATCH_SIZE = 1000
NUM_FEATURES = 10
NUM_CONTROL_POINTS = 1000
ITERATIONS = 100
LEARNING_RATE = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Using device: {DEVICE}")

# Create random input and target tensors
x_input = torch.randn(BATCH_SIZE, NUM_FEATURES).to(DEVICE)
y_target = torch.randn(BATCH_SIZE, NUM_FEATURES).to(DEVICE)

Using device: cpu


In [29]:
class CustomActivation(nn.Module):
    def __init__(self, num_features, num_control_points=1000, eps=1e-6):
        super(CustomActivation, self).__init__()
        self.register_buffer("mins", None)
        self.register_buffer("maxs", None)

        self.eps = eps
        self.num_control_points = num_control_points

        self.r_weight = nn.Parameter(torch.zeros(num_features, num_control_points))  # (num_features, num_control_points)
        self.l_weight = nn.Parameter(torch.zeros(num_features, num_control_points))  # (num_features, num_control_points)

        self.register_buffer("local_bias", torch.arange(num_control_points))  # (num_control_points,)
        self.register_buffer("feature_offset", torch.arange(num_features).view(1, -1) * self.num_control_points)  # (1, num_features)

    def forward(self, x):
        # x: (batch_size, num_features)

        if self.training or self.mins is None or self.maxs is None:
            self.mins = x.amin(dim=0, keepdim=True)  # (1, num_features)
            self.maxs = x.amax(dim=0, keepdim=True)  # (1, num_features)

        x = (x - self.mins) / (self.maxs - self.mins + self.eps) * (self.num_control_points - 1)  # (batch_size, num_features)

        # TODO: may change to feature-major order (num_features, batch_size) since that may help with memory access patterns (improved locality)
        
        lower_indices_float = x.floor().clamp(0, self.num_control_points - 2)  # (batch_size, num_features)
        lower_indices = lower_indices_float.long() + self.feature_offset  # (batch_size, num_features)

        indices = torch.stack((lower_indices, lower_indices + 1), dim=-1)  # (batch_size, num_features, 2)
        vals = F.embedding(indices, self.get_interp_tensor())  # (batch_size, num_features, 2, 1)

        lower_val, upper_val = vals.squeeze(-1).unbind(-1)  # each: (batch_size, num_features)
        return torch.lerp(lower_val, upper_val, x - lower_indices_float)  # (batch_size, num_features)
    
    def get_interp_tensor(self):
        cs_r_weight = torch.cumsum(self.r_weight, dim=1)  # (num_features, num_control_points)
        cs_l_weight = torch.cumsum(self.l_weight, dim=1)  # (num_features, num_control_points)

        cs_r_weight_bias_prod = torch.cumsum(self.r_weight * self.local_bias, dim=1)  # type: ignore (num_features, num_control_points)
        cs_l_weight_bias_prod = torch.cumsum(self.l_weight * self.local_bias, dim=1)  # type: ignore (num_features, num_control_points)

        r_interp = (self.local_bias * cs_r_weight - cs_r_weight_bias_prod)  # type: ignore (num_features, num_control_points)
        l_interp = (cs_l_weight_bias_prod[..., -1:] - cs_l_weight_bias_prod) - self.local_bias * (cs_l_weight[..., -1:] - cs_l_weight)  # type: ignore (num_features, num_control_points)
        return (r_interp + l_interp).view(-1, 1)  # (num_features * num_control_points, 1)
        

In [30]:
# --- Setup ---
# Instantiate the model and move it to the selected device
model = CustomActivation(
    num_features=NUM_FEATURES,
    num_control_points=NUM_CONTROL_POINTS
).to(DEVICE)
model.train() # Set model to training mode

# Define a simple loss function and optimizer
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Run Benchmark ---
benchmark(model, x_input, y_target, optimizer, loss_function, ITERATIONS, DEVICE)

--------------------------------------------------
Benchmarking Results on 'CPU'
Iterations: 100
Average Forward Pass Time:  0.3904 ms
Average Backward Pass Time: 0.5963 ms
--------------------------------------------------


In [None]:
class CustomActivation(nn.Module):
    def __init__(self, num_features, num_control_points):
        super(CustomActivation, self).__init__()
        # self.register_buffer('local_bias', torch.linspace(1, num_control_points * 2 - 1, num_control_points).expand(num_features, -1))
        self.register_buffer('local_bias', torch.linspace(-5, 5, num_control_points).expand(num_features, -1))

        self.pos_weight = nn.Parameter(torch.zeros(num_features, num_control_points))
        self.neg_weight = nn.Parameter(torch.zeros(num_features, num_control_points))

        self.global_weight = nn.Parameter(torch.zeros(1, num_features))
        self.global_bias = nn.Parameter(torch.zeros(1, num_features))

        self.x_nonlinear_mean = None

    def forward(self, x):
        x_shifted = x.unsqueeze(-1) + self.local_bias
        x_nonlinear = (F.relu(x_shifted) * self.pos_weight).sum(dim=-1) + (F.relu(-x_shifted) * self.neg_weight).sum(dim=-1)
        return x_nonlinear

In [32]:
# --- Setup ---
# Instantiate the model and move it to the selected device
model = CustomActivation(
    num_features=NUM_FEATURES,
    num_control_points=NUM_CONTROL_POINTS
).to(DEVICE)
model.train() # Set model to training mode

# Define a simple loss function and optimizer
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Run Benchmark ---
benchmark(model, x_input, y_target, optimizer, loss_function, ITERATIONS, DEVICE)

--------------------------------------------------
Benchmarking Results on 'CPU'
Iterations: 100
Average Forward Pass Time:  12.1262 ms
Average Backward Pass Time: 5.7455 ms
--------------------------------------------------
