In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [3]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import save_checkpoint
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import *

In [4]:
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

In [6]:
name = "MNIST"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [7]:
class SimpleDNN(nn.Module):
    def __init__(self):
        super(SimpleDNN, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 32)
        self.fc6 = nn.Linear(32, 10)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(64)
        self.bn5 = nn.BatchNorm1d(32)

    def forward(self, x, output_hidden_states=False):
        hidden_states = []
        x = x.view(x.size(0), -1)

        x = self.bn1(self.fc1(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn2(self.fc2(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn3(self.fc3(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn4(self.fc4(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn5(self.fc5(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.fc6(x)
        if output_hidden_states:
            hidden_states.append(x)

        if output_hidden_states:
            return {"logits": x, "hidden_states": hidden_states}
        else:
            return {"logits": x}

In [8]:
model = SimpleDNN()

In [9]:
model

SimpleDNN(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=32, bias=True)
  (fc6): Linear(in_features=32, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
config = Config(name, device)

In [11]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset MNIST.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset MNIST is loaded
{'dataset_name': 'MNIST', 'path': 'ylecun/mnist', 'config_name': 'mnist', 'features': {'first_column': 'image', 'second_column': 'label'}, 'cache_dir': 'Datasets/MNIST', 'task_type': 'image_classification'}


In [12]:
# for epoch in range(num_epochs):
#     for i, batch in enumerate(train_dataloader):
#         images = batch["image"].float()
#         labels = batch["labels"]
#         # Forward pass
#         outputs = model(images)
#         logits = outputs["logits"]
#         loss = criterion(logits, labels)

#         # Backward and optimize
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if (i + 1) % 100 == 0:
#             print(
#                 f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}"
#             )

In [13]:
# torch.save(model.state_dict(), "Models/MNIST/model.pt")

In [14]:
model = SimpleDNN()

In [15]:
model.load_state_dict(torch.load("Models/MNIST/model.pt"))

<All keys matched successfully>

In [16]:
result = evaluate_model(model, config, test_dataloader)




Loss: 0.1288
Precision: 0.9671, Recall: 0.9669, F1-Score: 0.9669
              precision    recall  f1-score   support
           0       0.98      0.99      0.98       980
           1       0.98      0.99      0.98      1135
           2       0.97      0.97      0.97      1032
           3       0.95      0.98      0.96      1010
           4       0.97      0.96      0.96       982
           5       0.96      0.96      0.96       892
           6       0.97      0.97      0.97       958
           7       0.97      0.96      0.96      1028
           8       0.97      0.96      0.96       974
           9       0.95      0.94      0.95      1009
    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000



In [17]:
for concern in range(5):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_classes,
        True,
        4,
        device=device,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_classes,
        False,
        4,
        device=device,
        resample=False,
    )
    all_samples = SamplingDataset(
        train,
        200,
        num_samples,
        num_classes,
        False,
        4,
        device=device,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=None,
        exclude_layers=None,
        sparsity_ratio=0.3,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    get_sparsity(module)

    get_similarity(
        model, module, valid, concern, num_samples, num_classes, config, seed=seed
    )

Evaluate the pruned model 0





Loss: 0.2942
Precision: 0.9611, Recall: 0.9614, F1-Score: 0.9610
              precision    recall  f1-score   support
           0       0.93      0.99      0.96       980
           1       0.98      0.98      0.98      1135
           2       0.97      0.95      0.96      1032
           3       0.97      0.96      0.96      1010
           4       0.95      0.97      0.96       982
           5       0.93      0.97      0.95       892
           6       0.97      0.97      0.97       958
           7       0.97      0.95      0.96      1028
           8       0.97      0.95      0.96       974
           9       0.96      0.92      0.94      1009
    accuracy                           0.96     10000
   macro avg       0.96      0.96      0.96     10000
weighted avg       0.96      0.96      0.96     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 0.3881
Precision: 0.9497, Recall: 0.9479, F1-Score: 0.9484
              precision    recall  f1-score   support
           0       0.98      0.97      0.98       980
           1       0.93      0.99      0.96      1135
           2       0.98      0.93      0.95      1032
           3       0.93      0.95      0.94      1010
           4       0.96      0.95      0.95       982
           5       0.94      0.92      0.93       892
           6       0.98      0.95      0.96       958
           7       0.97      0.91      0.94      1028
           8       0.92      0.97      0.94       974
           9       0.91      0.93      0.92      1009
    accuracy                           0.95     10000
   macro avg       0.95      0.95      0.95     10000
weighted avg       0.95      0.95      0.95     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 0.3175
Precision: 0.9600, Recall: 0.9589, F1-Score: 0.9590
              precision    recall  f1-score   support
           0       0.98      0.99      0.98       980
           1       0.98      0.99      0.99      1135
           2       0.90      0.98      0.94      1032
           3       0.94      0.97      0.96      1010
           4       0.96      0.96      0.96       982
           5       0.98      0.94      0.96       892
           6       0.97      0.97      0.97       958
           7       0.98      0.90      0.94      1028
           8       0.94      0.97      0.95       974
           9       0.96      0.91      0.94      1009
    accuracy                           0.96     10000
   macro avg       0.96      0.96      0.96     10000
weighted avg       0.96      0.96      0.96     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 0.2991
Precision: 0.9614, Recall: 0.9608, F1-Score: 0.9609
              precision    recall  f1-score   support
           0       0.98      0.99      0.98       980
           1       0.98      0.99      0.99      1135
           2       0.97      0.97      0.97      1032
           3       0.91      0.98      0.94      1010
           4       0.96      0.96      0.96       982
           5       0.94      0.96      0.95       892
           6       0.98      0.95      0.96       958
           7       0.96      0.95      0.96      1028
           8       0.97      0.95      0.96       974
           9       0.96      0.92      0.94      1009
    accuracy                           0.96     10000
   macro avg       0.96      0.96      0.96     10000
weighted avg       0.96      0.96      0.96     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 0.4052
Precision: 0.9543, Recall: 0.9522, F1-Score: 0.9524
              precision    recall  f1-score   support
           0       0.99      0.98      0.98       980
           1       0.98      0.99      0.98      1135
           2       0.99      0.92      0.95      1032
           3       0.96      0.97      0.96      1010
           4       0.91      0.97      0.94       982
           5       0.98      0.94      0.96       892
           6       0.98      0.95      0.97       958
           7       0.97      0.88      0.92      1028
           8       0.92      0.97      0.94       974
           9       0.87      0.96      0.91      1009
    accuracy                           0.95     10000
   macro avg       0.95      0.95      0.95     10000
weighted avg       0.95      0.95      0.95     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 

In [18]:
get_sparsity(module)

(0.29879162982611257,
 {'fc1.weight': 0.2997448979591837,
  'fc1.bias': 0.0,
  'fc2.weight': 0.298828125,
  'fc2.bias': 0.0,
  'fc3.weight': 0.296875,
  'fc3.bias': 0.0,
  'fc4.weight': 0.296875,
  'fc4.bias': 0.0,
  'fc5.weight': 0.296875,
  'fc5.bias': 0.0,
  'fc6.weight': 0.28125,
  'fc6.bias': 0.0})