<a href="https://colab.research.google.com/github/jbaremoney/goldenTicket/blob/main/explicit_router.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install medmnist --q
!pip install -U torchvision --q
!pip install sympy==1.13.3 --q # was getting weird dependency error

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/115.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from tqdm import tqdm
import numpy as np
import torch
print(torch.__version__)
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.transforms import v2
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import random
import torch.autograd as autograd
import math
import medmnist
from medmnist import INFO, Evaluator
from matplotlib import pyplot as plt
from itertools import combinations_with_replacement

2.8.0+cu126


In [3]:
!nvidia-smi

Thu Sep 25 21:53:29 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   50C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cuda device


In [5]:
def getTrainingDataLoaders(dataset_name, download=True, BATCH_SIZE=128):
    """
    Handling data preprocessing & loading.

    Args:
      dataset_name (str): name of dataset to be loaded
      download (bool): Whether or not you'll download the dataset locally
      BATCH_SIZE (int): batch size to be used during training

    Returns:
      info (dict): dictionary directly from medmnist.INFO containing metadata
      task (str): string indicating type of task ie 'binary-class', 'multi-class'
      n_classes (int): int indicating number of classes in dataset
      train_loader (DataLoader): provides iterator over training dataset, provides batches
      train_loader_at_eval (DataLoader): evaluation version of train_loader, double batch size
      test_loader (DataLoader): test version of train_loader, similar to train_loader_at_eval

    """
    data_flag = dataset_name

    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])

    DataClass = getattr(medmnist, info['python_class'])

    # RGBtransform = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)

    # preprocessing
    if n_channels == 3:
        data_transform = v2.Compose([
            v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            v2.Normalize(mean=[.5], std=[.5])
        ])

    if n_channels == 1:
        data_transform = v2.Compose([
            v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            v2.RGB(),
            v2.Normalize(mean=[.5], std=[.5])
        ])

    # load the data
    train_dataset = DataClass(split='train', transform=data_transform, download=download)
    test_dataset = DataClass(split='test', transform=data_transform, download=download)

    pil_dataset = DataClass(split='train', download=download)

    # encapsulate data into dataloader form
    train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
    test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
    return info, task, n_classes, train_loader, train_loader_at_eval, test_loader

In [6]:
import torch
from torch.utils.data import DataLoader, ConcatDataset, Dataset, Sampler
from torchvision.transforms import v2
import random
import math

# --- your PaddedMedMNIST as-is ---
TARGET_LENGTH = 14

class PaddedMedMNIST(Dataset):
    def __init__(self, dataset, target_length=TARGET_LENGTH):
        self.dataset = dataset
        self.target_length = target_length

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        label = torch.tensor(label, dtype=torch.float32)
        padded = torch.zeros(self.target_length, dtype=torch.float32)
        padded[:label.numel()] = label
        return img, padded

    def __len__(self):
        return len(self.dataset)


# --- Balanced batch sampler ---
class BalancedBatchSampler(Sampler):
    """
    Yields batches with (approximately) equal counts from each sub-dataset inside a ConcatDataset.

    Args:
        concat_ds: ConcatDataset([...])
        batch_size: total batch size
        strategy: 'upsample' (with replacement so every batch is balanced across entire epoch)
                  or 'min' (stop when any dataset can't fill its share; no replacement)
        drop_last: drop final incomplete batch (recommended True)
        generator: optional torch.Generator for reproducibility
    """
    def __init__(self, concat_ds: ConcatDataset, batch_size: int,
                 strategy: str = "upsample", drop_last: bool = True,
                 generator: torch.Generator | None = None):
        assert isinstance(concat_ds, ConcatDataset), "BalancedBatchSampler requires a ConcatDataset"
        self.concat_ds = concat_ds
        self.batch_size = int(batch_size)
        self.K = len(concat_ds.datasets)
        assert self.K >= 2, "Need at least 2 datasets to balance."
        assert self.batch_size >= self.K, "batch_size must be >= number of datasets"

        # Per-dataset shares sum to batch_size (distribute any remainder to the first few datasets).
        base = self.batch_size // self.K
        extra = self.batch_size % self.K
        self.shares = [base + (1 if i < extra else 0) for i in range(self.K)]

        self.lengths = [len(ds) for ds in concat_ds.datasets]
        self.offsets = []
        running = 0
        for L in self.lengths:
            self.offsets.append(running)
            running += L

        self.strategy = strategy
        self.drop_last = drop_last
        self.gen = generator

        # Precompute how many *balanced* batches we can make without replacement (for 'min').
        self.max_full_batches_min = min(
            (L // s) if s > 0 else 0
            for L, s in zip(self.lengths, self.shares)
        )

        # For 'upsample', define epoch length as number of batches ≈ total_len / batch_size
        total_len = sum(self.lengths)
        self.num_batches_upsample = max(1, total_len // self.batch_size)

    def __len__(self):
        if self.strategy == "min":
            return self.max_full_batches_min
        else:  # 'upsample'
            return self.num_batches_upsample

    def _randperm(self, n):
        if self.gen is None:
            return torch.randperm(n)
        return torch.randperm(n, generator=self.gen)

    def __iter__(self):
        # Build per-dataset index pools (local indices)
        pools = []
        for k, L in enumerate(self.lengths):
            order = self._randperm(L).tolist()
            pools.append(order)

        if self.strategy == "min":
            num_batches = self.max_full_batches_min
            for _ in range(num_batches):
                batch = []
                for k in range(self.K):
                    take = self.shares[k]
                    # pop 'take' elements from the pool (no replacement)
                    chosen_local = pools[k][:take]
                    pools[k] = pools[k][take:]
                    # map to global indices via offset
                    off = self.offsets[k]
                    batch.extend([off + i for i in chosen_local])
                yield batch
            # If not dropping last and any remainder exists (rare with this scheme), we could add a small final batch.
            # But by design with 'min' we usually keep batches uniform and drop incomplete ones.
            if not self.drop_last:
                # Attempt to form one last (possibly smaller) balanced-ish batch
                leftovers = []
                for k in range(self.K):
                    take = min(self.shares[k], len(pools[k]))
                    off = self.offsets[k]
                    leftovers.extend([off + i for i in pools[k][:take]])
                if len(leftovers) > 0:
                    yield leftovers

        else:  # 'upsample' with replacement from small datasets
            num_batches = self.num_batches_upsample
            for _ in range(num_batches):
                batch = []
                for k in range(self.K):
                    take = self.shares[k]
                    off = self.offsets[k]
                    pool = pools[k]
                    if len(pool) < take:
                        # Re-shuffle / top-up this pool
                        pool.extend(self._randperm(self.lengths[k]).tolist())
                    chosen_local = pool[:take]
                    del pool[:take]
                    batch.extend([off + i for i in chosen_local])
                yield batch


# --- Your loader builder, swapped to use the BalancedBatchSampler ---
def get_combined_medmnist_loader(dataset_names, batch_size=128, download=True, train=True,
                                 strategy: str = "upsample", drop_last: bool = True,
                                 num_workers: int = 4, pin_memory: bool = True):
    all_datasets = []

    for name in dataset_names:
        info = INFO[name]
        DataClass = getattr(medmnist, info['python_class'])
        n_channels = info['n_channels']

        if n_channels == 3:
            transform = v2.Compose([
                v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
            ])
        else:
            transform = v2.Compose([
                v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                v2.RGB(),  # force 3 channels
                v2.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
            ])

        split = 'train' if train else 'test'
        raw_dataset = DataClass(split=split, transform=transform, download=download)
        padded_dataset = PaddedMedMNIST(raw_dataset)
        all_datasets.append(padded_dataset)

    combined_dataset = ConcatDataset(all_datasets)
    batch_sampler = BalancedBatchSampler(
        combined_dataset, batch_size=batch_size, strategy=strategy, drop_last=drop_last
    )

    # IMPORTANT: when using batch_sampler, do not also pass batch_size/shuffle/sampler
    loader = DataLoader(
        combined_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    return loader

In [None]:
# Define a training function that returns a list of the losses during training.
def trainit(model,
            NUM_EPOCHS,
            train_loader,
            optimizer,
            task,
            n_classes,
            return_losses=False,
            no_progress=False):
  # define loss function
  if task == "multi-label, binary-class":
      criterion = nn.BCEWithLogitsLoss()
  else:
      criterion = nn.CrossEntropyLoss()
  if return_losses:
      losses = []
  # iterate over epochs for training run
  for epoch in range(NUM_EPOCHS):
      model.train()
      if no_progress:
        loader = train_loader
      else:
        loader=tqdm(train_loader)
      for inputs, targets in loader:
          inputs  = inputs.to(device, non_blocking=True)
          targets = targets.to(device, non_blocking=True)
          # forward + backward + optimize
          optimizer.zero_grad()
          outputs = model(inputs)[:,0:n_classes]
          if task == 'multi-label, binary-class':
              targets = targets.to(torch.float32)
              loss = criterion(outputs, targets)
          else:
              targets = targets.squeeze(1)
              loss = criterion(outputs, targets)
          loss.backward()
          optimizer.step()
          if return_losses:
              losses.append(loss.item())
  if return_losses:
      return losses

# Define an evaluation function
def test(split,
         model,
         train_loader_at_eval,
         test_loader,
         task,
         n_classes,
         data_flag,
         return_metrics=False):
    # define loss function
    if task == "multi-label, binary-class":
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])

    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            # inputs  = inputs.to(device, non_blocking=True)
            # targets = targets.to(device, non_blocking=True)
            outputs = model(inputs)[:,0:n_classes]

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze(1)
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()

        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)

        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

        if return_metrics:
          return metrics

In [7]:
def signed_kaiming_constant_(tensor, a=0, mode='fan_in', nonlinearity='relu', k=0.5, sparsity=0):

    fan = nn.init._calculate_correct_fan(tensor, mode)  # calculating correct fan, depends on shape and type of nn
    gain = nn.init.calculate_gain(nonlinearity, a)
    std = (gain / math.sqrt(fan))
    # scale by (1/sqrt(k))
    if k != 0:
        std *= (1 / math.sqrt(k))

    with torch.no_grad():
        tensor.uniform_(-std, std)
        if sparsity > 0:
            mask = (torch.rand_like(tensor) > sparsity).float()  # Keeps (1 - sparsity)% weights

            tensor *= mask

        return tensor


class GetSubnet(autograd.Function):

    @staticmethod
    def forward(ctx, scores, k):

        # Get the subnetwork by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort() # idx will be the indices in sorted order
        j = int((1-k) * scores.numel())  # how many we're getting rid of

        # flat_out and out access the same memory.
        flat_out = out.flatten() # flattened scores

        # idx[:j] is arr of indices corresponding to j lowest popup scores
        flat_out[idx[:j]] = 0 # vectorized assigning them to 0
        flat_out[idx[j:]] = 1 # vectorized assigning to 1

        return out

    @staticmethod
    def backward(ctx, grad):

        # send the gradient g straight-through on the backward pass.
        return grad, None


class LinearSubnetDynamicMask(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, k=0.5, init=signed_kaiming_constant_, **kwargs):
        super().__init__(in_features, out_features, bias if isinstance(bias, bool) else True, **kwargs)

        self.k = k

        # Initialize weights
        if init == signed_kaiming_constant_:
            init(self.weight, k=k)
        else:
            init(self.weight)

        self.weight.requires_grad_(False)
        if self.bias is not None:
            self.bias.requires_grad_(False)

    def forward(self, x, popups, popups_extra, bias_popups, bias_popups_extra):
        adj = GetSubnet.apply(
            torch.cat((popups.abs(),popups_extra.abs()),dim=-1), self.k
        )[:, :self.weight.shape[-1]]
        bias_adj = GetSubnet.apply(
            torch.cat((bias_popups.abs(),
                       bias_popups_extra.abs()), dim=-1), self.k
        )[:self.bias.shape[-1]]

        w = self.weight * adj
        b = self.bias * bias_adj
        return F.linear(x, w, b)


class Network(nn.Module):
    def __init__(self, layer_sizes, k=0.5, init=signed_kaiming_constant_, popups={}):
        super().__init__()
        self.flatten = nn.Flatten()

        self.layers = nn.ModuleList()
        for i, (in_f, out_f) in enumerate(layer_sizes):
            self.layers.append(LinearSubnet(in_f, out_f,k=k,init=init))
            if i < len(layer_sizes) - 1:
                self.layers.append(nn.ReLU())

    def apply_mask(self, scores,k):
      """
      scores must be same length as number of layers,
      each scores matrix must be same shape as respective layer
      """
      for i in range(len(scores)):
        self.linear_relu_stack[i*2] = self.linear_relu_stack[i*2] * GetSubnet.forward(scores[i], k)


    def forward(self, x):
        if x.shape[1:] != (3, 28, 28):
          print(x.shape)
          x.unsqueeze_(0)
          x = x.repeat(3, 1, 1)
        x = self.flatten(x)
        for layer in self.layers:
                x = layer(x)
        return x

In [15]:
import copy
def get_bin_mask(ctx, scores, k):

        # Get the subnetwork by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1-k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out # returns binary mask

class ClassicNetwork(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()

        self.flatten = nn.Flatten()

        self.linear_relu_stack = nn.Sequential(
            *[z for l in layer_sizes
              for z in [nn.Linear(l[0], l[1]), nn.ReLU()]][:-1]
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


    def apply_mask(self, whole_mask):
      """
      whole mask must be same length as number of linear layers,
      each mask matrix must be same shape as respective layer

      assumes ReLu is every other in list
      """
      linear_layers = [layer for layer in self.linear_relu_stack if isinstance(layer, nn.Linear)]

      if len(linear_layers) != len(whole_mask):
          raise ValueError(f"Expected {len(linear_layers)} masks, got {len(whole_mask)}")

      for layer, mask in zip(linear_layers, whole_mask):
          # Apply mask to weights (in-place)
          layer.weight.data *= mask

      return self



    def clone(self):
      return copy.deepcopy(self)


# Set up signed Kaiming initialization.
def signed_kaiming_constant_(tensor, a=0, mode='fan_in', nonlinearity='relu', k=0.5, sparsity=0):

    fan = nn.init._calculate_correct_fan(tensor, mode)  # calculating correct fan, depends on shape and type of nn
    gain = nn.init.calculate_gain(nonlinearity, a)
    std = (gain / math.sqrt(fan))
    # scale by (1/sqrt(k))
    if k != 0:
        std *= (1 / math.sqrt(k))

    with torch.no_grad():
        tensor.uniform_(-std, std)
        if sparsity > 0:
            mask = (torch.rand_like(tensor) > sparsity).float()  # Keeps (1 - sparsity)% weights

            tensor *= mask

        return tensor


In [19]:
# just doing real numbers
in_dim = 1
out_dim = 1
w = 28
d = 10


classic_layer_sizes = [[in_dim, w]]
for _ in range(d):
    classic_layer_sizes.append([w, w])
classic_layer_sizes.append([w, out_dim])

main_net = ClassicNetwork(classic_layer_sizes)
print(main_net)

num_points = 10 # number of points for context

# generating 100 whole masks to mask entire network
whole_masks_list = []

for i in range(100):
  whole_mask = [] # list of masks for each layer

  for layer in main_net.linear_relu_stack:
    if isinstance(layer, nn.Linear):

      p = random.random()  # random probability in [0, 1]

      layer_mask = (torch.rand_like(layer.weight) > p).float()
      whole_mask.append(layer_mask)

  whole_masks_list.append(whole_mask)

# for each whole mask, generate 10 input output pairs ie (network * mask)(input) = output
# let input be a integer
whole_masks_context_list = []
for mask in whole_masks_list:
  #inputs = [torch.randn(main_net.input_dim) for _ in range(10)]# generates random tensor of input shape
  # use batching instead
  inputs = torch.randn(num_points, in_dim) # random tensor size [batch_size, input_dim]
  # each row is an input vector, ie with input dim 1 we have column vector

  # generate masked network just once
  masked_net = main_net.clone().apply_mask(mask)



  # generate outputs
  outputs = masked_net(inputs) # runs batch

  io_list = list(zip(inputs.detach(), outputs.detach())) # x,y pairs
  whole_masks_context_list.append(io_list)
print(whole_masks_context_list)




ClassicNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=1, out_features=28, bias=True)
    (1): ReLU()
    (2): Linear(in_features=28, out_features=28, bias=True)
    (3): ReLU()
    (4): Linear(in_features=28, out_features=28, bias=True)
    (5): ReLU()
    (6): Linear(in_features=28, out_features=28, bias=True)
    (7): ReLU()
    (8): Linear(in_features=28, out_features=28, bias=True)
    (9): ReLU()
    (10): Linear(in_features=28, out_features=28, bias=True)
    (11): ReLU()
    (12): Linear(in_features=28, out_features=28, bias=True)
    (13): ReLU()
    (14): Linear(in_features=28, out_features=28, bias=True)
    (15): ReLU()
    (16): Linear(in_features=28, out_features=28, bias=True)
    (17): ReLU()
    (18): Linear(in_features=28, out_features=28, bias=True)
    (19): ReLU()
    (20): Linear(in_features=28, out_features=28, bias=True)
    (21): ReLU()
    (22): Linear(in_features=28, out_features=1, bias=Tr

So now each whole_mask has its corresponding context. That is
For mask_i , context_i is defined [(x_1, masked_nn(x_1)), ... (x_n, masked_nn(x_n))]. So each mask has n context points.

Now we make an autoencoder to encode/decode whole masks.

In [20]:
EMBED_DIM = 32

def prepare_masks_for_embed(whole_mask_list):
      """whole_mask_list is list of whole masks, whole mask is list of per layer masks"""
      flattened_masks = []
      for whole_mask in whole_mask_list:
          flat = torch.cat([mask.flatten() for mask in whole_mask], dim=0)  # shape [D] where D = total number of weights in network
          flattened_masks.append(flat)

      return torch.stack(flattened_masks)  # shape [num_masks, D]

class WholeMaskAE(nn.Module):
    """
    Encode masks for entire network ie whole masks,
    inputs generated by prepare_mask_for_embed(whole_mask_list)
    """
    def __init__(self, D, embed_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(D, max(128, embed_dim*2)),
            nn.ReLU(),
            nn.Linear(max(128, embed_dim*2), embed_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, max(128, embed_dim*2)),
            nn.ReLU(),
            nn.Linear(max(128, embed_dim*2), D)   # raw logits for masks
        )

    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y

    def encode(self, x):  # x: [B, D]
        return self.encoder(x)

    def decode(self, z):  # z: [B, E]
        return self.decoder(z)

Training the autoencoder

In [23]:
prepared_mask_tensor = prepare_masks_for_embed(whole_masks_list)  # shape [num_masks, D] where D is total number of weights in network

D = prepared_mask_tensor.shape[1]

mask_ae = WholeMaskAE(D, EMBED_DIM)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(mask_ae.parameters(), lr=1e-3)

for epoch in range(1000):
    optimizer.zero_grad()
    recon = mask_ae(prepared_mask_tensor)  # AE tries to reconstruct input
    loss = criterion(recon, prepared_mask_tensor)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

Epoch 10: Loss = 0.3609
Epoch 20: Loss = 0.2541
Epoch 30: Loss = 0.2446
Epoch 40: Loss = 0.2414
Epoch 50: Loss = 0.2398
Epoch 60: Loss = 0.2377
Epoch 70: Loss = 0.2339
Epoch 80: Loss = 0.2273
Epoch 90: Loss = 0.2188
Epoch 100: Loss = 0.2101
Epoch 110: Loss = 0.2022
Epoch 120: Loss = 0.1933
Epoch 130: Loss = 0.1858
Epoch 140: Loss = 0.1785
Epoch 150: Loss = 0.1724
Epoch 160: Loss = 0.1680
Epoch 170: Loss = 0.1647
Epoch 180: Loss = 0.1616
Epoch 190: Loss = 0.1583
Epoch 200: Loss = 0.1556
Epoch 210: Loss = 0.1525
Epoch 220: Loss = 0.1501
Epoch 230: Loss = 0.1478
Epoch 240: Loss = 0.1457
Epoch 250: Loss = 0.1437
Epoch 260: Loss = 0.1418
Epoch 270: Loss = 0.1394
Epoch 280: Loss = 0.1372
Epoch 290: Loss = 0.1350
Epoch 300: Loss = 0.1327
Epoch 310: Loss = 0.1303
Epoch 320: Loss = 0.1279
Epoch 330: Loss = 0.1255
Epoch 340: Loss = 0.1230
Epoch 350: Loss = 0.1207
Epoch 360: Loss = 0.1182
Epoch 370: Loss = 0.1157
Epoch 380: Loss = 0.1131
Epoch 390: Loss = 0.1107
Epoch 400: Loss = 0.1078
Epoch 410

In [31]:

training_data = []
for i in range(len(prepared_mask_tensor)):
  context = whole_masks_context_list[i]

  flattened_context = torch.cat([torch.cat((j[0], j[1])) for j in context])

  encoded_mask = mask_ae.encode(prepared_mask_tensor[i].unsqueeze(0)).squeeze(0).detach() #break comp graph from ae

  training_data.append((flattened_context, encoded_mask))

print(training_data)

# training data is pairs of (flattened_context, autoencoder(respective mask))

# so we are basically saying "here's what this function should look like"
# and "give me the vector that will get decoded to be the mask to approximate the function"





[(tensor([-0.1589,  0.0911, -1.8219,  0.0911,  0.9176,  0.0911, -0.3174,  0.0911,
         0.0772,  0.0911,  1.5879,  0.0911,  0.0347,  0.0911,  0.6774,  0.0911,
        -0.3319,  0.0911,  2.0134,  0.0911]), tensor([  5.0334,   5.3987,  -9.4939,  -5.3220,   3.1807,   2.0935,   3.0375,
         -2.1661,  -7.1258,  -2.2957, -13.2378,  13.0512,  -8.1087,   2.2995,
          5.2444,   6.7891,   1.0956,   4.8833,   8.1026, -10.7583,  11.7555,
          2.1992,   2.5882,   1.4380,   4.0920,   5.6199,  13.8963,   8.4806,
         -2.1295,  14.0702,   1.6270,  12.5810])), (tensor([ 0.1815,  0.1235, -1.4175,  0.1235,  0.4602,  0.1235, -2.0951,  0.1235,
         1.1751,  0.1235,  0.7389,  0.1235,  0.1005,  0.1235,  0.9139,  0.1235,
         0.1667,  0.1235,  0.6655,  0.1235]), tensor([ 3.7202,  9.0669, -6.4105, -2.4829, -0.5054, -0.6486,  6.4054,  3.1014,
        -7.4661, -1.9161, -3.4143,  2.5878, -6.7323, -0.5869,  4.3324,  0.9137,
         0.3167,  1.7482,  6.8083, -3.2109, 11.8243,  5.2103, 

Now we can train the router with input shape being len(flattened_context) and output shape being encode_dim

In [32]:
from torch.utils.data import TensorDataset, DataLoader

# preparing the training data
contexts = torch.stack([pair[0] for pair in training_data]) # shape = [100, len(flattened_context)]
targets = torch.stack([pair[1] for pair in training_data]) # shape = [100, encode_dim]

dataset = TensorDataset(contexts, targets)
loader = DataLoader(dataset, batch_size=16, shuffle=True)  # 16 samples per batch


# defining the router
in_dim = len(training_data[0][0]) # length of flattened context
w = 28
out_dim = EMBED_DIM # gets decoded into mask

d = 10 # depth

layer_sizes = [[in_dim, w]]
for _ in range(d):
    layer_sizes.append([w, w])
layer_sizes.append([w, out_dim])

Router = ClassicNetwork(layer_sizes)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(Router.parameters(), lr=1e-3)


Now we can train it, minimizing Router(context) - AE(mask that generated the context)  ... such that router gets better at producing embedded masks that might have generated this context

In [34]:
num_epochs = 500

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch_context, batch_target in loader:
        optimizer.zero_grad()

        # Forward pass
        pred = Router(batch_context)  # shape [batch_size, encode_dim], all predicted y for batch

        # Compute loss
        loss = criterion(pred, batch_target)

        # Backward pass
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(loader):.4f}")

Epoch 1/500, Loss: 16.4615
Epoch 2/500, Loss: 15.8240
Epoch 3/500, Loss: 16.2293
Epoch 4/500, Loss: 15.9058
Epoch 5/500, Loss: 15.7197
Epoch 6/500, Loss: 15.8705
Epoch 7/500, Loss: 15.8098
Epoch 8/500, Loss: 16.0492
Epoch 9/500, Loss: 15.9528
Epoch 10/500, Loss: 16.2993
Epoch 11/500, Loss: 15.4374
Epoch 12/500, Loss: 16.1634
Epoch 13/500, Loss: 16.0046
Epoch 14/500, Loss: 15.9766
Epoch 15/500, Loss: 15.7543
Epoch 16/500, Loss: 16.2283
Epoch 17/500, Loss: 15.8279
Epoch 18/500, Loss: 15.6752
Epoch 19/500, Loss: 15.2433
Epoch 20/500, Loss: 15.8695
Epoch 21/500, Loss: 16.0852
Epoch 22/500, Loss: 15.9280
Epoch 23/500, Loss: 16.2254
Epoch 24/500, Loss: 15.7834
Epoch 25/500, Loss: 15.7102
Epoch 26/500, Loss: 15.7469
Epoch 27/500, Loss: 16.1146
Epoch 28/500, Loss: 15.6475
Epoch 29/500, Loss: 16.0112
Epoch 30/500, Loss: 16.1642
Epoch 31/500, Loss: 16.1263
Epoch 32/500, Loss: 15.2042
Epoch 33/500, Loss: 15.8985
Epoch 34/500, Loss: 15.6214
Epoch 35/500, Loss: 15.9498
Epoch 36/500, Loss: 15.8693
E

Now to test, just generate more random whole masks, generate context for them, pass only context to network, and see