In [27]:
'''
    Load RESNET model and split it
        1. layer by layer
        2. [TODO] vertically 
'''

'\n    Load RESNET model and split it\n        1. layer by layer\n        2. [TODO] vertically \n'

In [28]:
from source.core.engine import MoP
import source.core.run_partition as run_p
from os import environ
from source.utils.dataset import *
from source.utils.misc import *

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from source.models import resnet

import torch.nn.functional as F

import numpy as np

from source.utils import io
import json

from torchsummary import summary

import time

In [29]:
# setup config
dataset='cifar10'
environ["config"] = f"config/{dataset}.yaml"

configs = run_p.main()

configs["device"] = "cuda:0"
configs['load_model'] = "cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001.pt"
configs["num_partition"] = 'resnet18-v2.yaml'

device :  
model :  resnet18
data_code :  cifar10
num_classes :  10
model_file :  test.pt
epochs :  0
batch_size :  128
optimizer :  sgd
lr_scheduler :  default
learning_rate :  0.01
seed :  1234
sparsity_type :  kernel
prune_ratio :  1
admm :  True
admm_epochs :  3
rho :  0.0001
multi_rho :  True
retrain_bs :  128
retrain_lr :  0.005
retrain_ep :  50
retrain_opt :  default
xentropy_weight :  1.0
warmup :  False
warmup_lr :  0.001
warmup_epochs :  10
mix_up :  True
alpha :  0.3
smooth :  False
smooth_eps :  0
save_last_model_only :  False
num_partition :  1
layer_type :  regular
bn_type :  masked
par_first_layer :  False
comm_outsize :  False
lambda_comm :  0
lambda_comp :  0
distill_model :  
distill_loss :  kl
distill_temp :  30
distill_alpha :  1


In [30]:
# load data and load or train model
model = get_model_from_code(configs).to(configs['device']) # grabs model architecture from ./source/models/escnet.py
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [31]:
'''
   1st section of resnet model 
'''
class ResnetBlockOne(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockOne, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition
        
        self.conv1 = conv_layer(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        num_bn = self.bn_partition.pop(0)
        self.bn1 = bn_layer(64) if num_bn==1 else bn_layer(64, num_bn)

        self.layer1 = self._make_layer(block, int(64*self.shrink),  num_blocks[0], stride=1)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        
        # override the the foward pass to only include the first modules 
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)

        return out

'''
   2nd section of resnet model 
'''
class ResnetBlockTwo(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockTwo, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer2 = self._make_layer(block, int(128*self.shrink), num_blocks[1], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer2(x)
        return out

'''
    3rd section of resenet model 
'''
class ResnetBlockThree(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockThree, self).__init__()

        self.in_planes = 128
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer3 = self._make_layer(block, int(256*self.shrink), num_blocks[2], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer3(x)
        return out

'''
    4-th section of resenet model 
'''
class ResnetBlockFour(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockFour, self).__init__()

        self.in_planes = 256
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer4 = self._make_layer(block, num_filters, num_blocks[3], stride=2)
        self.linear = nn.Linear(num_filters*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer4(x)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [32]:
# get framework for each layer of resnet

# inputs can be found from looking at ./source/util/misc/get_model_from_code
num_classes=configs['num_classes']
bn_layers = get_bn_layers('regular') # basic block layer
conv_layers = get_layers(configs['layer_type'])

# resnet18 inputs from resnet.py
layer_1 =  ResnetBlockOne(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_2 =  ResnetBlockTwo(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_3 =  ResnetBlockThree(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_4 =  ResnetBlockFour(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 

print(layer_4)
#block2 = |

ResnetBlockFour(
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(512, eps=1e-05,

In [33]:
split_model = [layer_1, layer_2, layer_3, layer_4]

# load model params into dictionary
state_dict = torch.load(io.get_model_path("{}".format(configs["load_model"])), map_location=configs['device'])

# add params to split
for l in split_model:
    l = io.load_state_dict(l, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


not found:  layer2.0.conv1.weight
not found:  layer2.0.conv2.weight
not found:  layer2.0.bn1.weight
not found:  layer2.0.bn1.bias
not found:  layer2.0.bn1.running_mean
not found:  layer2.0.bn1.running_var
not found:  layer2.0.bn1.num_batches_tracked
not found:  layer2.0.bn2.weight
not found:  layer2.0.bn2.bias
not found:  layer2.0.bn2.running_mean
not found:  layer2.0.bn2.running_var
not found:  layer2.0.bn2.num_batches_tracked
not found:  layer2.0.shortcut.0.weight
not found:  layer2.0.shortcut.1.weight
not found:  layer2.0.shortcut.1.bias
not found:  layer2.0.shortcut.1.running_mean
not found:  layer2.0.shortcut.1.running_var
not found:  layer2.0.shortcut.1.num_batches_tracked
not found:  layer2.1.conv1.weight
not found:  layer2.1.conv2.weight
not found:  layer2.1.bn1.weight
not found:  layer2.1.bn1.bias
not found:  layer2.1.bn1.running_mean
not found:  layer2.1.bn1.running_var
not found:  layer2.1.bn1.num_batches_tracked
not found:  layer2.1.bn2.weight
not found:  layer2.1.bn2.bias


In [34]:
# look at state dict keys
print(state_dict.keys())
for i in split_model:
    print(len(l.state_dict().keys()))



odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', '

In [35]:
# load weights into full model
model = io.load_state_dict(model, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


In [36]:
# compare outputs

input = torch.rand(1000, 3, 32, 32, device=torch.device(configs['device'])) # 1k images, 3 channels, 32x32 image (cifar100) 

# put models into eval mode and on device
model.eval()
model.to(configs['device'])
for l in split_model:
    l.eval()
    l.to(configs['device'])

# make inference 
with torch.no_grad():
        output_full = model(input)

        output_split = input
        for l in split_model:
                output_split = l(output_split)


match_count = (torch.argmax(output_split, axis=1) == torch.argmax(output_full, axis=1)).sum().item()
label_hist = output_full.sum(0)
print(f'Matches: {match_count}/{output_full.size(0)}')
print(f'histogram {label_hist}')

Matches: 1000/1000
histogram tensor([  -34.6444,  -919.2313,   322.4745,  -964.0338,   461.6558, -1301.9825,
         3073.6328,  -932.3021,  -233.5381, -1052.3689], device='cuda:0')


In [37]:
# go layer by layer to identify mismatch
split_model[0].conv1(input)

tensor([[[[-1.3076e-04, -3.7590e-04, -7.2394e-04,  ..., -4.0057e-04,
           -5.7252e-04, -5.8638e-04],
          [-3.1505e-04, -6.5126e-04, -1.0406e-03,  ..., -6.3951e-04,
           -6.6972e-04, -8.7693e-04],
          [-4.7651e-04, -8.7004e-04, -7.1680e-04,  ..., -8.7867e-04,
           -8.2595e-04, -5.9736e-04],
          ...,
          [-4.7025e-04, -8.1106e-04, -7.3174e-04,  ..., -7.1544e-04,
           -6.5507e-04, -6.4034e-04],
          [-4.4429e-04, -4.0446e-04, -1.0031e-03,  ..., -6.2399e-04,
           -7.6651e-04, -6.1660e-04],
          [-3.6429e-04, -5.9499e-04, -6.1644e-04,  ..., -6.4783e-04,
           -5.2922e-04, -4.7095e-04]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
     

In [38]:
'''
    Test I/O Logic
'''

'\n    Test I/O Logic\n'

In [39]:
# SAVE

# make dir name 
time_stamp = time.strftime("%Y%m%d-%H%M%S")
if len(configs['load_model']) == 0:
    folder_name='{}-{}-{}-np{}-pr{}-lcm{}-{}'.format( 
                configs['data-code'], 
                configs['model'], 
                configs['sparsity-type'], 
                configs['num_partition'], 
                configs['prune-ratio'], 
                configs['lambda-comm'],
                time_stamp)
else:
    folder_name = '{}-{}'.format(configs['load_model'][:-3],time_stamp)

# make folder 
folder_path = os.path.join(os.getcwd(), 'assets', 'models',folder_name)
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# save weights
index = 0
for l in split_model:
    fpath = os.path.join(os.getcwd(), 'assets', 'models', folder_path, f'layer_model_{index}.pth')
    torch.save(l.state_dict(), fpath)
    index += 1

In [50]:
# LOAD

select = 0

layer_output_sizes = [(3,32,32), (64,32,32), (128,16,16), (256,8,8)]

model_path = os.path.join(os.getcwd(), 'assets', 'models')
filenames = os.listdir(model_path)

# get dirs
split_model_names = []
for filename in filenames: # loop through all the files and folders
    if os.path.isdir(os.path.join(model_path, filename)): # check whether the current object is a folder or not
        split_model_names.append(filename)

print('Split models:')
print(split_model_names)
print()

model_name = split_model_names[select] 
print(f'loading split model {model_name}')

index = 0
for l in split_model:
    layer_state_dict = torch.load(os.path.join(model_path, model_name, f'layer_model_{index}.pth'))
    l = io.load_state_dict(l, layer_state_dict)


    summary(l, layer_output_sizes[index])
    index += 1

Split models:
['cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-135016', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-141807']

loading split model cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-135016
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
        BasicBlock-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
           Conv2d-10           [-1, 64, 32, 32]          36,864
      