# Imports

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import einops
from tqdm.notebook import tqdm
from dataclasses import dataclass

# Model definition

In [None]:
@dataclass
class ModelsConfig:
    # Sweep over several models to reduce noise, following the paper
    n_models: int = 10
    # The paper sweeps over 50 densities and importances
    n_densities: int = 50
    n_importances: int = 50
    # The "toy model of the toy model" is a 2D -> 1D -> 2D mapping
    d_feature: int = 2
    d_model: int = 1
    device: str = "cuda"

In [None]:
class ReLUModels(nn.Module):
    def __init__(self, cfg: ModelsConfig):
        super().__init__()

        self.cfg = cfg
        n_models = cfg.n_models
        n_densities = cfg.n_densities
        n_importances = cfg.n_importances
        d_feature = cfg.d_feature
        d_model = cfg.d_model
        device = cfg.device

        self.W = nn.Parameter(torch.empty(n_models, n_densities, n_importances, d_feature, d_model, device=device))
        # Kaiming initialization works better than Xavier for layers with ReLU activation
        # See https://stats.stackexchange.com/questions/319323/whats-the-difference-between-variance-scaling-initializer-and-xavier-initialize/319849#319849
        nn.init.xavier_normal_(self.W)
        self.b = nn.Parameter(torch.zeros(n_models, n_densities, n_importances, d_feature, device=device))

    def forward(self, x):
        h = einops.einsum(
            x, self.W,
            "model density importance batch d_feature, model density importance d_feature d_model -> model density importance batch d_model"
        )
        out = F.relu(
            einops.einsum(
                h, self.W,
                "model density importance batch d_model, model density importance d_feature d_model -> model density importance batch d_feature"
            ) + self.b.unsqueeze(-2) # unsqueeze adds batch dimension to b
        )
        return out

# Data generation

In [None]:
def generate_batch(cfg, batch_size, densities, device='cuda'):
    n_models = cfg.n_models
    n_densities = cfg.n_densities
    n_importances = cfg.n_importances
    d_feature = cfg.d_feature

    feat_vals = torch.rand(n_models, n_densities, n_importances, batch_size, d_feature, device=device)
    feat_probs = torch.rand(n_models, n_densities, n_importances, batch_size, d_feature, device=device)
    sparsity_mask = feat_probs < densities.view(-1, 1, 1, 1)
    # Only those feats with prob < respective density will be present in the generated data
    return feat_vals * sparsity_mask

# Loss function

In [None]:
class ImportanceWeightedMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target, importances):
        squared_error = (target - input) ** 2
        # unsqueeze(-2) adds batch dimension to importances before multiplying.
        # After adding the new dimension, importances is of shape
        # (n_importances, 1, d_feature) and squared_error is of the same shape
        # as input, i.e., (..., n_importances, batch_size, d_feature).
        # Then, we return the mean error over features and batches.
        return einops.reduce(
            squared_error * importances.unsqueeze(-2),
            "... batch feature -> ...",
            "mean"
        )

# Training loop

In [None]:
@dataclass
class TrainingArgs:
    lr: float = 1e-3
    n_epochs: int = 5000
    log_interval: int = 150
    batch_size: int = 1024

In [None]:
def train(models: ReLUModels, training_args: TrainingArgs, optimizer: optim.Optimizer, loss_fn: nn.Module, densities, importances, device='cuda'):
    optimizer = optimizer(models.parameters(), lr=training_args.lr)
    loss_fn = loss_fn()

    # Create 1-vs-importance tensor, where 1 represents the importance of the
    # first feature relative to the second one
    importances = torch.stack((torch.ones(models.cfg.n_importances, device=device), importances), dim=1)

    for epoch in tqdm(range(1, training_args.n_epochs + 1)):
        batch = generate_batch(models.cfg, training_args.batch_size, densities, device)
        # Mean loss over the variables of interest (density and importance) as
        # well as different models
        loss = (loss_fn(batch, models(batch), importances)).mean()

        optimizer.zero_grad() # experiment setting to False
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch [{epoch}/{training_args.n_epochs}]: loss = {loss.item():.6f}")

# Intializing model

In [None]:
models_cfg = ModelsConfig()

models = ReLUModels(models_cfg)

# Training model

In [None]:
device = 'cuda' # 'cpu' will take >7 hours to train according to early estimates by tqdm

In [None]:
# Both features' density is log-spaced from 0.01 to 1
densities = 10 ** torch.linspace(-2, 0, 50, device=device)
# Relative importance of the second feature is log-spaced from 0.1 to 10
importances = 10 ** torch.linspace(-1, 1, 50, device=device)

In [None]:
# Seed for reproducibility
torch.manual_seed(42)

In [None]:
train(models, TrainingArgs(), optim.AdamW, ImportanceWeightedMSELoss, densities, importances, device)
print("Training complete.")

In [None]:
from google.colab import drive
drive.mount('/content/drive')
save_path = '/content/drive/My Drive/toy-models-superpos/relus-across-densities-and-importances.pth'
torch.save(models.state_dict(), save_path)

Mounted at /content/drive


# Plotting results

In [None]:
from google.colab import drive
drive.mount('/content/drive')
models.load_state_dict(torch.load('/content/drive/My Drive/toy-models-superpos/relus-across-densities-and-importances.pth'))

Mounted at /content/drive


  models.load_state_dict(torch.load('/content/drive/My Drive/toy-models-superpos/relus-across-densities-and-importances.pth'))


<All keys matched successfully>

In [None]:
# To do: write the visualize() function
visualize(models, densities, importances)