In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import os
os.environ['HOME_DIR'] = 'drive/MyDrive/hidden-networks'
!pip install -r $HOME_DIR/requirements.txt

import sys
sys.path.append(os.path.join('/content', os.environ['HOME_DIR']))



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
import collections

from simple_mnist_example import GetSubnet, SupermaskConv, SupermaskLinear
from simple_mnist_example import train, test, set_args

arg_list = ("batch_size", "test_batch_size", "epochs", "lr", 
            "momentum", 'wd', 'no_cuda', 'seed', 
            'log_interval', 'save_name', 'data', 
            'sparsity')

ArgClass = collections.namedtuple('ArgClass', list(arg_list))

In [4]:
class Net(nn.Module):
    def __init__(self, args):
        super(Net, self).__init__()
        self.conv1 = SupermaskConv(1, 32, 3, 1, bias=False)
        self.conv2 = SupermaskConv(32, 64, 3, 1, bias=False)
        self.fc1 = SupermaskLinear(9216, 128, bias=False)
        self.fc2 = SupermaskLinear(128, 10, bias=False)
        self.args = args

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    
    def get_extra_state(self):
        return self.args
      
    def set_extra_state(self, state):
        self.args = state

In [5]:
# The main function runs the full training loop on the MNIST dataset

def main(args):
    args = ArgClass(**args)

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device {device}")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(os.path.join(args.data, 'mnist'), train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(os.path.join(args.data, 'mnist'), train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    model = Net(args).to(device)
    # NOTE: only pass the parameters where p.requires_grad == True to the optimizer! Important!
    optimizer = optim.SGD(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.wd,
    )
    print( [p for p in model.parameters() if p.requires_grad])
    criterion = nn.CrossEntropyLoss().to(device)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
    for epoch in range(1, args.epochs + 1):
        train(model, device, train_loader, optimizer, criterion, epoch)
        test(model, device, criterion, train_loader, name="Train")
        test(model, device, criterion, test_loader, name="Test")
        scheduler.step()

    if args.save_name is not None:
        torch.save(model.state_dict(), os.path.join(os.environ['HOME_DIR'], \
                                                    "trained_networks", args.save_name))
    
    return model

def get_prune_mask(layer, sparsity):
    with torch.no_grad():
        return GetSubnet.apply(layer.scores.abs(), sparsity)


def get_sign(weight):
  if weight > 0:
    return 1
  elif weight < 0:
    return -1
  else:
    return 0

def process_weight(weight):
  return abs(weight).item(), get_sign(weight)

def featurize_fc(weights, masks, sparsity, layer):
  weights = torch.transpose(weights, 0, 1)
  masks = torch.transpose(masks, 0, 1)
  weights_padded = F.pad(weights, (1,1,1,1), "constant", 0)
  data_fc = []
  for input in range(1, weights_padded.shape[0] - 1):
    for output in range(1, weights_padded.shape[1] - 1):
      mag_0, sign_0 = process_weight(weights_padded[input][output])
      mag_1, sign_1 = process_weight(weights_padded[input-1][output])
      mag_2, sign_2 = process_weight(weights_padded[input+1][output])
      mag_3, sign_3 = process_weight(weights_padded[input][output-1])
      mag_4, sign_4 = process_weight(weights_padded[input][output+1])
      include = masks[input-1][output-1].item()
      data_fc.append([input - 1, output - 1, mag_0, mag_1, mag_2, mag_3, mag_4, sign_0, sign_1, sign_2, sign_3, sign_4, sparsity, "fc"+layer, include])
  return data_fc


def featurize_conv(weights, masks, sparsity, layer):
  weights_padded =  F.pad(weights, (1,1,1,1), "constant", 0)
  data_conv = []
  for channel_num, channel in enumerate(weights_padded):
    for row in range(1, channel.shape[0] - 1):
      for col in range(1, channel.shape[1] - 1):
        mag_0, sign_0 = process_weight(channel[row][col])
        mag_1, sign_1 = process_weight(channel[row-1][col-1])
        mag_2, sign_2 = process_weight(channel[row-1][col])
        mag_3, sign_3 = process_weight(channel[row-1][col+1])
        mag_4, sign_4 = process_weight(channel[row][col-1])
        mag_5, sign_5 = process_weight(channel[row][col+1])
        mag_6, sign_6 = process_weight(channel[row+1][col-1])
        mag_7, sign_7 = process_weight(channel[row+1][col])
        mag_8, sign_8 = process_weight(channel[row+1][col+1])
        include = masks[channel_num][row-1][col-1].item()
        data_conv.append([channel_num, row - 1, col - 1, mag_0, mag_1, mag_2, mag_3, mag_4, mag_5, mag_6, mag_7, mag_8, sign_0, sign_1, sign_2, sign_3, sign_4, sign_5, sign_6, sign_7, sign_8, sparsity, "conv"+layer, include])
  return data_conv

#dimension of input by output, out_channels by in_channels by kernel
def conv2_predictors(fc1_weights, conv2_masks, conv2_output_dim):
  data = torch.zeros([len(fc1_weights), len(fc1_weights[0])], dtype=torch.int32)
  flat_length = conv2_output_dim ** 2
  for i in range(len(conv2_masks)):
    shape = conv2_masks[i].shape
    #can replace pruned_count with some other function
    pruned_count = (shape[0] * shape[1] * shape[2]) - torch.count_nonzero(conv2_masks[i])
    data[i*flat_length:(i+1)*flat_length] = pruned_count

  return data

#dimension of input by output
def fc2_predictors(fc1_weights, fc2_masks):
  data = torch.zeros([len(fc1_weights), len(fc1_weights[0])], dtype=torch.int32)
  for j in range(len(fc1_weights[0])):
    pruned_count = len(fc2_masks[0]) - torch.count_nonzero(fc2_masks[j])
    data[:, j] = pruned_count

  return data

def make_pruned_df(conv2_mat, fc2_mat):
  i_list = [] 
  j_list = [] 
  c2_list = []
  fc2_list = []

  for i in range(len(conv2_mat)):
    for j in range(len(conv2_mat[0])):
      i_list += [i]
      j_list += [j]
      c2_list += [conv2_mat[i][j].item()]
      fc2_list += [fc2_mat[i][j].item()]
  
  d = {'i': i_list, 'j': j_list, 'conv2_pruned_count': c2_list, 'fc2_pruned_count': fc2_list}
  return pd.DataFrame(data=d)


In [6]:
import csv
from tqdm import tqdm
import pickle
import math
import pandas as pd

In [None]:
sparsity = .20
args = {
  "batch_size": 64, # input batch size for training (default: 64)
  "test_batch_size": 1000, # input batch size for testing (default: 1000)
  "epochs": 14, # number of epochs to train (default: 14)
  "lr": 0.1, # learning rate (default: 0.1)
  "momentum": 0.9, # Momentum (default: 0.9)
  'wd': 0.0005, # Weight decay (default: 0.0005)
  'no_cuda': False, # disables CUDA training
  'seed': 1, # random seed (default: 1)
  'log_interval': 10000, # how many batches to wait before logging training status
  'save_name': None, # For Saving the current Model, None if not saving
  'data': '../data', # Location to store data
  'sparsity': sparsity, # 'how sparse is each layer'
}

model = main(args)
conv1_masks = get_prune_mask(model.conv1, sparsity)
conv1_weights = model.conv1.weight
conv2_masks = get_prune_mask(model.conv2, sparsity)
conv2_weights = model.conv2.weight
fc1_masks = get_prune_mask(model.fc1, sparsity)
fc1_weights = model.fc1.weight
fc2_masks = get_prune_mask(model.fc2, sparsity)
fc2_weights = model.fc2.weight

conv1_masks = conv1_masks.squeeze()
conv1_weights = conv1_weights.squeeze()
conv2_masks = conv2_masks.squeeze()
conv2_weights = conv2_weights.squeeze()

data = {}
data["conv1"] = torch.stack((conv1_masks, conv1_weights))
data["conv2"] = torch.stack((conv2_masks, conv2_weights))
data["fc1"] = torch.stack((fc1_masks, fc1_weights))
data["fc2"] = torch.stack((fc2_masks, fc2_weights))
def write_pickle(path, d):
  with open(path,'wb+') as f:
      return pickle.dump(d, f, protocol = pickle.HIGHEST_PROTOCOL)

write_pickle('./drive/MyDrive/hidden-networks/dataset/conv1_s20.pkl', data['conv1'])
write_pickle('./drive/MyDrive/hidden-networks/dataset/conv2_s20.pkl', data['conv2'])
write_pickle('./drive/MyDrive/hidden-networks/dataset/fc1_s20.pkl', data['fc1'])
write_pickle('./drive/MyDrive/hidden-networks/dataset/fc2_s20.pkl', data['fc2'])

fc1_data = featurize_fc(fc1_weights, fc1_masks, sparsity, "1")
fc_df = pd.DataFrame(fc1_data, columns = ['i', 'j', 'mag_0', 'mag_1', 'mag_2', 'mag_3', 'mag_4', 'sign_0', 'sign_1', 'sign_2', 'sign_3', 'sign_4', 'sparsity', 'layer', 'include'])

d1 = conv2_predictors(fc1_weights.T, conv2_masks, 12)
d2 = fc2_predictors(fc1_weights.T, fc2_masks.T)
pruned_fc_data = make_pruned_df(d1, d2)

final_data = pd.merge(fc_df, pruned_fc_data, how='inner', left_on=['i','j'], right_on = ['i','j'])

with open("./drive/MyDrive/hidden-networks/dataset/fc1_pruned_data.csv", "a+", newline="") as f:
  writer = csv.writer(f)
  writer.writerows(final_data)

Using device cuda
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/mnist/MNIST/raw


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/60000 (10%)


Test set: Average loss: 0.0057, Accuracy: 1107/10000 (11%)


Train set: Average loss: 0.0892, Accuracy: 6290/6