The architeture of LLaVA contains the following layers for projection:

In [None]:
import torch
import torch.nn as nn
import re

def build_vision_projector(config, delay_load=False, **kwargs):
    projector_type = getattr(config, 'mm_projector_type', 'linear')

    if projector_type == 'linear':
        return nn.Linear(config.mm_hidden_size, config.hidden_size)

    mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)

    if projector_type == 'identity':
        return IdentityMap()

    raise ValueError(f'Unknown projector type: {projector_type}')

In [None]:
!pip install ptflops

Collecting ptflops
  Downloading ptflops-0.7.4-py3-none-any.whl.metadata (9.4 kB)
Downloading ptflops-0.7.4-py3-none-any.whl (19 kB)
Installing collected packages: ptflops
Successfully installed ptflops-0.7.4


Here we create a network similar to the projection layer to test pruning:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
from ptflops import get_model_complexity_info

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers):
        super(MLP, self).__init__()
        layers = [nn.Linear(input_size, hidden_size)]
        for _ in range(n_layers - 1):
            layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_size, hidden_size))
        layers.append(nn.Linear(hidden_size, output_size))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy


def finetune(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}, Accuracy: {100 * correct / total}%")


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
test_loader = DataLoader(testset, batch_size=64, shuffle=False)


input_size = 28 * 28
hidden_size = 128
output_size = 10
n_layers = 5


model = MLP(input_size, hidden_size, output_size, n_layers).cuda()

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Training the model before pruning:")
finetune(model, train_loader, criterion, optimizer, epochs=10)

torch.save(model, "model.pth")

print("Calculating FLOPs before pruning:")
macs, params = get_model_complexity_info(model, (1, 28, 28), as_strings=True, print_per_layer_stat=True)
print(f"FLOPs before pruning: {macs}, Parameters: {params}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 43.5MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.28MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 10.9MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.57MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training the model before pruning:
Epoch 1, Loss: 0.3561573775369984, Accuracy: 88.36166666666666%
Epoch 2, Loss: 0.15687885793295306, Accuracy: 95.16333333333333%
Epoch 3, Loss: 0.12062729426772832, Accuracy: 96.28833333333333%
Epoch 4, Loss: 0.10285982079823801, Accuracy: 96.795%
Epoch 5, Loss: 0.0874423110019366, Accuracy: 97.28833333333333%
Epoch 6, Loss: 0.08068304431641093, Accuracy: 97.50833333333334%
Epoch 7, Loss: 0.06948303363173267, Accuracy: 97.85666666666667%
Epoch 8, Loss: 0.06760811422876309, Accuracy: 97.895%
Epoch 9, Loss: 0.060989562736346085, Accuracy: 98.135%
Epoch 10, Loss: 0.05614889428228773, Accuracy: 98.245%
Calculating FLOPs before pruning:
MLP(
  167.82 k, 100.000% Params, 168.33 KMac, 99.697% MACs, 
  (model): Sequential(
    167.82 k, 100.000% Params, 168.33 KMac, 99.697% MACs, 
    (0): Linear(100.48 k, 59.874% Params, 100.48 KMac, 59.511% MACs, in_features=784, out_features=128, bi

In [None]:
import torch.nn.utils.prune as prune
pmodel = torch.load("model.pth")

def prune_model(model, pruning_amount=0.1):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv1d):
            prune.l1_unstructured(module, name='weight', amount=pruning_amount)
            print(f"Pruned Conv1d layer: {name}")
        elif isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_amount)
            print(f"Pruned Linear layer: {name}")

    return model

print("Pruning the model:")
pmodel = prune_model(pmodel, pruning_amount=0.9)

print("Calculating FLOPs after pruning:")
macs_after_pruning, params_after_pruning = get_model_complexity_info(pmodel, (1, 28, 28), as_strings=True, print_per_layer_stat=True)
print(f"FLOPs after pruning: {macs_after_pruning}, Parameters: {params_after_pruning}")

print("Evaluating the pruned model on the test set:")
test_accuracy_before_finetune = evaluate(pmodel, test_loader)

  pmodel = torch.load("model.pth")


Pruning the model:
Pruned Linear layer: model.0
Pruned Linear layer: model.2
Pruned Linear layer: model.4
Pruned Linear layer: model.6
Pruned Linear layer: model.8
Pruned Linear layer: model.9
Calculating FLOPs after pruning:
MLP(
  167.82 k, 100.000% Params, 168.33 KMac, 99.697% MACs, 
  (model): Sequential(
    167.82 k, 100.000% Params, 168.33 KMac, 99.697% MACs, 
    (0): Linear(100.48 k, 59.874% Params, 100.48 KMac, 59.511% MACs, in_features=784, out_features=128, bias=True)
    (1): GELU(0, 0.000% Params, 128.0 Mac, 0.076% MACs, approximate='none')
    (2): Linear(16.51 k, 9.839% Params, 16.51 KMac, 9.780% MACs, in_features=128, out_features=128, bias=True)
    (3): GELU(0, 0.000% Params, 128.0 Mac, 0.076% MACs, approximate='none')
    (4): Linear(16.51 k, 9.839% Params, 16.51 KMac, 9.780% MACs, in_features=128, out_features=128, bias=True)
    (5): GELU(0, 0.000% Params, 128.0 Mac, 0.076% MACs, approximate='none')
    (6): Linear(16.51 k, 9.839% Params, 16.51 KMac, 9.780% MACs, 