In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.environ['HOME_DIR'] = 'drive/MyDrive/hidden-networks'
# !pip install -r $HOME_DIR/requirements.txt

import sys
sys.path.append(os.path.join('/content', os.environ['HOME_DIR']))

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd
import collections

from supermask_pruning import GetSubnet, SupermaskConv, SupermaskLinear
from supermask_pruning import train, test

class ArgClass:
    def __init__(self, args):
        self.setattrs(**args)
        
    def setattrs(self, **kwargs):
        for name, val in kwargs.items():
            setattr(self, name, val)

In [4]:
class Net(nn.Module):
    def __init__(self, args, input_channels, image_size, num_labels):
        super().__init__()
        
        sparsities = getattr(args, "sparsity", [{"sparsity": 1.0}, {"sparsity": 1.0}, {"sparsity": 1.0}, {"sparsity": 1.0}, {"sparsity": 1.0}])
        self.conv1 = SupermaskConv(input_channels, 64, 3, 1, bias=args.bias, init=args.init, **sparsities[0])
        self.conv2 = SupermaskConv(64, 64, 3, 1, bias=args.bias, init=args.init, **sparsities[1])
        s = (image_size - 4) * (image_size - 4) * 64 // 4
        self.fc1 = SupermaskLinear(s, 256, bias=args.bias, init=args.init, **sparsities[2])
        self.fc2 = SupermaskLinear(256, 256, bias=args.bias, init=args.init, **sparsities[3])
        self.fc3 = SupermaskLinear(256, num_labels, bias=args.bias, init=args.init, **sparsities[4])
        self.args = args

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        output = F.log_softmax(x, dim=1)
        return output
    
    def get_extra_state(self):
        return self.args
      
    def set_extra_state(self, state):
        self.args = state

In [5]:
# The main function runs the full training loop on a dataset of your choice
def main(model_args, train_args, base_model=None, trial=None):
    args = ArgClass(model_args)
    train_args = ArgClass(train_args)
    dataset = args.dataset

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device {device}")

    transform = None
    if dataset == "MNIST":
        transform = transforms.Compose([transforms.ToTensor(), 
                                        transforms.Normalize((0.1307,), (0.3081,))
                                        ])
        train_transform = transform
        input_channels, image_size, num_labels = 1, 28, 10
    elif dataset == "CIFAR10":
        train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                              transforms.RandomHorizontalFlip(),
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                              ])
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ])
        input_channels, image_size, num_labels = 3, 32, 10
    else:
        raise ValueError("Only supported datasets are CIFAR10 and MNIST currently.")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        getattr(datasets, dataset)(os.path.join(train_args.data, dataset), 
                                   train=True, download=True, transform=transform),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    train_augmented_loader = torch.utils.data.DataLoader(
        getattr(datasets, dataset)(os.path.join(train_args.data, dataset), 
                                   train=True, transform=train_transform),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        getattr(datasets, dataset)(os.path.join(train_args.data, dataset), 
                                   train=False, transform=transform),
        batch_size=train_args.test_batch_size, shuffle=True, **kwargs)

    model = Net(args, input_channels, image_size, num_labels).to(device)

    if getattr(args, "copy_layers", None) is not None:
        if (bool(args.copy_layers) ^ (base_model is not None)):
            raise ValueError("copy_layers arg must be None or [] if base_model is not specified")
        if base_model is not None and args.copy_layers:
            for layer in args.copy_layers:
                model.load_state_dict(getattr(base_model, layer).state_dict(prefix=f"{layer}."), strict=False)
            
    # NOTE: only pass the parameters where p.requires_grad == True to the optimizer! Important!
    optimizer = getattr(optim, args.optimizer)(
        [p for p in model.parameters() if p.requires_grad],
        **args.optim_kwargs,
    )
    assert isinstance(args.epochs, list) or isinstance(args.epochs, int)
    num_epochs, check_freeze = (args.epochs, False) if isinstance(args.epochs, int) else (max(args.epochs), True)
    criterion = nn.CrossEntropyLoss().to(device)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs) if args.scheduler else None 

    for epoch in range(1, num_epochs + 1):
        if check_freeze:
            for freeze_at_epoch, child in zip(args.epochs, model.children()):
                if freeze_at_epoch == epoch - 1:
                    child.freeze()
                    print(f"Freezing {child} before epoch {epoch}")

        train(model, train_args.log_interval, device, train_augmented_loader, optimizer, criterion, epoch)
        if (train_args.train_eval_interval and epoch % train_args.train_eval_interval == 0) or (train_args.eval_on_last and epoch == args.epochs):
            train_acc, train_loss = test(model, device, criterion, train_loader, name="Train")
            if trial:
                trial.set_user_attr('train_acc', {**trial.user_attrs.get('train_acc', {}), **{epoch: train_acc}})
                trial.set_user_attr('train_loss', {**trial.user_attrs.get('train_loss', {}), **{epoch: train_loss}})
        if (train_args.test_eval_interval and epoch % train_args.test_eval_interval == 0) or (train_args.eval_on_last and epoch == args.epochs):
            test_acc, test_loss = test(model, device, criterion, test_loader, name="Test")
            if trial:
                trial.set_user_attr('test_acc', {**trial.user_attrs.get('test_acc', {}), **{epoch: test_acc}})
                trial.set_user_attr('test_loss', {**trial.user_attrs.get('test_loss', {}), **{epoch: test_loss}})
                trial.report(test_acc, epoch-1)
                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()

        if scheduler:
            scheduler.step()

    if args.save_name is not None:
        torch.save(model.state_dict(), os.path.join(os.environ['HOME_DIR'], \
                                                    "trained_networks", args.save_name))
    
    return model, device, train_loader, test_loader, criterion

def get_prune_mask(layer, sparsity):
    with torch.no_grad():
        return GetSubnet.apply(layer.scores.abs(), sparsity)

def get_sign(weight):
  if weight > 0:
    return 1
  elif weight < 0:
    return -1
  else:
    return 0

def process_weight(weight):
  return abs(weight).item(), get_sign(weight)

def featurize_fc(weights, masks, sparsity, layer):
  weights = torch.transpose(weights, 0, 1)
  masks = torch.transpose(masks, 0, 1)
  weights_padded = F.pad(weights, (1,1,1,1), "constant", 0)
  data_fc = []
  for input in range(1, weights_padded.shape[0] - 1):
    for output in range(1, weights_padded.shape[1] - 1):
      mag_0, sign_0 = process_weight(weights_padded[input][output])
      mag_1, sign_1 = process_weight(weights_padded[input-1][output])
      mag_2, sign_2 = process_weight(weights_padded[input+1][output])
      mag_3, sign_3 = process_weight(weights_padded[input][output-1])
      mag_4, sign_4 = process_weight(weights_padded[input][output+1])
      include = masks[input-1][output-1].item()
      data_fc.append([input - 1, output - 1, mag_0, mag_1, mag_2, mag_3, mag_4, sign_0, sign_1, sign_2, sign_3, sign_4, sparsity, "fc"+layer, include])
  return data_fc


def featurize_conv(weights, masks, sparsity, layer):
  weights_padded =  F.pad(weights, (1,1,1,1), "constant", 0)
  data_conv = []
  for channel_num, channel in enumerate(weights_padded):
    for row in range(1, channel.shape[0] - 1):
      for col in range(1, channel.shape[1] - 1):
        mag_0, sign_0 = process_weight(channel[row][col])
        mag_1, sign_1 = process_weight(channel[row-1][col-1])
        mag_2, sign_2 = process_weight(channel[row-1][col])
        mag_3, sign_3 = process_weight(channel[row-1][col+1])
        mag_4, sign_4 = process_weight(channel[row][col-1])
        mag_5, sign_5 = process_weight(channel[row][col+1])
        mag_6, sign_6 = process_weight(channel[row+1][col-1])
        mag_7, sign_7 = process_weight(channel[row+1][col])
        mag_8, sign_8 = process_weight(channel[row+1][col+1])
        include = masks[channel_num][row-1][col-1].item()
        data_conv.append([channel_num, row - 1, col - 1, mag_0, mag_1, mag_2, mag_3, mag_4, mag_5, mag_6, mag_7, mag_8, sign_0, sign_1, sign_2, sign_3, sign_4, sign_5, sign_6, sign_7, sign_8, sparsity, "conv"+layer, include])
  return data_conv

#dimension of input by output, out_channels by in_channels by kernel
def conv2_predictors(fc1_weights, conv2_masks, conv2_output_dim):
  data = torch.zeros([len(fc1_weights), len(fc1_weights[0])], dtype=torch.int32)
  flat_length = conv2_output_dim ** 2
  for i in range(len(conv2_masks)):
    shape = conv2_masks[i].shape
    #can replace pruned_count with some other function
    pruned_count = (shape[0] * shape[1] * shape[2]) - torch.count_nonzero(conv2_masks[i])
    data[i*flat_length:(i+1)*flat_length] = pruned_count

  return data

#dimension of input by output
def fc2_predictors(fc1_weights, fc2_masks):
  data = torch.zeros([len(fc1_weights), len(fc1_weights[0])], dtype=torch.int32)
  for j in range(len(fc1_weights[0])):
    pruned_count = len(fc2_masks[0]) - torch.count_nonzero(fc2_masks[j])
    data[:, j] = pruned_count

  return data

def make_pruned_df(conv2_mat, fc2_mat):
  i_list = [] 
  j_list = [] 
  c2_list = []
  fc2_list = []

  for i in range(len(conv2_mat)):
    for j in range(len(conv2_mat[0])):
      i_list += [i]
      j_list += [j]
      c2_list += [conv2_mat[i][j].item()]
      fc2_list += [fc2_mat[i][j].item()]
  
  d = {'i': i_list, 'j': j_list, 'conv2_pruned_count': c2_list, 'fc2_pruned_count': fc2_list}
  return pd.DataFrame(data=d)

In [38]:
train_args = {
  "test_batch_size": 1000, # input batch size for testing (default: 1000)
  'data': '../data', # Location to store data (e.g. MNIST)
  'log_interval': 1000000, # how many batches to wait before logging training status
  'train_eval_interval': 20, # epoch interval at which to print training accuracy
  'test_eval_interval': 20, # epoch interval at which to print test accuracy
  'eval_on_last': True
}

args = {
  "dataset": "CIFAR10",
  "init": "kaiming_normal",
  "batch_size": 64, # input batch size for training (default: 64)
  "epochs": 100, # number of epochs to train (default: 14)
  "optimizer": "SGD",
  "optim_kwargs": {"lr": 0.1, "momentum": 0.9, "weight_decay": 0.0005},
  "scheduler": True, # False for Adam, True for SGD, does CosineAnnealing
  'no_cuda': False, # disables CUDA training
  'seed': 1000, # random seed (default: 1)
  'save_name': None, # "simple20_rs2", # For Saving the current Model, None if not saving
  'sparsity': [{"sparsity": 0.5}, {"sparsity": 0.5}, {"sparsity": 0.5}, {"sparsity": 0.5}, {"sparsity": 0.5}], # 'how sparse is each layer'
  'copy_layers': [], # ['conv1', 'conv2', 'fc2'],
  'bias': False
}

In [39]:
model, device, train_loader, test_loader, criterion = main(args, train_args, trial=None)

Using device cuda
Files already downloaded and verified

Train set: Average loss: 0.0243, Accuracy: 22781/50000 (46%)


Test set: Average loss: 0.0015, Accuracy: 4604/10000 (46%)


Train set: Average loss: 0.0227, Accuracy: 24856/50000 (50%)


Test set: Average loss: 0.0014, Accuracy: 5008/10000 (50%)


Train set: Average loss: 0.0205, Accuracy: 27050/50000 (54%)


Test set: Average loss: 0.0013, Accuracy: 5335/10000 (53%)


Train set: Average loss: 0.0157, Accuracy: 32649/50000 (65%)


Test set: Average loss: 0.0010, Accuracy: 6442/10000 (64%)


Train set: Average loss: 0.0124, Accuracy: 36677/50000 (73%)


Test set: Average loss: 0.0009, Accuracy: 7127/10000 (71%)



In [40]:
import pickle

sparsity = 0.5
conv1_masks = get_prune_mask(model.conv1, sparsity)
conv1_weights = model.conv1.weight
conv2_masks = get_prune_mask(model.conv2, sparsity)
conv2_weights = model.conv2.weight
fc1_masks = get_prune_mask(model.fc1, sparsity)
fc1_weights = model.fc1.weight
fc2_masks = get_prune_mask(model.fc2, sparsity)
fc2_weights = model.fc2.weight

conv1_masks = conv1_masks.squeeze()
conv1_weights = conv1_weights.squeeze()
conv2_masks = conv2_masks.squeeze()
conv2_weights = conv2_weights.squeeze()

data = {}
data["conv1"] = torch.stack((conv1_masks, conv1_weights))
data["conv2"] = torch.stack((conv2_masks, conv2_weights))
data["fc1"] = torch.stack((fc1_masks, fc1_weights))
data["fc2"] = torch.stack((fc2_masks, fc2_weights))

def write_pickle(path, d):
  with open(path,'wb+') as f:
      return pickle.dump(d, f, protocol = pickle.HIGHEST_PROTOCOL)

write_pickle('./drive/MyDrive/hidden-networks/dataset/conv1_s50_kaiming.pkl', data['conv1'])
write_pickle('./drive/MyDrive/hidden-networks/dataset/conv2_s50_kaiming.pkl', data['conv2'])
write_pickle('./drive/MyDrive/hidden-networks/dataset/fc1_s50_kaiming.pkl', data['fc1'])
write_pickle('./drive/MyDrive/hidden-networks/dataset/fc2_s50_kaiming.pkl', data['fc2'])

fc1_data = featurize_fc(fc1_weights, fc1_masks, sparsity, "1")

KeyboardInterrupt: ignored

In [15]:
import pandas as pd

fc_df = pd.DataFrame(fc1_data, columns = ['i', 'j', 'mag_0', 'mag_1', 'mag_2', 'mag_3', 'mag_4', 'sign_0', 'sign_1', 'sign_2', 'sign_3', 'sign_4', 'sparsity', 'layer', 'include'])

d1 = conv2_predictors(fc1_weights.T, conv2_masks, 12)
d2 = fc2_predictors(fc1_weights.T, fc2_masks.T)
pruned_fc_data = make_pruned_df(d1, d2)

final_data = pd.merge(fc_df, pruned_fc_data, how='inner', left_on=['i','j'], right_on = ['i','j'])
# with open("./drive/MyDrive/hidden-networks/dataset/fc1_pruned_data.csv", "a+", newline="") as f:
#   writer = csv.writer(f)
#   writer.writerows(final_data)

In [16]:
final_data.to_csv('./drive/MyDrive/hidden-networks/dataset/fc1_pruned_data.csv')

In [17]:
final_data.head()

Unnamed: 0,i,j,mag_0,mag_1,mag_2,mag_3,mag_4,sign_0,sign_1,sign_2,sign_3,sign_4,sparsity,layer,include,conv2_pruned_count,fc2_pruned_count
0,0,0,0.008929,0.0,0.008929,0.0,0.008929,1,0,1,0,-1,0.5,fc1,1.0,280,126
1,0,1,0.008929,0.0,0.008929,0.008929,0.008929,-1,0,-1,1,1,0.5,fc1,0.0,280,118
2,0,2,0.008929,0.0,0.008929,0.008929,0.008929,1,0,-1,-1,-1,0.5,fc1,1.0,280,126
3,0,3,0.008929,0.0,0.008929,0.008929,0.008929,-1,0,1,1,1,0.5,fc1,1.0,280,136
4,0,4,0.008929,0.0,0.008929,0.008929,0.008929,1,0,1,-1,-1,0.5,fc1,1.0,280,116


In [19]:
!pip install positional_encodings
from positional_encodings import PositionalEncoding1D, PositionalEncoding2D, PositionalEncoding3D

p_enc_2d = PositionalEncoding2D(4)
y = torch.zeros((1,12,12,4))
mapper = p_enc_2d(y).squeeze()

final_data['pos1'] = final_data.apply(lambda row: mapper[int((row.i % 144) / 12)][(row.i % 144) % 12][0].item(), axis=1)
final_data['pos2'] = final_data.apply(lambda row: mapper[int((row.i % 144) / 12)][(row.i % 144) % 12][1].item(), axis=1)
final_data['pos3'] = final_data.apply(lambda row: mapper[int((row.i % 144) / 12)][(row.i % 144) % 12][2].item(), axis=1)
final_data['pos4'] = final_data.apply(lambda row: mapper[int((row.i % 144) / 12)][(row.i % 144) % 12][3].item(), axis=1)

Collecting positional_encodings
  Downloading positional_encodings-5.0.0-py3-none-any.whl (7.3 kB)
Collecting tf-estimator-nightly==2.8.0.dev2021122109
  Downloading tf_estimator_nightly-2.8.0.dev2021122109-py2.py3-none-any.whl (462 kB)
[K     |████████████████████████████████| 462 kB 5.3 MB/s 
Installing collected packages: tf-estimator-nightly, positional-encodings
Successfully installed positional-encodings-5.0.0 tf-estimator-nightly-2.8.0.dev2021122109


In [22]:
def conv2_norm(fc1_weights, conv2_weights, conv2_output_dim):
  data = torch.zeros([len(fc1_weights), len(fc1_weights[0])], dtype=torch.float32)
  flat_length = conv2_output_dim ** 2
  for i in range(len(conv2_weights)):
    #can replace pruned_count with some other function
    norm = torch.linalg.norm(conv2_weights[i])
    data[i*flat_length:(i+1)*flat_length] = norm

  return data

#dimension of input by output
def fc2_norm(fc1_weights, fc2_weights):
  data = torch.zeros([len(fc1_weights), len(fc1_weights[0])], dtype=torch.float32)
  for j in range(len(fc1_weights[0])):
    norm = torch.norm(fc2_weights[j])
    data[:, j] = norm

  return data

d1 = conv2_norm(fc1_weights.T, conv2_weights, 12)
d2 = fc2_norm(fc1_weights.T, fc2_weights.T)
norm_fc_data = make_pruned_df(d1, d2)
final_data = pd.merge(final_data, norm_fc_data, how='inner', left_on=['i','j'], right_on = ['i','j'])

In [25]:
final_data["fc2_pruned_count"] = final_data['fc2_pruned_count_x']
final_data["conv2_pruned_count"] = final_data['conv2_pruned_count_x']
final_data["fc2_norm"] = final_data['fc2_pruned_count_y']
final_data["conv2_norm"] = final_data['conv2_pruned_count_y']

del final_data['fc2_pruned_count_x'], final_data['conv2_pruned_count_x'], final_data['fc2_pruned_count_y'], final_data['conv2_pruned_count_y']

In [26]:
final_data.to_csv('./drive/MyDrive/hidden-networks/dataset/fc1_pruned_data.csv')

In [31]:
final_data.columns

Index(['i', 'j', 'mag_0', 'mag_1', 'mag_2', 'mag_3', 'mag_4', 'sign_0',
       'sign_1', 'sign_2', 'sign_3', 'sign_4', 'sparsity', 'layer', 'include',
       'pos1', 'pos2', 'pos3', 'pos4', 'fc2_pruned_count',
       'conv2_pruned_count', 'fc2_norm', 'conv2_norm'],
      dtype='object')

In [33]:
import statsmodels.formula.api as smf

log_reg = smf.logit("include ~ sign_0 + sign_1 + sign_2 + sign_3 + sign_4 + conv2_pruned_count + fc2_pruned_count + pos1 + pos2 + pos3 + pos4 + conv2_norm + fc2_norm", data=final_data).fit()

Optimization terminated successfully.
         Current function value: 0.693055
         Iterations 5


In [34]:
log_reg.summary()

0,1,2,3
Dep. Variable:,include,No. Observations:,3211264.0
Model:,Logit,Df Residuals:,3211251.0
Method:,MLE,Df Model:,12.0
Date:,"Mon, 25 Apr 2022",Pseudo R-squ.:,0.0001325
Time:,20:02:44,Log-Likelihood:,-2225600.0
converged:,True,LL-Null:,-2225900.0
Covariance Type:,nonrobust,LLR p-value:,1.396e-118

0,1,2,3,4,5,6
,coef,std err,z,P>|z|,[0.025,0.975]
Intercept,0.1518,7.49e+04,2.03e-06,1.000,-1.47e+05,1.47e+05
sign_0,-0.0180,0.001,-16.125,0.000,-0.020,-0.016
sign_1,-0.0002,0.001,-0.172,0.864,-0.002,0.002
sign_2,0.0003,0.001,0.304,0.761,-0.002,0.003
sign_3,0.0006,0.001,0.523,0.601,-0.002,0.003
sign_4,0.0010,0.001,0.863,0.388,-0.001,0.003
conv2_pruned_count,-0.0001,9.18e-05,-1.215,0.224,-0.000,6.84e-05
fc2_pruned_count,-0.0025,0.000,-16.748,0.000,-0.003,-0.002
pos1,-0.0037,0.002,-2.311,0.021,-0.007,-0.001
