In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import os
os.environ['HOME_DIR'] = 'drive/MyDrive/hidden-networks'
!pip install -r $HOME_DIR/requirements.txt

import sys
sys.path.append(os.path.join('/content', os.environ['HOME_DIR']))



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
import collections

from penalized_supermask_pruning import GetSubnet, SupermaskConv, SupermaskLinear
from penalized_supermask_pruning import train, test

class ArgClass:
    def __init__(self, args):
        self.setattrs(**args)
        
    def setattrs(self, **kwargs):
        for name, val in kwargs.items():
            setattr(self, name, val)

In [4]:
class Net(nn.Module):
    def __init__(self, args, input_channels, image_size, num_labels):
        super().__init__()
        sparsities = getattr(args, "sparsity", [1.0, 1.0, 1.0, 1.0])
        self.conv1 = SupermaskConv(input_channels, 32, 3, 1, bias=False, sparsity=sparsities[0])
        self.conv2 = SupermaskConv(32, 64, 3, 1, bias=False, sparsity=sparsities[1])
        s = (image_size - 4) * (image_size - 4) * 64 // 4
        self.fc1 = SupermaskLinear(s, 128, bias=False, sparsity=sparsities[2])
        self.fc2 = SupermaskLinear(128, num_labels, bias=False, sparsity=sparsities[3])
        self.fc1.calculate_subscores = True
        self.args = args

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
   
        return output
    
    def get_extra_state(self):
        return self.args
      
    def set_extra_state(self, state):
        self.args = state

In [5]:
# The main function runs the full training loop on a dataset of your choice
def main(model_args, train_args, base_model=None):
    args = ArgClass(model_args)
    train_args = ArgClass(train_args)
    dataset = args.dataset

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device {device}")

    transform = None
    if dataset == "MNIST":
        transform = transforms.Compose([transforms.ToTensor(), 
                                        transforms.Normalize((0.1307,), (0.3081,))
                                        ])
        input_channels, image_size, num_labels = 1, 28, 10
    elif dataset == "CIFAR10":
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ])
        input_channels, image_size, num_labels = 3, 32, 10
    else:
        raise ValueError("Only supported datasets are CIFAR10 and MNIST currently.")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        getattr(datasets, dataset)(os.path.join(train_args.data, dataset), 
                                   train=True, download=True, transform=transform),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        getattr(datasets, dataset)(os.path.join(train_args.data, dataset), 
                                   train=False, transform=transform),
        batch_size=train_args.test_batch_size, shuffle=True, **kwargs)

    model = Net(args, input_channels, image_size, num_labels).to(device)

    if getattr(args, "copy_layers", None) is not None:
        if (bool(args.copy_layers) ^ (base_model is not None)):
            raise ValueError("copy_layers arg must be None or [] if base_model is not specified")
        if base_model is not None and args.copy_layers:
            for layer in args.copy_layers:
                model.load_state_dict(getattr(base_model, layer).state_dict(prefix=f"{layer}."), strict=False)
                
    if getattr(args, "freeze_layers", None):
        for layer_name in args.freeze_layers:
            getattr(model, layer_name).freeze()
            
    # NOTE: only pass the parameters where p.requires_grad == True to the optimizer! Important!
    optimizer = optim.SGD(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.wd,
    )
    #print(len([p for p in model.parameters() if p.requires_grad]))
    assert isinstance(args.epochs, list) or isinstance(args.epochs, int)
    num_epochs, check_freeze = (args.epochs, False) if isinstance(args.epochs, int) else (max(args.epochs), True)
    criterion = nn.CrossEntropyLoss().to(device)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    for epoch in range(1, num_epochs + 1):
        if check_freeze:
            for freeze_at_epoch, child in zip(args.epochs, model.children()):
                if freeze_at_epoch == epoch - 1:
                    child.freeze()
                    print(f"Freezing {child} before epoch {epoch}")

        train(model, train_args.log_interval, device, train_loader, optimizer, criterion, epoch, penalty=model_args['penalty'])
        if (train_args.train_eval_interval and epoch % train_args.train_eval_interval == 0) or (train_args.eval_on_last and epoch == args.epochs):
            test(model, device, criterion, train_loader, name="Train")
        if (train_args.test_eval_interval and epoch % train_args.test_eval_interval == 0) or (train_args.eval_on_last and epoch == args.epochs):
            test(model, device, criterion, test_loader, name="Test")
        scheduler.step()

    if args.save_name is not None:
        torch.save(model.state_dict(), os.path.join(os.environ['HOME_DIR'], \
                                                    "trained_networks", args.save_name))
    
    return model, device, train_loader, test_loader, criterion

def get_prune_mask(layer, sparsity):
    with torch.no_grad():
        return GetSubnet.apply(layer.scores.abs(), sparsity)

In [8]:
# Arguments that do not affect model at all
train_args = {
    "test_batch_size": 1000, # input batch size for testing (default: 1000)
    'data': '../data', # Location to store data (e.g. MNIST)
    'log_interval': 500, # how many batches to wait before logging training status
    'train_eval_interval': 5, # epoch interval at which to print training accuracy
    'test_eval_interval': 5, # epoch interval at which to print test accuracy
    'eval_on_last': True
}

args = {
  "dataset": "CIFAR10",
  "batch_size": 64, # input batch size for training (default: 64)
  "epochs": [40, 40, 40, 40], # number of epochs to train (default: 14)
  "lr": 0.1, # learning rate (default: 0.1)
  "momentum": 0.9, # Momentum (default: 0.9)
  'wd': 0.0005, # Weight decay (default: 0.0005)
  'no_cuda': False, # disables CUDA training
  'seed': 1, # random seed (default: 1)
  'save_name': None, # "simple20_rs2", # For Saving the current Model, None if not saving
  'sparsity': [0.5, 0.5, 0.5, 0.5], # 'how sparse is each layer'
  'copy_layers': [], # ['conv1', 'conv2', 'fc2'],
  'freeze_layers': [],
  'penalty': 0
}

penalties = [0, .1, .5, 1, 5, 10, 50]
#trained_model, device, train_loader, test_loader, criterion = main(args, train_args)
for penalty in penalties:
  name_of_experiment = 'penalized_scores_v1' + '_' + str(penalty)
  train_results = []
  test_results = []
  for rs in range(100, 101):
      args["seed"] = rs
      args["penalty"] = penalty

      trained_model, device, train_loader, test_loader, criterion = main(args, train_args)
      train_acc, train_loss = test(trained_model, device, criterion, train_loader, name="Train")
      test_acc, test_loss = test(trained_model, device, criterion, test_loader)
      train_results.append((train_acc, train_loss))
      test_results.append((test_acc, test_loss))
      torch.save((train_args, args, train_results, test_results), \
                os.path.join(os.environ["HOME_DIR"], "results", f"{name_of_experiment}_{args['dataset']}.pt"))

Using device cuda
Files already downloaded and verified

Train set: Average loss: 0.0180, Accuracy: 30573/50000 (61%)


Test set: Average loss: 0.0013, Accuracy: 5632/10000 (56%)


Train set: Average loss: 0.0184, Accuracy: 29620/50000 (59%)


Test set: Average loss: 0.0013, Accuracy: 5498/10000 (55%)


Train set: Average loss: 0.0181, Accuracy: 30188/50000 (60%)


Test set: Average loss: 0.0013, Accuracy: 5641/10000 (56%)


Train set: Average loss: 0.0170, Accuracy: 31619/50000 (63%)


Test set: Average loss: 0.0012, Accuracy: 5850/10000 (58%)


Train set: Average loss: 0.0154, Accuracy: 33110/50000 (66%)


Test set: Average loss: 0.0012, Accuracy: 6005/10000 (60%)


Train set: Average loss: 0.0133, Accuracy: 35790/50000 (72%)


Test set: Average loss: 0.0011, Accuracy: 6344/10000 (63%)


Train set: Average loss: 0.0111, Accuracy: 38957/50000 (78%)


Test set: Average loss: 0.0010, Accuracy: 6607/10000 (66%)


Train set: Average loss: 0.0099, Accuracy: 40714/50000 (81%)


Test set: Av

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5865/10000 (59%)


Train set: Average loss: 0.0130, Accuracy: 36229/50000 (72%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6408/10000 (64%)


Train set: Average loss: 0.0110, Accuracy: 39095/50000 (78%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6580/10000 (66%)


Train set: Average loss: 0.0097, Accuracy: 40927/50000 (82%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6777/10000 (68%)


Train set: Average loss: 0.0097, Accuracy: 40927/50000 (82%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6777/10000 (68%)

Using device cuda
Files already downloaded and verified

Train set: Average loss: 0.0189, Accuracy: 29381/50000 (59%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5449/10000 (54%)


Train set: Average loss: 0.0188, Accuracy: 29068/50000 (58%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5449/10000 (54%)


Train set: Average loss: 0.0192, Accuracy: 28713/50000 (57%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5352/10000 (54%)


Train set: Average loss: 0.0171, Accuracy: 31492/50000 (63%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5789/10000 (58%)


Train set: Average loss: 0.0156, Accuracy: 32840/50000 (66%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5960/10000 (60%)


Train set: Average loss: 0.0128, Accuracy: 36468/50000 (73%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6428/10000 (64%)


Train set: Average loss: 0.0105, Accuracy: 39813/50000 (80%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6699/10000 (67%)


Train set: Average loss: 0.0097, Accuracy: 40840/50000 (82%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6811/10000 (68%)


Train set: Average loss: 0.0097, Accuracy: 40840/50000 (82%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6811/10000 (68%)

Using device cuda
Files already downloaded and verified

Train set: Average loss: 0.0193, Accuracy: 28302/50000 (57%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5214/10000 (52%)


Train set: Average loss: 0.0185, Accuracy: 29720/50000 (59%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5522/10000 (55%)


Train set: Average loss: 0.0184, Accuracy: 29501/50000 (59%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5503/10000 (55%)


Train set: Average loss: 0.0166, Accuracy: 32031/50000 (64%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5985/10000 (60%)


Train set: Average loss: 0.0154, Accuracy: 33024/50000 (66%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5929/10000 (59%)


Train set: Average loss: 0.0129, Accuracy: 36280/50000 (73%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6442/10000 (64%)


Train set: Average loss: 0.0109, Accuracy: 39218/50000 (78%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6645/10000 (66%)


Train set: Average loss: 0.0100, Accuracy: 40477/50000 (81%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6766/10000 (68%)


Train set: Average loss: 0.0100, Accuracy: 40477/50000 (81%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6766/10000 (68%)

Using device cuda
Files already downloaded and verified

Train set: Average loss: 0.0186, Accuracy: 29449/50000 (59%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5505/10000 (55%)


Train set: Average loss: 0.0191, Accuracy: 28787/50000 (58%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5428/10000 (54%)


Train set: Average loss: 0.0184, Accuracy: 29553/50000 (59%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0013, Accuracy: 5544/10000 (55%)


Train set: Average loss: 0.0165, Accuracy: 31897/50000 (64%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5967/10000 (60%)


Train set: Average loss: 0.0161, Accuracy: 32016/50000 (64%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0012, Accuracy: 5792/10000 (58%)


Train set: Average loss: 0.0135, Accuracy: 35627/50000 (71%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0011, Accuracy: 6349/10000 (63%)


Train set: Average loss: 0.0114, Accuracy: 38623/50000 (77%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0010, Accuracy: 6600/10000 (66%)


Train set: Average loss: 0.0103, Accuracy: 40181/50000 (80%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6764/10000 (68%)


Train set: Average loss: 0.0103, Accuracy: 40181/50000 (80%)



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6eea475710>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process



Test set: Average loss: 0.0009, Accuracy: 6764/10000 (68%)



In [None]:
torch.zeros((20,)).device

device(type='cpu')

In [None]:
x = nn.Parameter(torch.Tensor((5, 5)))
y = torch.clone(x)
print(x)
print(y)
y += 5
print(x)
print(y)

Parameter containing:
tensor([5., 5.], requires_grad=True)
tensor([5., 5.], grad_fn=<CloneBackward0>)
Parameter containing:
tensor([5., 5.], requires_grad=True)
tensor([10., 10.], grad_fn=<AddBackward0>)


In [None]:
trained_model.conv2.weight.shape

torch.Size([64, 32, 3, 3])

In [None]:
#         train_results[conv1_sparsity].append((train_acc, train_loss))
#         test_results[conv1_sparsity].append((test_acc, test_loss))
#         torch.save((train_args, args, train_results, test_results), \
#                    os.path.join(os.environ["HOME_DIR"], "results", f"{args['dataset']}-{name_of_experiment}.pt")) # Arguments that do not affect model at all
train_args = {
    "test_batch_size": 1000, # input batch size for testing (default: 1000)
    'data': '../data', # Location to store data (e.g. MNIST)
    'log_interval': 1000, # how many batches to wait before logging training status
    'train_eval_interval': 5, # epoch interval at which to print training accuracy
    'test_eval_interval': 5, # epoch interval at which to print test accuracy
}

args = {
  "batch_size": 64, # input batch size for training (default: 64)
  "epochs": 14, # number of epochs to train (default: 14)
  "lr": 0.1, # learning rate (default: 0.1)
  "momentum": 0.9, # Momentum (default: 0.9)
  'wd': 0.0005, # Weight decay (default: 0.0005)
  'no_cuda': False, # disables CUDA training
  'seed': 1, # random seed (default: 1)
  'save_name': None, # "simple20_rs2", # For Saving the current Model, None if not saving
  'sparsity': [1.0, 1.0, 0.01, 0.1], # 'how sparse is each layer'
  'copy_layers': [], # ['conv1', 'conv2', 'fc2'],
  'freeze_layers': []
}



trained_model = main(args, train_args)

In [None]:
2, 20: 59737 / 60000
2, 20, conv1: 59796 / 60000
2, 20, conv1, conv2: 59896 / 60000
2, 20, conv1, conv2, fc1: 54630 / 60000
2, 20, conv1, conv2, fc2: 59871

-----

rs=1, [0.5, 0.1, 0.01, 0.1], 14 epochs, no freeze: 58461
rs=1, [1.0, 1.0, 0.01, 0.1], 14 epochs, no freeze: 57124
rs=1, [0.5, 0.1, 0.01, 0.1], 14 epochs, freeze conv1, conv2: 51234


In [None]:
main_model = model20

use_cuda = not main_model.args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

model = Net(None).to(device)
model.load_state_dict(main_model.state_dict())

# model.args.sparsity = 0.02

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(os.path.join(model.args.data, 'mnist'), train=False, download=True, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=model.args.test_batch_size, shuffle=True, **kwargs)
test(model, device, nn.CrossEntropyLoss().to(device), test_loader, name="Test")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw



(11.07, 0.00570564204826951)

In [None]:
model = model2
(get_prune_mask(model.conv1, model.args.sparsity) * model.conv1.weight).squeeze(1).sum(dim=0)

tensor([[0.9824, 0.0000, 0.0000],
        [1.3175, 0.4464, 0.0000],
        [0.0000, 1.0300, 1.9097]])

In [None]:
for i in range(32):
    temp = (get_prune_mask(model.conv1, model.args.sparsity) * model.conv1.weight)[i].squeeze(0)
    if temp.sum() != 0:
        print(i, temp)

6 tensor([[0.9824, 0.0000, -0.0000],
        [1.3175, 0.4464, 0.0000],
        [0.0000, 1.0300, 0.0000]])
11 tensor([[-0.0000, -0.0000, 0.0000],
        [-0.0000, -0.0000, -0.0000],
        [-0.0000, 0.0000, 0.7069]])
25 tensor([[-0.0000, -0.0000, 0.0000],
        [-0.0000, -0.0000, -0.0000],
        [0.0000, 0.0000, 1.2028]])


In [None]:
model = model2
for i in range(64):
    print(f"---{i}---")
    has = False
    for j in range(32):
        temp = (get_prune_mask(model.conv2, model.args.sparsity) * model.conv2.weight)[i, j]
        if temp.sum() != 0:
            print(i, j, temp)
            has = True

In [None]:
for layer2, layer20 in zip(model2.children(), model20.children()):
    mask2 = get_prune_mask(layer2, model2.args.sparsity)
    mask20 = get_prune_mask(layer20, model20.args.sparsity)
    print(mask2.sum(), mask20.sum(), ((mask20 != 1) & (mask2 == 1)).sum())

tensor(6.) tensor(58.) tensor(0)
tensor(369.) tensor(3687.) tensor(197)
tensor(23593.) tensor(235930.) tensor(13019)
tensor(26.) tensor(256.) tensor(9)


[SupermaskConv(1, 32, kernel_size=(3, 3), stride=(1, 1), bias=False),
 SupermaskConv(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False),
 SupermaskLinear(in_features=9216, out_features=128, bias=False),
 SupermaskLinear(in_features=128, out_features=10, bias=False)]

In [None]:
state_dict = torch.load(os.path.join(os.environ["HOME_DIR"], 'trained_networks', "30_fc2_sparsity_0.6.pt"))

In [None]:
state_dict['_extra_state'].sparsity

[1.0, 0.4, 0.1, 0.6]

In [None]:
temp_model = Net(None, 1, 28, 10)
temp_model.load_state_dict(state_dict)

<All keys matched successfully>