In [3]:
import gc
from typing import List

import numpy as np
import pandas as pd
from scipy import stats

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.animation as animation
from IPython.display import HTML, display

from sklearn.datasets import fetch_covtype
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, log_loss
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.optim import Optimizer
from torch.autograd import Function
import torch.utils.benchmark as benchmark
import torch._dynamo
import torch._inductor.metrics as metrics

from torchinfo import summary

In [None]:
class CustomActivation(nn.Module):
    def __init__(self, num_features, num_control_points=1000, eps=1e-6):
        super(CustomActivation, self).__init__()
        self.register_buffer("mins", None)
        self.register_buffer("maxs", None)

        self.eps = eps
        self.num_control_points = num_control_points

        self.r_weight = nn.Parameter(torch.zeros(num_features, num_control_points))  # (num_features, num_control_points)
        self.l_weight = nn.Parameter(torch.zeros(num_features, num_control_points))  # (num_features, num_control_points)

        self.register_buffer("local_bias", torch.arange(num_control_points))  # (num_control_points,)
        self.register_buffer("feature_offset", torch.arange(num_features).view(1, -1) * self.num_control_points)  # (1, num_features)

    def forward(self, x):
        # x: (batch_size, num_features)

        if self.training or self.mins is None or self.maxs is None:
            self.mins = x.amin(dim=0, keepdim=True)  # (1, num_features)
            self.maxs = x.amax(dim=0, keepdim=True)  # (1, num_features)

        x = (x - self.mins) / (self.maxs - self.mins + self.eps) * (self.num_control_points - 1)  # (batch_size, num_features)

        # TODO: may change to feature-major order (num_features, batch_size) since that may help with memory access patterns (improved locality)
        
        lower_indices_float = x.floor().clamp(0, self.num_control_points - 2)  # (batch_size, num_features)
        lower_indices = lower_indices_float.long() + self.feature_offset  # (batch_size, num_features)

        indices = torch.stack((lower_indices, lower_indices + 1), dim=-1)  # (batch_size, num_features, 2)
        vals = F.embedding(indices, self.get_interp_tensor())  # (batch_size, num_features, 2, 1)

        lower_val, upper_val = vals.squeeze(-1).unbind(-1)  # each: (batch_size, num_features)
        return torch.lerp(lower_val, upper_val, x - lower_indices_float)  # (batch_size, num_features)
    
    def get_interp_tensor(self):
        cs_r_weight = torch.cumsum(self.r_weight, dim=1)  # (num_features, num_control_points)
        cs_l_weight = torch.cumsum(self.l_weight, dim=1)  # (num_features, num_control_points)

        cs_r_weight_bias_prod = torch.cumsum(self.r_weight * self.local_bias, dim=1)  # type: ignore (num_features, num_control_points)
        cs_l_weight_bias_prod = torch.cumsum(self.l_weight * self.local_bias, dim=1)  # type: ignore (num_features, num_control_points)

        r_interp = (self.local_bias * cs_r_weight - cs_r_weight_bias_prod)  # type: ignore (num_features, num_control_points)
        l_interp = (cs_l_weight_bias_prod[..., -1:] - cs_l_weight_bias_prod) - self.local_bias * (cs_l_weight[..., -1:] - cs_l_weight)  # type: ignore (num_features, num_control_points)
        return (r_interp + l_interp).view(-1, 1)  # (num_features * num_control_points, 1)
        