In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [3]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_data
from src.utils.load import save_checkpoint
from src.utils.load import load_model, load_data, save_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune_head import head_importance_prunning
from src.pruning.prune import *

In [4]:
input_size = 28 * 28
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

In [6]:
name = "MNIST"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44

In [7]:
class SimpleDNN(nn.Module):
    def __init__(self):
        super(SimpleDNN, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, 32)
        self.fc6 = nn.Linear(32, 10)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(64)
        self.bn5 = nn.BatchNorm1d(32)

    def forward(self, x, output_hidden_states=False):
        hidden_states = []
        x = x.view(x.size(0), -1)

        x = self.bn1(self.fc1(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn2(self.fc2(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn3(self.fc3(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn4(self.fc4(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.bn5(self.fc5(x))
        if output_hidden_states:
            hidden_states.append(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = self.fc6(x)
        if output_hidden_states:
            hidden_states.append(x)

        if output_hidden_states:
            return {"logits": x, "hidden_states": hidden_states}
        else:
            return {"logits": x}

In [8]:
model = SimpleDNN()

In [9]:
model

SimpleDNN(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=64, bias=True)
  (fc5): Linear(in_features=64, out_features=32, bias=True)
  (fc6): Linear(in_features=32, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
config = Config(name, device)

In [11]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset MNIST.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset MNIST is loaded
{'dataset_name': 'MNIST', 'path': 'ylecun/mnist', 'config_name': 'mnist', 'features': {'first_column': 'image', 'second_column': 'label'}, 'cache_dir': 'Datasets/MNIST', 'task_type': 'image_classification'}


In [12]:
# for epoch in range(num_epochs):
#     for i, batch in enumerate(train_dataloader):
#         images = batch["image"].float()
#         labels = batch["labels"]
#         # Forward pass
#         outputs = model(images)
#         logits = outputs["logits"]
#         loss = criterion(logits, labels)

#         # Backward and optimize
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if (i + 1) % 100 == 0:
#             print(
#                 f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}"
#             )

In [13]:
# torch.save(model.state_dict(), "Models/MNIST/model.pt")

In [14]:
model = SimpleDNN()

In [15]:
model.load_state_dict(torch.load("Models/MNIST/model.pt"))

<All keys matched successfully>

In [16]:
result = evaluate_model(model, config, test_dataloader)




Loss: 0.1288
Precision: 0.9671, Recall: 0.9669, F1-Score: 0.9669
              precision    recall  f1-score   support
           0       0.98      0.99      0.98       980
           1       0.98      0.99      0.98      1135
           2       0.97      0.97      0.97      1032
           3       0.95      0.98      0.96      1010
           4       0.97      0.96      0.96       982
           5       0.96      0.96      0.96       892
           6       0.97      0.97      0.97       958
           7       0.97      0.96      0.96      1028
           8       0.97      0.96      0.96       974
           9       0.95      0.94      0.95      1009
    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000



In [17]:
for concern in range(num_classes):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_classes,
        True,
        4,
        device=device,
        resample=False,
    )
    negative_samples = SamplingDataset(
        train,
        concern,
        num_samples,
        num_classes,
        False,
        4,
        device=device,
        resample=False,
    )
    all_samples = SamplingDataset(
        train,
        200,
        num_samples,
        num_classes,
        False,
        4,
        device=device,
        resample=False,
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        config,
        positive_samples,
        negative_samples,
        include_layers=None,
        exclude_layers=None,
        sparsity_ratio=0.6,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, config, test_dataloader, verbose=True)
    get_sparsity(module)

    get_similarity(
        model, module, valid, concern, num_samples, num_classes, config, seed=seed
    )

Evaluate the pruned model 0





Loss: 3.4924
Precision: 0.0098, Recall: 0.1000, F1-Score: 0.0179
              precision    recall  f1-score   support
           0       0.10      1.00      0.18       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.6011
Precision: 0.0114, Recall: 0.1000, F1-Score: 0.0204
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.11      1.00      0.20      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.11     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.11      0.02     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.2966
Precision: 0.1104, Recall: 0.1057, F1-Score: 0.0297
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.10      1.00      0.19      1032
           3       1.00      0.06      0.11      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.11     10000
   macro avg       0.11      0.11      0.03     10000
weighted avg       0.11      0.11      0.03     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.2320
Precision: 0.0101, Recall: 0.1000, F1-Score: 0.0183
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.10      1.00      0.18      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.3102
Precision: 0.0098, Recall: 0.1000, F1-Score: 0.0179
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.10      1.00      0.18       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 2.1643
Precision: 0.0632, Recall: 0.1310, F1-Score: 0.0566
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.54      0.31      0.39      1010
           4       0.00      0.00      0.00       982
           5       0.09      1.00      0.17       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.12     10000
   macro avg       0.06      0.13      0.06     10000
weighted avg       0.06      0.12      0.06     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 3.0075
Precision: 0.0096, Recall: 0.1000, F1-Score: 0.0175
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.10      1.00      0.17       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdb03dab100>
Traceback (most recent call last):
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/connection.py", line 1

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
CCA coefficients mean concern: (np.float64(0.7069097370402727), np.float64(0.7069097370402727))
CCA coefficients mean non-concern: (np.float64(0.7374416959573262), np.float64(0.7374416959573262))
Linear CKA concern: 0.9930596769399709
Linear CKA non-concern: 0.4101843684364933
Kernel CKA concern: 0.9884769139228012
Kernel CKA non-concern: 0.43903815945594454
Evaluate the pruned model 7





Loss: 2.2333
Precision: 0.0676, Recall: 0.1786, F1-Score: 0.0866
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.12      1.00      0.21      1028
           8       0.00      0.00      0.00       974
           9       0.56      0.79      0.65      1009
    accuracy                           0.18     10000
   macro avg       0.07      0.18      0.09     10000
weighted avg       0.07      0.18      0.09     10000
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square 




Loss: 1.9645
Precision: 0.1103, Recall: 0.1487, F1-Score: 0.0842
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       1.00      0.49      0.66      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.10      1.00      0.19       974
           9       0.00      0.00      0.00      1009
    accuracy                           0.15     10000
   macro avg       0.11      0.15      0.08     10000
weighted avg       0.12      0.15      0.09     10000



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdb03dab100>
Traceback (most recent call last):
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/connection.py", line 1

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
CCA coefficients mean concern: (np.float64(0.6342453844606138), np.float64(0.6342453844606138))
CCA coefficients mean non-concern: (np.float64(0.7460644851832636), np.float64(0.7460644851832636))
Linear CKA concern: 0.9848326566348916
Linear CKA non-concern: 0.401630118266487
Kernel CKA concern: 0.9605298981831273
Kernel CKA non-concern: 0.6051276716388426
Evaluate the pruned model 9





Loss: 2.6264
Precision: 0.0101, Recall: 0.1000, F1-Score: 0.0183
              precision    recall  f1-score   support
           0       0.00      0.00      0.00       980
           1       0.00      0.00      0.00      1135
           2       0.00      0.00      0.00      1032
           3       0.00      0.00      0.00      1010
           4       0.00      0.00      0.00       982
           5       0.00      0.00      0.00       892
           6       0.00      0.00      0.00       958
           7       0.00      0.00      0.00      1028
           8       0.00      0.00      0.00       974
           9       0.10      1.00      0.18      1009
    accuracy                           0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdb03dab100>
Traceback (most recent call last):
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/connection.py", line 1

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
CCA coefficients mean concern: (np.float64(0.738604051675104), np.float64(0.738604051675104))
CCA coefficients mean non-concern: (np.float64(0.7545242451951308), np.float64(0.7545242451951308))
Linear CKA concern: 0.9992644832875293
Linear CKA non-concern: 0.3142990830237377
Kernel CKA concern: 0.9981984227828664
Kernel CKA non-concern: 0.5005892493493059


In [18]:
get_sparsity(module)

(0.29879162982611257,
 {'fc1.weight': 0.2997448979591837,
  'fc1.bias': 0.0,
  'fc2.weight': 0.298828125,
  'fc2.bias': 0.0,
  'fc3.weight': 0.296875,
  'fc3.bias': 0.0,
  'fc4.weight': 0.296875,
  'fc4.bias': 0.0,
  'fc5.weight': 0.296875,
  'fc5.bias': 0.0,
  'fc6.weight': 0.28125,
  'fc6.bias': 0.0})