In [None]:
import os
import sys
import torch

Add system path

In [None]:
sys.path.append("/home/iwawiwi/research/22/lipreading-lightning/")
print(sys.path)

Prepare model and data class

In [None]:
from src.models.lrw_module import LRWLitModule
from src.datamodules.components.lrw_dataset import LRWDataset

Model path

In [None]:
BIG_MODEL = "lrw_model.pth"

Load model

In [None]:
# loading model using lightning module
big_model = LRWLitModule.load_from_checkpoint(BIG_MODEL)

In [None]:
print(big_model)

In [None]:
net = big_model.net

Evaluate net performance on test dataset

In [None]:
net.eval()

In [None]:
DATA_PATH = "/home/iwawiwi/research/22/lipreading-lightning/data/lrw_cropped"
test_set = LRWDataset(DATA_PATH, phase="test")
print(len(test_set))

In [None]:
testloader = torch.utils.data.DataLoader(test_set, batch_size=16, shuffle=False, num_workers=0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
net.to(device) # use GPU if available

# test and compute accuracy
with torch.no_grad():
    correct = 0
    total = 0
    for data in testloader:
        inputs, labels = data["video"].to(device), data["label"].long().to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print("Accuracy of the network on the test images: %.2f %%" % (100 * correct / total))
    

Overall accuracy: $84.52\%$

Import pytorch pruning module

In [None]:
import torch.nn.utils.prune as prune
import torch.nn.functional as F 

Inspect module

In [None]:
print(list(net.named_parameters()))

In [None]:
# print all module' names from the network
for name, module in net.named_modules():
    print(name)

For sequential module (instance of ```torch.nn.Sequential```), number valued entry indicated array index position of the module in the sequence.

Do global pruning on ```net```

In [None]:
# threshold
t = 1e-6

# calculate sparsity in model based on threshold
sparsity = 0
for name, param in net.named_parameters():
    if param.requires_grad:
        sparsity += float(param.numel() - param.nonzero().shape[0]) / param.numel()

        if param.numel() > 0:
            param.data = torch.where(param.data > t, param.data, torch.zeros_like(param.data))

        print(name, param.numel(), param.nonzero().shape[0], "\n", float(param.numel() - param.nonzero().shape[0]) / param.numel())


In [None]:
# index of sequential module can be accesed using array
parameters_to_prune = (
    #(net.video_cnn.frontend3D[0], 'weight'), 
    #(net.video_cnn.resnet18.layer1[0].conv1, "weight"),
    #(net.video_cnn.resnet18.layer1[0].conv2, "weight"),
    #(net.video_cnn.resnet18.layer1[1].conv1, "weight"),
    #(net.video_cnn.resnet18.layer1[1].conv2, "weight"),
    #(net.video_cnn.resnet18.layer2[0].conv1, "weight"),
    #(net.video_cnn.resnet18.layer2[0].conv2, "weight"),
    #(net.video_cnn.resnet18.layer2[1].conv1, "weight"),
    #(net.video_cnn.resnet18.layer2[1].conv2, "weight"),
    (net.v_cls, "weight"),
)

In [None]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0, # 0% smallest weight (defined globally from ```parameters_to_prune```) will be pruned
)

In [None]:
# check buffer
print(list(net.named_buffers()))

Check forward pre-hooks

In [None]:
print(dict(net.named_buffers()).keys())

Remove re-parameterization

In [None]:
for module, param in parameters_to_prune:
    prune.remove(module, param)

Evaluate model aftar pruning

In [None]:
net.eval()

In [None]:
net.to(device) # use GPU if available

# test and compute accuracy
with torch.no_grad():
    correct = 0
    total = 0
    for data in testloader:
        inputs, labels = data["video"].to(device), data["label"].long().to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print("Accuracy of the network on the test images: %.2f %%" % (100 * correct / total))

Model accuracy after pruning: $0.20\%$