<a href="https://colab.research.google.com/github/jbaremoney/goldenTicket/blob/main/flow_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install medmnist --q

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.transforms import v2
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import random
import torch.autograd as autograd
import math
import medmnist
from medmnist import INFO, Evaluator
from matplotlib import pyplot as plt
from itertools import combinations_with_replacement

In [None]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [None]:
def getTrainingDataLoaders(dataset_name, download=True, BATCH_SIZE=128):
    """
    Handling data preprocessing & loading.

    Args:
      dataset_name (str): name of dataset to be loaded
      download (bool): Whether or not you'll download the dataset locally
      BATCH_SIZE (int): batch size to be used during training

    Returns:
      info (dict): dictionary directly from medmnist.INFO containing metadata
      task (str): string indicating type of task ie 'binary-class', 'multi-class'
      n_classes (int): int indicating number of classes in dataset
      train_loader (DataLoader): provides iterator over training dataset, provides batches
      train_loader_at_eval (DataLoader): evaluation version of train_loader, double batch size
      test_loader (DataLoader): test version of train_loader, similar to train_loader_at_eval

    """
    data_flag = dataset_name

    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])

    DataClass = getattr(medmnist, info['python_class'])

    # RGBtransform = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)

    # preprocessing
    if n_channels == 3:
        data_transform = v2.Compose([
            v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            v2.Normalize(mean=[.5], std=[.5])
        ])

    if n_channels == 1:
        data_transform = v2.Compose([
            v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
            v2.RGB(),
            v2.Normalize(mean=[.5], std=[.5])
        ])

    # load the data
    train_dataset = DataClass(split='train', transform=data_transform, download=download)
    test_dataset = DataClass(split='test', transform=data_transform, download=download)

    pil_dataset = DataClass(split='train', download=download)

    # encapsulate data into dataloader form
    train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
    test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
    return info, task, n_classes, train_loader, train_loader_at_eval, test_loader

In [None]:
class ClassicNetwork(nn.Module):
    def __init__(self, layer_sizes,bias=True):
        super().__init__()

        self.flatten = nn.Flatten()

        self.linear_relu_stack = nn.Sequential(
            *[z for l in layer_sizes
              for z in [nn.Linear(l[0], l[1],bias=bias), nn.ReLU()]][:-1]
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits



In [None]:
# Set up signed Kaiming initialization.
def signed_kaiming_constant_(tensor, a=0, mode='fan_in', nonlinearity='relu', k=0.5, sparsity=0):

    fan = nn.init._calculate_correct_fan(tensor, mode)  # calculating correct fan, depends on shape and type of nn
    gain = nn.init.calculate_gain(nonlinearity, a)
    std = (gain / math.sqrt(fan))
    # scale by (1/sqrt(k))
    if k != 0:
        std *= (1 / math.sqrt(k))

    with torch.no_grad():
        tensor.uniform_(-std, std)
        if sparsity > 0:
            mask = (torch.rand_like(tensor) > sparsity).float()  # Keeps (1 - sparsity)% weights

            tensor *= mask

        return tensor

In [None]:
class GetSubnet(autograd.Function):

    @staticmethod
    def forward(ctx, scores, k):

        # Get the subnetwork by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1-k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out
    @staticmethod
    def backward(ctx, grad):

        # send the gradient g straight-through on the backward pass.
        return grad, None

In [None]:
class LinearSubnet(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, k=0.5, init=signed_kaiming_constant_, **kwargs):
        super().__init__(in_features, out_features, bias if isinstance(bias, bool) else True, **kwargs)

        self.k = k
        self.popup_scores = nn.Parameter(torch.randn(out_features, in_features), requires_grad=True)
        self.bias_popup_scores = nn.Parameter(torch.randn(out_features), requires_grad=True)
        self.popup_scores_extra = nn.Parameter(torch.randn(out_features, in_features), requires_grad=True)
        self.bias_popup_scores_extra = nn.Parameter(torch.randn(out_features), requires_grad=True)

        self.initial_popup_scores = self.popup_scores.clone()
        self.initial_bias_popup_scores = self.bias_popup_scores.clone()

        # Initialize weights
        if init == signed_kaiming_constant_:
            init(self.weight, k=k)
        else:
            init(self.weight)

        self.weight.requires_grad_(False)
        if self.bias is not None:
            self.bias.requires_grad_(False)
    def return_to_initial_popup_scores(self):
        self.popup_scores = nn.Parameter(self.initial_popup_scores.clone(),requires_grad=True)
        self.bias_popup_scores = nn.Parameter(self.initial_bias_popup_scores.clone(),requires_grad=True)
        print('Popup Scores Returned to Initial Values')

    def forward(self, x):
        adj = GetSubnet.apply(
            torch.cat((self.popup_scores.abs(),self.popup_scores_extra.abs()),dim=-1), self.k
        )[:, :self.weight.shape[-1]]
        bias_adj = GetSubnet.apply(
            torch.cat((self.bias_popup_scores.abs(),self.bias_popup_scores_extra.abs()), dim=-1), self.k
        )[:self.bias.shape[-1]]

        w = self.weight * adj
        b = self.bias * bias_adj


        return F.linear(x, w, b)

class Network(nn.Module):
    def __init__(self, layer_sizes, k=0.5, init=signed_kaiming_constant_):
        super().__init__()
        self.flatten = nn.Flatten()

        self.layers = nn.ModuleList()
        for i, (in_f, out_f) in enumerate(layer_sizes):
            self.layers.append(LinearSubnet(in_f, out_f,k=k,init=init))
            if i < len(layer_sizes) - 1:
                self.layers.append(nn.ReLU())

    def forward(self, x):
        if x.shape[1:] != (3, 28, 28):
          print(x.shape)
          x.unsqueeze_(0)
          x = x.repeat(3, 1, 1)
        x = self.flatten(x)
        for layer in self.layers:
                x = layer(x)
        return x

In [None]:
class MaskAE(nn.Module):
    """Small MLP autoencoder for flattened doubled masks of length D."""
    def __init__(self, D, embed_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(D, max(128, embed_dim*2)),
            nn.ReLU(),
            nn.Linear(max(128, embed_dim*2), embed_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, max(128, embed_dim*2)),
            nn.ReLU(),
            nn.Linear(max(128, embed_dim*2), D)   # raw logits for masks
        )

    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y

    def encode(self, x):  # x: [B, D]
        return self.encoder(x)

    def decode(self, z):  # z: [B, E]
        return self.decoder(z)

# linear subnet that passes info about it's masking
class LinearSubnetFlow(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, k=0.5, init=signed_kaiming_constant_, **kwargs):
        super().__init__(in_features, out_features, bias if isinstance(bias, bool) else True, **kwargs)

        self.k = k
        self.popup_scores = nn.Parameter(torch.randn(out_features, in_features), requires_grad=True)
        self.bias_popup_scores = nn.Parameter(torch.randn(out_features), requires_grad=True)
        self.popup_scores_extra = nn.Parameter(torch.randn(out_features, in_features), requires_grad=True)
        self.bias_popup_scores_extra = nn.Parameter(torch.randn(out_features), requires_grad=True)

        self.initial_popup_scores = self.popup_scores.clone()
        self.initial_bias_popup_scores = self.bias_popup_scores.clone()

        # Initialize weights
        if init == signed_kaiming_constant_:
            init(self.weight, k=k)
        else:
            init(self.weight)

        self.weight.requires_grad_(False)
        if self.bias is not None:
            self.bias.requires_grad_(False)
    def return_to_initial_popup_scores(self):
        self.popup_scores = nn.Parameter(self.initial_popup_scores.clone(),requires_grad=True)
        self.bias_popup_scores = nn.Parameter(self.initial_bias_popup_scores.clone(),requires_grad=True)
        print('Popup Scores Returned to Initial Values')

    def forward(self, x, AE: MaskAE, prev_mask_emb=None):
        if prev_mask_emb is not None:
            # expand or project mask_emb to match x’s shape
            if prev_mask_emb.dim() == 1:
                prev_mask_emb = prev_mask_emb.unsqueeze(0).expand(x.size(0), -1)
            x = self.mask_adapter(torch.cat((x, prev_mask_emb), dim=-1))

        adj = GetSubnet.apply(
            torch.cat((self.popup_scores.abs(),self.popup_scores_extra.abs()),dim=-1), self.k
        )[:, :self.weight.shape[-1]]
        bias_adj = GetSubnet.apply(
            torch.cat((self.bias_popup_scores.abs(),self.bias_popup_scores_extra.abs()), dim=-1), self.k
        )[:self.bias.shape[-1]]

        w = self.weight * adj
        b = self.bias * bias_adj

        mask_emb = MaskAE.encode(w)


        return F.linear(x, w, b), mask_emb


class FlowNetwork(nn.Module):
    def __init__(self, layer_sizes, AE, k=0.5, init=signed_kaiming_constant_):
        super().__init__()
        self.flatten = nn.Flatten()

        self.layers = nn.ModuleList()
        for i, (in_f, out_f) in enumerate(layer_sizes):
            self.layers.append(LinearSubnetFlow(in_f, out_f,k=k,init=init))
            if i < len(layer_sizes) - 1:
                self.layers.append(nn.ReLU())

    def forward(self, x):
        if x.shape[1:] != (3, 28, 28):
          print(x.shape)
          x.unsqueeze_(0)
          x = x.repeat(3, 1, 1)
        x = self.flatten(x)
        for layer in self.layers:
                x = layer(x)
        return x

In [None]:
#training mask autoencoder
class MaskAE(nn.Module):
    """Small MLP autoencoder for flattened doubled masks of length D."""
    def __init__(self, D, embed_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(D, max(128, embed_dim*2)),
            nn.ReLU(),
            nn.Linear(max(128, embed_dim*2), embed_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, max(128, embed_dim*2)),
            nn.ReLU(),
            nn.Linear(max(128, embed_dim*2), D)   # raw logits for masks
        )

    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y

    def encode(self, x):  # x: [B, D]
        return self.encoder(x)

    def decode(self, z):  # z: [B, E]
        return self.decoder(z)

class MasksOnlyDataset(Dataset):
    def __init__(self, targets):  # targets: [N, D] of {0,1}
        self.targets = targets
    def __len__(self): return self.targets.size(0)
    def __getitem__(self, i):
        t = self.targets[i]
        return t, t  # (input=mask, target=mask)

def train_autoencoder_on_masks(targets,
                               D,            # flattened doubled length = out_f*(2*in_f) + 2*out_f
                               in_f, out_f,  # layer shape
                               k,            # keep ratio used by GetSubnet
                               embed_dim=64,
                               epochs=20,
                               batch=256,
                               lr=1e-3,
                               device='cpu',
                               aux_logit_loss_weight=0.0):
    """
    targets: [N, D] binary doubled masks (weight: [out, 2*in], bias: [2*out], flattened)
    Trains AE so that decoder(pred_z) -> logits -> GetSubnet(k) matches targets.
    Optionally adds a tiny auxiliary loss on logits (post-sigmoid) for stability.
    """
    ds = MasksOnlyDataset(targets)  # (mask, mask)
    loader = DataLoader(ds, batch_size=batch, shuffle=True)

    ae = MaskAE(D, embed_dim).to(device)
    opt = torch.optim.Adam(ae.parameters(), lr=lr)

    D_w = out_f * (2 * in_f)
    D_b = 2 * out_f

    ae.train()
    for _ in range(epochs):
        for m_in, m_tgt in loader:
            m_in, m_tgt = m_in.to(device), m_tgt.to(device)  # [B, D], {0,1}

            logits = ae(m_in)  # [B, D]

            # Split & reshape to doubled shapes, mirror router path
            w_logits = logits[:, :D_w].reshape(-1, out_f, 2*in_f).abs()
            b_logits = logits[:, D_w:].reshape(-1, 2*out_f).abs()

            # STE top-k to binarize
            w_mask_hat = GetSubnet.apply(w_logits, k)    # [B, out, 2*in] in {0,1}
            b_mask_hat = GetSubnet.apply(b_logits, k)    # [B, 2*out]   in {0,1}

            pred_mask = torch.cat([w_mask_hat.reshape(logits.size(0), -1),
                                   b_mask_hat.reshape(logits.size(0), -1)], dim=1)  # [B, D]

            # Primary loss: match binary masks after GetSubnet
            loss = F.mse_loss(pred_mask, m_tgt)

            # Optional tiny auxiliary term (can help stabilize training a bit)
            if aux_logit_loss_weight > 0.0:
                loss = loss + aux_logit_loss_weight * F.mse_loss(torch.sigmoid(logits), m_tgt)

            opt.zero_grad()
            loss.backward()
            opt.step()

    # freeze decoder head
    for p in ae.decoder.parameters():
        p.requires_grad = False
    ae.eval()
    return ae, ae.decoder

def train_autoencoders_for_all_layers(datasets_per_layer,
                                      layer_sizes,
                                      embed_dims,
                                      k=0.5,
                                      epochs=20,
                                      batch=256,
                                      lr=1e-3,
                                      device='cpu',
                                      aux_logit_loss_weight=0.0):
    """
    Returns: (decoders_frozen, full_AEs)
    Each dataset[i].targets is [N, D_i] with doubled masks.
    """
    if isinstance(embed_dims, int):
        embed_dims = [embed_dims] * len(layer_sizes)

    decoders, aes = [], []
    for (in_f, out_f), ds, E in zip(layer_sizes, datasets_per_layer, embed_dims):
        D = out_f * (2 * in_f) + 2 * out_f
        all_masks = ds.targets.to(device)  # [N, D], {0,1}

        ae, decoder = train_autoencoder_on_masks(
            targets=all_masks,
            D=D,
            in_f=in_f,
            out_f=out_f,
            k=k,
            embed_dim=E,
            epochs=epochs,
            batch=batch,
            lr=lr,
            device=device,
            aux_logit_loss_weight=aux_logit_loss_weight
        )
        decoders.append(decoder)  # frozen
        aes.append(ae)

    return decoders, aes

In [None]:
# Define a training function that returns a list of the losses during training.
def trainit(model,
            NUM_EPOCHS,
            train_loader,
            optimizer,
            task,
            n_classes,
            return_losses=False,
            no_progress=False):
  # define loss function
  if task == "multi-label, binary-class":
      criterion = nn.BCEWithLogitsLoss()
  else:
      criterion = nn.CrossEntropyLoss()
  if return_losses:
      losses = []
  # iterate over epochs for training run
  for epoch in range(NUM_EPOCHS):
      model.train()
      if no_progress:
        loader = train_loader
      else:
        loader=tqdm(train_loader)
      for inputs, targets in loader:
          inputs  = inputs.to(device, non_blocking=True)
          targets = targets.to(device, non_blocking=True)
          # forward + backward + optimize
          optimizer.zero_grad()
          outputs = model(inputs)[:,0:n_classes]
          if task == 'multi-label, binary-class':
              targets = targets.to(torch.float32)
              loss = criterion(outputs, targets)
          else:
              targets = targets.squeeze(1)
              loss = criterion(outputs, targets)
          loss.backward()
          optimizer.step()
          if return_losses:
              losses.append(loss.item())
  if return_losses:
      return losses

# Define an evaluation function
def test(split,
         model,
         train_loader_at_eval,
         test_loader,
         task,
         n_classes,
         data_flag,
         return_metrics=False):
    # define loss function
    if task == "multi-label, binary-class":
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])

    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            # inputs  = inputs.to(device, non_blocking=True)
            # targets = targets.to(device, non_blocking=True)
            outputs = model(inputs)[:,0:n_classes]

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze(1)
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()

        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)

        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

        if return_metrics:
          return metrics

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")
data_set_name = 'breastmnist'
info, task, n_classes, train_loader, train_loader_at_eval, test_loader = getTrainingDataLoaders(data_set_name)
layer_sizes=[[3*28*28, 256],[256,256],[256, n_classes]]
NUM_EPOCHS = 1000

model = Network(layer_sizes=layer_sizes,k=0.5)
model.to(device)
router_train_score=[]

losses=trainit(model,NUM_EPOCHS,train_loader,optim.Adam(model.parameters()),task='multi-class',n_classes=14, return_losses=True)
router_train_score.append(losses)

flow_model = FlowNetwork(layer_sizes=layer_sizes,k=0.5)
flow_model.to(device)
flow_train_score=[]
losses=trainit(flow_model,NUM_EPOCHS,train_loader,optim.Adam(model.parameters()),task='multi-class',n_classes=14, return_losses=True)
flow_train_score.append(losses)


classicmodel = ClassicNetwork(layer_sizes=layer_sizes)
classicmodel.to(device)
classic_train_score=[]
losses=trainit(classicmodel,NUM_EPOCHS,train_loader,optim.Adam(classicmodel.parameters()),task='multi-class',n_classes=14, return_losses=True)
classic_train_score.append(losses)

plt.plot(router_train_score[0], label='Masking')
plt.plot(flow_train_score[0], label='Flow Masking')
plt.plot(classic_train_score[0], label='Classical')
plt.legend()
ax = plt.gca()
ax.get_xaxis().set_visible(False)
# plt.yscale('log')


Using cpu device


100%|██████████| 560k/560k [00:01<00:00, 481kB/s]
100%|██████████| 5/5 [00:01<00:00,  3.41it/s]
100%|██████████| 5/5 [00:01<00:00,  3.05it/s]
100%|██████████| 5/5 [00:01<00:00,  2.92it/s]
100%|██████████| 5/5 [00:01<00:00,  3.78it/s]
100%|██████████| 5/5 [00:01<00:00,  3.79it/s]
100%|██████████| 5/5 [00:01<00:00,  3.75it/s]
100%|██████████| 5/5 [00:01<00:00,  3.81it/s]
100%|██████████| 5/5 [00:01<00:00,  3.80it/s]
100%|██████████| 5/5 [00:01<00:00,  3.75it/s]
100%|██████████| 5/5 [00:01<00:00,  3.71it/s]
100%|██████████| 5/5 [00:01<00:00,  2.90it/s]
100%|██████████| 5/5 [00:01<00:00,  2.89it/s]
100%|██████████| 5/5 [00:01<00:00,  3.74it/s]
100%|██████████| 5/5 [00:01<00:00,  3.69it/s]
100%|██████████| 5/5 [00:01<00:00,  3.77it/s]
100%|██████████| 5/5 [00:01<00:00,  3.78it/s]
100%|██████████| 5/5 [00:01<00:00,  3.75it/s]
100%|██████████| 5/5 [00:01<00:00,  3.74it/s]
100%|██████████| 5/5 [00:01<00:00,  3.80it/s]
100%|██████████| 5/5 [00:01<00:00,  2.88it/s]
100%|██████████| 5/5 [00:01<00

KeyboardInterrupt: 