In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [3]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import save_checkpoint
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import *

In [4]:
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

In [6]:
name = "MNIST"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [7]:
class SimpleDNN(nn.Module):
    def __init__(self):
        super(SimpleDNN, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 32)
        self.fc6 = nn.Linear(32, 10)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(64)
        self.bn5 = nn.BatchNorm1d(32)

    def forward(self, x, output_hidden_states=False):
        hidden_states = []
        x = x.view(x.size(0), -1)

        x = self.bn1(self.fc1(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn2(self.fc2(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn3(self.fc3(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn4(self.fc4(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn5(self.fc5(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.fc6(x)
        if output_hidden_states:
            hidden_states.append(x)

        if output_hidden_states:
            return {"logits": x, "hidden_states": hidden_states}
        else:
            return {"logits": x}

In [8]:
model = SimpleDNN()

In [9]:
model

SimpleDNN(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=32, bias=True)
  (fc6): Linear(in_features=32, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
config = Config(name, device)

In [11]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset MNIST.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset MNIST is loaded
{'dataset_name': 'MNIST', 'path': 'ylecun/mnist', 'config_name': 'mnist', 'features': {'first_column': 'image', 'second_column': 'label'}, 'cache_dir': 'Datasets/MNIST', 'task_type': 'image_classification'}


In [12]:
# for epoch in range(num_epochs):
#     for i, batch in enumerate(train_dataloader):
#         images = batch["image"].float()
#         labels = batch["labels"]
#         # Forward pass
#         outputs = model(images)
#         logits = outputs["logits"]
#         loss = criterion(logits, labels)

#         # Backward and optimize
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if (i + 1) % 100 == 0:
#             print(
#                 f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}"
#             )

In [13]:
# torch.save(model.state_dict(), "Models/MNIST/model.pt")

In [14]:
model = SimpleDNN()

In [15]:
model.load_state_dict(torch.load("Models/MNIST/model.pt"))

<All keys matched successfully>

In [16]:
result = evaluate_model(model, config, test_dataloader)




Loss: 0.1288
Precision: 0.9671, Recall: 0.9669, F1-Score: 0.9669
              precision    recall  f1-score   support
           0       0.98      0.99      0.98       980
           1       0.98      0.99      0.98      1135
           2       0.97      0.97      0.97      1032
           3       0.95      0.98      0.96      1010
           4       0.97      0.96      0.96       982
           5       0.96      0.96      0.96       892
           6       0.97      0.97      0.97       958
           7       0.97      0.96      0.96      1028
           8       0.97      0.96      0.96       974
           9       0.95      0.94      0.95      1009
    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000



In [17]:
for concern in range(5):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_classes,
        True,
        4,
        device=device,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_classes,
        False,
        4,
        device=device,
        resample=False,
    )
    all_samples = SamplingDataset(
        train,
        200,
        num_samples,
        num_classes,
        False,
        4,
        device=device,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_wanda(
        module,
        config,
        positive_samples,
        sparsity_ratio=0.3,
        include_layers=None,
        exclude_layers=None,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    get_sparsity(module)

    get_similarity(
        model, module, valid, concern, num_samples, num_classes, config, seed=seed
    )

Evaluate the pruned model 0





Loss: 1.9286
Precision: 0.3174, Recall: 0.2617, F1-Score: 0.2069
              precision    recall  f1-score   support
           0       0.13      1.00      0.23       980
           1       0.00      0.00      0.00      1135
           2       1.00      0.10      0.18      1032
           3       0.98      0.41      0.58      1010
           4       0.00      0.00      0.00       982
           5       0.51      0.46      0.49       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.55      0.65      0.59       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.25     10000
   macro avg       0.32      0.26      0.21     10000
weighted avg       0.31      0.25      0.20     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.2328
Precision: 0.1992, Recall: 0.1531, F1-Score: 0.1028
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.12      1.00      0.21      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       1.00      0.36      0.53       958
           7       0.00      0.00      0.00      1028
           8       0.87      0.17      0.28       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.16     10000
   macro avg       0.20      0.15      0.10     10000
weighted avg       0.19      0.16      0.10     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.9500
Precision: 0.0826, Recall: 0.1644, F1-Score: 0.0880
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.11      1.00      0.20      1032
           3       0.71      0.64      0.68      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.17     10000
   macro avg       0.08      0.16      0.09     10000
weighted avg       0.08      0.17      0.09     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.2158
Precision: 0.5103, Recall: 0.1242, F1-Score: 0.0614
              precision    recall  f1-score   support
           0       1.00      0.01      0.01       980
           1       0.00      0.00      0.00      1135
           2       1.00      0.06      0.10      1032
           3       0.10      1.00      0.19      1010
           4       0.00      0.00      0.00       982
           5       1.00      0.00      0.00       892
           6       1.00      0.00      0.01       958
           7       0.00      0.00      0.00      1028
           8       1.00      0.18      0.30       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.12     10000
   macro avg       0.51      0.12      0.06     10000
weighted avg       0.49      0.12      0.06     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.7934
Precision: 0.0099, Recall: 0.1000, F1-Score: 0.0180
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.10      1.00      0.18       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 