In [1]:
import torch
from dataclasses import dataclass


In [17]:
from torch import nn
from torch.nn.functional import relu

@dataclass
class Config:
    n_features: int
    n_hidden: int

  # We optimize n_instances models in a single training loop
  # to let us sweep over sparsity or importance curves
  # efficiently.



class DataGenerator:

    def __init__(self, n_features, feature_probability, importance, device):
        self.n_features = n_features
        self.device = device

        if feature_probability is None:
            feature_probability = torch.ones(())
        self.feature_probability = feature_probability
        
        if importance is None:
            importance = torch.ones(())
    
        self.importance = importance.to(device)
        self.feature_probability = feature_probability.to(device)
    

    def generate_batch(self, n_batch):
        feat = torch.rand((n_batch, self.n_features), device=self.device)
        batch = torch.where(
            torch.rand((n_batch, self.n_features), device=self.device) <= self.feature_probability,
            feat,
            torch.zeros((), device=self.device),
        )
        return batch


class Model(nn.Module):
    def __init__(
        self, 
        config, 
        device='cuda'
    ):
        super().__init__()
        self.config = config
        self.W = nn.Parameter(
            torch.empty((config.n_features, config.n_hidden), device=device)
        )
        nn.init.xavier_normal_(self.W)
        self.b = nn.Parameter(torch.zeros((config.n_features), device=device))


    
    def forward(self, features):
        # features: [..., n_features]
        # W: [n_features, n_hidden]
        hidden = features @ self.W
        out = (hidden @ self.W.T) + self.b
        return relu(out)

  

In [18]:
n_features = 24
n_hidden = 8
device = 'cuda:0'

probs = torch.ones((24,)) * .2

config = Config(n_features, n_hidden)
generator = DataGenerator(n_features, probs, None, device)


In [19]:
model = Model(config, device)

In [20]:
batch = generator.generate_batch(100)
model(batch)

tensor([[0.3108, 0.0932, 0.0000,  ..., 0.3805, 0.0000, 0.0000],
        [0.1109, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0774, 0.0638,  ..., 0.0000, 0.0094, 0.0000],
        ...,
        [0.0077, 0.0000, 0.3032,  ..., 0.1546, 0.1707, 0.0857],
        [0.0000, 0.0000, 0.0109,  ..., 0.0000, 0.0403, 0.0240],
        [0.0022, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<ReluBackward0>)