In [4]:
from my_utils.model_helpers import init_model, get_layer_names, get_nice_layer_names, count_nonzero_weights
from my_utils.sparsify import sparsify_weights

In [17]:
resnet = init_model('vgg16', regime="trained")
resnet_s = sparsify_weights(resnet, 'vgg16', sparsity_k=0.9,)
z, t = count_nonzero_weights(resnet_s)

z/t

Layers to sparsify: all (n=15)


0.9029607169160082

In [7]:
sparsify_weights(convnext, 'convnext_b', target_layer_range="all")

Layers to sparsify: all (n=112)


ModelWithInputLayer(
  (input_layer): Identity()
  (model): ConvNeXt(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
      )
      (1): Sequential(
        (0): CNBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (1): Permute()
            (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
            (3): Linear(in_features=128, out_features=512, bias=True)
            (4): GELU(approximate='none')
            (5): Linear(in_features=512, out_features=128, bias=True)
            (6): Permute()
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): CNBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (1

In [4]:
convnext

ModelWithInputLayer(
  (input_layer): Identity()
  (model): ConvNeXt(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
      )
      (1): Sequential(
        (0): CNBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (1): Permute()
            (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
            (3): Linear(in_features=128, out_features=512, bias=True)
            (4): GELU(approximate='none')
            (5): Linear(in_features=512, out_features=128, bias=True)
            (6): Permute()
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): CNBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (1

In [34]:
print(vit.model.encoder.layers.encoder_layer_1.self_attention.in_proj_weight.shape)
print(vit.model.encoder.layers.encoder_layer_1.self_attention.out_proj.weight.shape)

torch.Size([2304, 768])
torch.Size([768, 768])


In [6]:
ln = get_layer_names(convnext)
get_nice_layer_names(convnext, ln)
for name, module in convnext.named_modules():
    if len(list(module.children())) == 0:
        print(f"{name} ||||  {str(module).split("(")[0]}")

input_layer ||||  Identity
model.features.0.0 ||||  Conv2d
model.features.0.1 ||||  LayerNorm2d
model.features.1.0.block.0 ||||  Conv2d
model.features.1.0.block.1 ||||  Permute
model.features.1.0.block.2 ||||  LayerNorm
model.features.1.0.block.3 ||||  Linear
model.features.1.0.block.4 ||||  GELU
model.features.1.0.block.5 ||||  Linear
model.features.1.0.block.6 ||||  Permute
model.features.1.0.stochastic_depth ||||  StochasticDepth
model.features.1.1.block.0 ||||  Conv2d
model.features.1.1.block.1 ||||  Permute
model.features.1.1.block.2 ||||  LayerNorm
model.features.1.1.block.3 ||||  Linear
model.features.1.1.block.4 ||||  GELU
model.features.1.1.block.5 ||||  Linear
model.features.1.1.block.6 ||||  Permute
model.features.1.1.stochastic_depth ||||  StochasticDepth
model.features.1.2.block.0 ||||  Conv2d
model.features.1.2.block.1 ||||  Permute
model.features.1.2.block.2 ||||  LayerNorm
model.features.1.2.block.3 ||||  Linear
model.features.1.2.block.4 ||||  GELU
model.features.1.2.b

In [3]:
import torch
import random
import numpy as np
from torch import nn
from torchvision import models

def sparsify_vit_weights(model, sparsity_k=0.5, layers="all", layer_type="all"):

    assert 0.0 <= sparsity_k <= 1.0, "sparsity_k must be between 0 and 1."
    assert layer_type in ["all", "attn_in", "attn_out", "mlp"], "layer_type must be one of 'all', 'attn_in', 'attn_out', 'mlp'."

    layer_names = []
    mlp_flag = layer_type in ["all", "mlp"]
    attn_in_flag = layer_type in ["all", "attn_in"]
    attn_out_flag = layer_type in ["all", "attn_out"]


    # for now only "all" layers supported
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear)):
            # mlp layers
            if mlp_flag and 'mlp' in name:
                layer_names.append(name)
            
            # attn_out layers
            if attn_out_flag and 'out_proj' in name:
                layer_names.append(name) 
        
        # attn_in layers
        if attn_in_flag and isinstance(module, nn.MultiheadAttention):
            layer_names.append(name) # needs special attention later to get in_proj

    layer_names = layer_names  # exclude the final classification layer
    num_layers = len(layer_names)

    layers_to_sparsify = layer_names  # all layers TODO: add early/late options
    return layer_names
    modules = dict(model.named_modules())
    for name in layers_to_sparsify:
        module = modules[name]
        W = None
        if isinstance(module, (nn.Linear)):
            # mlp layers and attn_out layers
            W = module.weight
            
        
        # attn_in layers
        elif isinstance(module, nn.MultiheadAttention):
            W = module.in_proj_weight
            
        if W is not None:
            with torch.no_grad():
                W_flat = W.view(-1)
                num_weights_to_keep = round(sparsity_k * W_flat.numel())
                indices = torch.randperm(W_flat.numel())[:num_weights_to_keep]
                mask = torch.zeros_like(W_flat)
                mask[indices] = 1
                W_sparse = (W_flat * mask).view_as(W)
                
                if isinstance(module, (nn.Linear)):
                    module.weight.copy_(W_sparse)
                elif isinstance(module, nn.MultiheadAttention):
                    module.in_proj_weight.copy_(W_sparse)

            # Verify sparsity to be within tolerance (5%)
            actual_nonzero = torch.sum(W != 0).item()
            actual_sparsity = actual_nonzero / W.numel() # fraction of weights kept
            print(f'Layer {name}: expected sparsity {sparsity_k}, actual sparsity {actual_sparsity}')
            assert abs(actual_sparsity - sparsity_k) < 0.05, f"Sparsity check failed for layer {name}: expected {sparsity_k}, got {actual_sparsity}"




In [None]:
vit = init_model('vit_b_16', regime="trained")
ln = sparsify_vit_weights(vit, layer_type="all")
print(f"All layers: {len(ln)}")
vit = init_model('vit_b_16', regime="trained")
ln1 = sparsify_vit_weights(vit, layer_type="attn_in")
print(f"AttnIn layers: {len(ln1)}")
vit = init_model('vit_b_16', regime="trained")
ln2 = sparsify_vit_weights(vit, layer_type="attn_out")
print(f"AttnOut layers: {len(ln2)}")
vit = init_model('vit_b_16', regime="trained")
ln3 = sparsify_vit_weights(vit, layer_type="mlp")
print(f"MLP layers: {len(ln3)}")

In [5]:
net = init_model('vit_b_16', regime="trained")
nets = sparsify_weights(net, arch='vit_b_16', target_layer_range="middle")

['model.conv_proj', 'model.encoder.layers.encoder_layer_0.self_attention', 'model.encoder.layers.encoder_layer_0.self_attention.out_proj', 'model.encoder.layers.encoder_layer_0.mlp.0', 'model.encoder.layers.encoder_layer_0.mlp.3', 'model.encoder.layers.encoder_layer_1.self_attention', 'model.encoder.layers.encoder_layer_1.self_attention.out_proj', 'model.encoder.layers.encoder_layer_1.mlp.0', 'model.encoder.layers.encoder_layer_1.mlp.3', 'model.encoder.layers.encoder_layer_2.self_attention', 'model.encoder.layers.encoder_layer_2.self_attention.out_proj', 'model.encoder.layers.encoder_layer_2.mlp.0', 'model.encoder.layers.encoder_layer_2.mlp.3', 'model.encoder.layers.encoder_layer_3.self_attention', 'model.encoder.layers.encoder_layer_3.self_attention.out_proj', 'model.encoder.layers.encoder_layer_3.mlp.0', 'model.encoder.layers.encoder_layer_3.mlp.3', 'model.encoder.layers.encoder_layer_4.self_attention', 'model.encoder.layers.encoder_layer_4.self_attention.out_proj', 'model.encoder.la

In [11]:
import timm 

timms = timm.list_models(pretrained=True)
timms_moco = [t for t in timms if "moco" in t
timms_moco

False