In [1]:
'''
    Load RESNET model and split it
        1. layer by layer
        2. [TODO] vertically 
'''

'\n    Load RESNET model and split it\n        1. layer by layer\n        2. [TODO] vertically \n'

In [2]:
from source.core.engine import MoP
import source.core.run_partition as run_p
from os import environ
from source.utils.dataset import *
from source.utils.misc import *

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from source.models import resnet

import torch.nn.functional as F

import numpy as np

from source.utils import io
from source.utils import testers
from source.core import engine
import json
import itertools

from torchsummary import summary

import time

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# setup config
dataset='cifar10'
environ["config"] = f"config/{dataset}.yaml"

configs = run_p.main()

configs["device"] = "cpu"
configs['load_model'] = "cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001.pt"
configs["num_partition"] = '4' #'resnet18-v2.yaml'

device :  
model :  resnet18
data_code :  cifar10
num_classes :  10
model_file :  test.pt
epochs :  0
batch_size :  128
optimizer :  sgd
lr_scheduler :  default
learning_rate :  0.01
seed :  1234
sparsity_type :  kernel
prune_ratio :  1
admm :  True
admm_epochs :  3
rho :  0.0001
multi_rho :  True
retrain_bs :  128
retrain_lr :  0.005
retrain_ep :  50
retrain_opt :  default
xentropy_weight :  1.0
warmup :  False
warmup_lr :  0.001
warmup_epochs :  10
mix_up :  True
alpha :  0.3
smooth :  False
smooth_eps :  0
save_last_model_only :  False
num_partition :  1
layer_type :  regular
bn_type :  masked
par_first_layer :  False
comm_outsize :  False
lambda_comm :  0
lambda_comp :  0
distill_model :  
distill_loss :  kl
distill_temp :  30
distill_alpha :  1


In [4]:
# load data and load or train model
model = get_model_from_code(configs).to(configs['device']) # grabs model architecture from ./source/models/escnet.py
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [5]:
'''
   1st section of resnet model 
'''
class ResnetBlockOne(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockOne, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition
        
        self.conv1 = conv_layer(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        num_bn = self.bn_partition.pop(0)
        self.bn1 = bn_layer(64) if num_bn==1 else bn_layer(64, num_bn)

        self.layer1 = self._make_layer(block, int(64*self.shrink),  num_blocks[0], stride=1)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        
        # override the the foward pass to only include the first modules 
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)

        return out

'''
   2nd section of resnet model 
'''
class ResnetBlockTwo(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockTwo, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer2 = self._make_layer(block, int(128*self.shrink), num_blocks[1], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer2(x)
        return out

'''
    3rd section of resenet model 
'''
class ResnetBlockThree(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockThree, self).__init__()

        self.in_planes = 128
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer3 = self._make_layer(block, int(256*self.shrink), num_blocks[2], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer3(x)
        return out

'''
    4-th section of resenet model 
'''
class ResnetBlockFour(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockFour, self).__init__()

        self.in_planes = 256
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer4 = self._make_layer(block, num_filters, num_blocks[3], stride=2)
        self.linear = nn.Linear(num_filters*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer4(x)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [6]:
# get framework for each layer of resnet

# inputs can be found from looking at ./source/util/misc/get_model_from_code
num_classes=configs['num_classes']
bn_layers = get_bn_layers('regular') # basic block layer
conv_layers = get_layers(configs['layer_type'])

# resnet18 inputs from resnet.py
layer_1 =  ResnetBlockOne(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_2 =  ResnetBlockTwo(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_3 =  ResnetBlockThree(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_4 =  ResnetBlockFour(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 

print(layer_4)
#block2 = |

ResnetBlockFour(
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(512, eps=1e-05,

In [7]:
split_model = [layer_1, layer_2, layer_3, layer_4]

# load model params into dictionary
state_dict = torch.load(io.get_model_path("{}".format(configs["load_model"])), map_location=configs['device'])

# add params to split
for l in split_model:
    l = io.load_state_dict(l, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


not found:  layer2.0.conv1.weight
not found:  layer2.0.conv2.weight
not found:  layer2.0.bn1.weight
not found:  layer2.0.bn1.bias
not found:  layer2.0.bn1.running_mean
not found:  layer2.0.bn1.running_var
not found:  layer2.0.bn1.num_batches_tracked
not found:  layer2.0.bn2.weight
not found:  layer2.0.bn2.bias
not found:  layer2.0.bn2.running_mean
not found:  layer2.0.bn2.running_var
not found:  layer2.0.bn2.num_batches_tracked
not found:  layer2.0.shortcut.0.weight
not found:  layer2.0.shortcut.1.weight
not found:  layer2.0.shortcut.1.bias
not found:  layer2.0.shortcut.1.running_mean
not found:  layer2.0.shortcut.1.running_var
not found:  layer2.0.shortcut.1.num_batches_tracked
not found:  layer2.1.conv1.weight
not found:  layer2.1.conv2.weight
not found:  layer2.1.bn1.weight
not found:  layer2.1.bn1.bias
not found:  layer2.1.bn1.running_mean
not found:  layer2.1.bn1.running_var
not found:  layer2.1.bn1.num_batches_tracked
not found:  layer2.1.bn2.weight
not found:  layer2.1.bn2.bias


In [8]:
# look at state dict keys
print(state_dict.keys())
for i in split_model:
    print(len(l.state_dict().keys()))



odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', '

In [9]:
# load weights into full model
model = io.load_state_dict(model, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


In [10]:
# compare outputs

input = torch.rand(1000, 3, 32, 32, device=torch.device(configs['device'])) # 1k images, 3 channels, 32x32 image (cifar100) 

# put models into eval mode and on device
model.eval()
model.to(configs['device'])
for l in split_model:
    l.eval()
    l.to(configs['device'])

# make inference 
with torch.no_grad():
        output_full = model(input)

        output_split = input
        for l in split_model:
                output_split = l(output_split)


match_count = (torch.argmax(output_split, axis=1) == torch.argmax(output_full, axis=1)).sum().item()
label_hist = output_full.sum(0)
print(f'Matches: {match_count}/{output_full.size(0)}')
print(f'histogram {label_hist}')

Matches: 1000/1000
histogram tensor([  -24.0950,  -911.8698,   302.2574,  -958.8383,   490.0737, -1314.1349,
         3054.4353,  -942.0951,  -227.5249, -1040.1658])


In [11]:
# go layer by layer to identify mismatch
split_model[0].conv1(input)

tensor([[[[-8.5278e-05, -3.6204e-04, -5.0258e-04,  ..., -6.3595e-04,
           -6.6297e-04, -6.9204e-04],
          [-2.2234e-04, -6.0189e-04, -8.2329e-04,  ..., -8.1841e-04,
           -7.1547e-04, -7.2304e-04],
          [-2.9183e-04, -9.4233e-04, -8.2087e-04,  ..., -5.8339e-04,
           -4.3400e-04, -7.1098e-04],
          ...,
          [-3.4770e-04, -5.7245e-04, -6.4811e-04,  ..., -3.8958e-04,
           -5.6954e-04, -5.9069e-04],
          [-1.7294e-04, -3.9085e-04, -7.1114e-04,  ..., -3.8511e-04,
           -4.2150e-04, -7.0768e-04],
          [-4.4009e-04, -5.8731e-04, -5.9930e-04,  ..., -3.8048e-04,
           -4.7991e-04, -5.4229e-04]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
     

In [12]:
'''
    Test I/O Logic
'''

'\n    Test I/O Logic\n'

In [13]:
# SAVE

# make dir name 
time_stamp = time.strftime("%Y%m%d-%H%M%S")
if len(configs['load_model']) == 0:
    folder_name='{}-{}-{}-np{}-pr{}-lcm{}-{}'.format( 
                configs['data-code'], 
                configs['model'], 
                configs['sparsity-type'], 
                configs['num_partition'], 
                configs['prune-ratio'], 
                configs['lambda-comm'],
                time_stamp)
else:
    folder_name = '{}-{}'.format(configs['load_model'][:-3],time_stamp)

# make folder 
folder_path = os.path.join(os.getcwd(), 'assets', 'models',folder_name)
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# save weights
index = 0
for l in split_model:
    fpath = os.path.join(os.getcwd(), 'assets', 'models', folder_path, f'layer_model_{index}.pth')
    torch.save(l.state_dict(), fpath)
    index += 1

In [14]:
# LOAD

select = 0

layer_output_sizes = [(3,32,32), (64,32,32), (128,16,16), (256,8,8)]

model_path = os.path.join(os.getcwd(), 'assets', 'models')
filenames = os.listdir(model_path)

# get dirs
split_model_names = []
for filename in filenames: # loop through all the files and folders
    if os.path.isdir(os.path.join(model_path, filename)): # check whether the current object is a folder or not
        split_model_names.append(filename)

print('Split models:')
print(split_model_names)
print()

model_name = split_model_names[select] 
print(f'loading split model {model_name}')

index = 0
for l in split_model:
    layer_state_dict = torch.load(os.path.join(model_path, model_name, f'layer_model_{index}.pth'))
    l = io.load_state_dict(l, layer_state_dict)


    summary(l, layer_output_sizes[index],device=configs['device'])
    index += 1

Split models:
['cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-135016', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-141807', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240622-092931', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240622-125523', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240623-182335', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240623-182910', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240624-204135', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240625-073155', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-090056', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-202301', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-213627', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-213724', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-214353', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240627-105309', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240629-090932', 'cifar10-resnet18-kernel-

In [15]:
'''
    add partitions and communications to configs
'''

# gets random test input (with correct size)
input_var = engine.get_input_from_code(configs)
#print(input_var)

# Config partitions and prune_ratio
configs['num_partition'] = '4'#'./config/resnet18-v2.yaml'
configs = engine.partition_generator(configs, model)
            
# Compute output size of each layer
configs['partition'] = engine.featuremap_summary(model, configs['partition'], input_var)
        
# Setup communication costs
configs['comm_costs'] = engine.set_communication_cost(model, configs['partition'],)


Inference time per data is 33.909321ms.
conv1.weight 1024
layer1.0.conv1.weight 1024
layer1.0.conv2.weight 1024
layer1.1.conv1.weight 1024
layer1.1.conv2.weight 1024
layer2.0.conv1.weight 256
layer2.0.conv2.weight 256
layer2.0.shortcut.0.weight 256
layer2.1.conv1.weight 256
layer2.1.conv2.weight 256
layer3.0.conv1.weight 64
layer3.0.conv2.weight 64
layer3.0.shortcut.0.weight 64
layer3.1.conv1.weight 64
layer3.1.conv2.weight 64
layer4.0.conv1.weight 16
layer4.0.conv2.weight 16
layer4.0.shortcut.0.weight 16
layer4.1.conv1.weight 16
layer4.1.conv2.weight 16


In [16]:
'''
    inspect IO per layer 
'''

savepath=io.get_fig_path("{}".format('.'.join(configs["load_model"].split('.')[:-1])))
print(savepath)

# this function looks for model layers named in "confgis['partition']" (other layers are ignored)
# -> "conv" and shortcut layers are the only ones "split"
# -> total 20/49 (20/62?) layers are split
counter = 0
for name, W in model.named_parameters():
        if name in configs['partition']:
            #print(f'{counter} | {name}')
            counter +=1

# Plot model
layer_id = (0,1,2) # inspect these layers
testers.plot_layer(model, configs['partition'], layer_id=layer_id, savepath=savepath)

c:\Users\natet\Desktop\graduate school\thesis\CaP\assets\figs\cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001
name: conv1.weight, img size: (3, 64) weight size: (64, 3)
name: layer1.0.conv1.weight, img size: (64, 64) weight size: (64, 64)


In [17]:
# split model 

# make copies of model per machine
num_machines = max(configs['partition']['bn_partition']) # TODO: double check this makes sense
model_machines = [model]*num_machines

module_names =  [module[0] for i, module in enumerate(model.named_modules())]
num_total_modules = len(module_names)

split_module_names = list(configs['partition'].keys())

print(module_names)

['', 'conv1', 'bn1', 'layer1', 'layer1.0', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.0.bn1', 'layer1.0.bn2', 'layer1.0.shortcut', 'layer1.1', 'layer1.1.conv1', 'layer1.1.conv2', 'layer1.1.bn1', 'layer1.1.bn2', 'layer1.1.shortcut', 'layer2', 'layer2.0', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.0.bn1', 'layer2.0.bn2', 'layer2.0.shortcut', 'layer2.0.shortcut.0', 'layer2.0.shortcut.1', 'layer2.1', 'layer2.1.conv1', 'layer2.1.conv2', 'layer2.1.bn1', 'layer2.1.bn2', 'layer2.1.shortcut', 'layer3', 'layer3.0', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.0.bn1', 'layer3.0.bn2', 'layer3.0.shortcut', 'layer3.0.shortcut.0', 'layer3.0.shortcut.1', 'layer3.1', 'layer3.1.conv1', 'layer3.1.conv2', 'layer3.1.bn1', 'layer3.1.bn2', 'layer3.1.shortcut', 'layer4', 'layer4.0', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.0.bn1', 'layer4.0.bn2', 'layer4.0.shortcut', 'layer4.0.shortcut.0', 'layer4.0.shortcut.1', 'layer4.1', 'layer4.1.conv1', 'layer4.1.conv2', 'layer4.1.bn1', 'layer4.1.bn2', 'layer4.1.shor

In [18]:
'''
    Setup datastrcuts to ID layers executed with 'extra' functionality
'''

# module numbering is based on module.named_parameters()
relu_modules = [2, 7,8,13,14,20,2428,29,35,3943,44,50,54,58,59] # execute relu on this  layer 


In [92]:
def get_nonzero_channels(atensor, dim=1):
    return torch.unique(torch.nonzero(atensor, as_tuple=True)[dim]) 


In [126]:
'''
    mock run through inference using split models 
'''

# TODO: reduce size of communicated tensors to only what is necessary 
# TODO: also check bias for nonzero
# TODO: come up with more general scheme to handle residual layers

# channel_id == INPUTS
# filter_id  == OUTPUTS

# setup input 
N_batch = 1000
input_tensor = torch.rand(N_batch, 3, 32, 32, device=torch.device(configs['device'])) # 1k images, 3 channels, 32x32 image (cifar100) 
input_size = (N_batch, 3, 32, 32)

# broadcast input_tensor to different machines
# TODO: find a better datastructure for this
#input = np.empty((num_machines, num_machines), dtype=torch.Tensor)
input = [None]*num_machines
input = [input[:] for i in range(num_machines)]
for imach in range(num_machines):
    input[imach][0] = input_tensor

# put models into eval mode and on device
model.eval()
model.to(configs['device'])

residual_input = {} # use this to keep track of inputs stored in machine memory for residule layers
add_bias = False # add bias for previous conv layer 

# make inference 
with torch.no_grad():
        # iterate through layers 1 module at a time 
        for imodule in range(2):#range(num_total_modules):

                # initialize output for ilayer
                #output = np.empty((num_machines, num_machines), dtype=torch.Tensor) # square list indexed as: output[destination/RX machine][origin/TX machine]
                # TODO: find a better datastructure for this 
                output = [None]*num_machines
                output = [output[:] for i in range(num_machines)]
                
                send_module_outputs = True

                print(f'Executing module {imodule}: {module_names[imodule]}')

                # iterate through each machine (done in parallel later)
                for imach in range(num_machines):
                        print(f'\tExecuting on machine {imach}')
                        
                        add_residual = False

                        # get the current module
                        # TODO: this is very bad for latency. Only load module if you have to 
                        curr_name, curr_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule)) 

                        # update I/O if encounter split layer
                        # TODO: revist this implementation
                        split_param_name = curr_name + '.weight'
                        if split_param_name in split_module_names:

                                # skip if machine doesnt expect input
                                if len(configs['partition'][split_param_name]['channel_id'][imach]) == 0:
                                        print(f'\t\t-No input assigned to this machine. Skipping...')
                                        continue
                                
                                # TODO: reconsider implementation 
                                # What input channels does this machine compute?
                                input_channels = torch.tensor(configs['partition'][split_param_name]['channel_id'][imach],
                                        device=torch.device(configs['device']))
                                N_in = len(input_channels) # TODO: is this used?

                                # Where to send output (map of output channels to different machines)
                                output_channel_map = configs['partition'][split_param_name]['filter_id']

                        # combine inputs from machines
                        curr_input = False 
                        rx_count = 0
                        for i in range(num_machines):
                                if not input[imach][i] == None:
                                        if not torch.is_tensor(curr_input):
                                                curr_input = input[imach][i] # initialize curr_input with first input tensor 
                                        else:
                                                curr_input += input[imach][i]
                                        rx_count += 1
                        if add_bias:
                                # TODO: check if this works (this is not required for resnet18 because no bias on conv layers)
                                dummy, prev_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule-1))
                                bias = prev_module.bias 
                                curr_input += bias/rx_count

                        # skip this machine+module if there is no input to compute 
                        if not torch.is_tensor(curr_input):
                                print('\t\t-No input sent to this machine. Skipping module')
                                continue

                        # reduce computation-- make vertically split layer 
                        # TODO: generalize this to more than conv layers 
                        if type(curr_module) == nn.Conv2d:
                                split_layer = nn.Conv2d(N_in,
                                                curr_module.weight.shape[0], # TODO does this need to be an int? (currently tensor)
                                                kernel_size= curr_module.kernel_size,
                                                stride=curr_module.stride,
                                                padding=curr_module.padding, 
                                                bias=False) # TODO: add bias during input collecting step on next layer 

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(1, input_channels))

                                # TODO: add support for splitting bias

                                # handle if this is the start of the residual layer
                                if 'shortcut' in curr_name:
                                        residual_input[str(imach)]['block_out'] = curr_input
                                        curr_input = residual_input[str(imach)]['block_in']

                        elif type(curr_module) == nn.BatchNorm2d:
                                split_layer = nn.BatchNorm2d(N_in, 
                                                curr_module.eps,
                                                momentum=curr_module.momentum, 
                                                affine=curr_module.affine, 
                                                track_running_stats=curr_module.track_running_stats)

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(0, input_channels))
                                split_layer.running_mean = torch.nn.Parameter(curr_module.running_mean.index_select(0, input_channels))
                                split_layer.running_var = torch.nn.Parameter(curr_module.running_var.index_select(0, input_channels))

                                if not curr_module.bias == None:
                                        split_layer.bias = torch.nn.Parameter(curr_module.bias.index_select(0, input_channels))
                                
                                
                                if 'shortcut' in curr_name:
                                        add_residual = True

                                # TODO: revise implementation to only compute necessary C_in to C_out 
                                # assume mach-Cout map from previous conv layer can be used as inputs for this bn layer
                                #input_channels = output_channel_map[imach]

                        elif type(curr_module) == nn.Linear:
                                # TODO: assumes there is a bias 
                                split_layer = nn.Linear(N_in, 
                                                curr_module.weight.shape[0])

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(1, input_channels))

                                # TODO: double check bias is applied correctly
                                if not curr_module.bias == None:
                                        split_layer.bias = curr_module.bias

                                # prep for linear layer
                                # TODO: assumes this always happens before linear layer 
                                # bn takes one in channel C_in_i and produces one out channel C_out_j. No communication is needed. 
                                curr_input = F.avg_pool2d(curr_input, 4)
                                curr_input = curr_input.view(curr_input.size(0), -1)

                        elif type(curr_module) == resnet.BasicBlock:
                                # save input for later 
                                residual_input[str(imach)] = {}
                                residual_input[str(imach)]['block_in'] = curr_input
                                print('\t\t-Saving input for later...')
                                send_module_outputs = False
                                continue
                        else:
                                print(f'\t\t-Skipping module {type(curr_module).__name__}')
                                send_module_outputs = False
                                continue

                        # look at which C_out need to be computed and sent
                        nonzero_Cout = torch.unique(torch.nonzero(split_layer.weight, as_tuple=True)[0]) # find nonzero dimensions in output channels

                        # eval and send to next machines 
                        #print(f'\t\t {curr_input.device}\t\t{input_channels.device}\n\t\t{split_layer.weight.device}\n\t\t{split_layer.bias}')
                        if type(curr_module) == nn.BatchNorm2d:
                                # TODO: lazy implementation. Every machine computes entire C_in to C_out but only need to compute C_in from previous layer
                                out_tensor = curr_module(curr_input)
                        else:
                                out_tensor = split_layer(curr_input.index_select(1, input_channels))

                        # debug
                        nonzero_out_tensor = torch.unique(torch.nonzero(out_tensor, as_tuple=True)[1])


                        # check if this is residual layer
                        if add_residual:
                                print('\t\t-adding residual')
                                out_tensor += residual_input[str(imach)]['block_out']

                                # erase stored 
                                residual_input[str(imach)] = {}

                        # apply ReLU after batch layers
                        if imodule in relu_modules:
                                print('\t\t-Applying ReLU')
                                F.relu(out_tensor)

                        # communicate
                        out_channel_array = torch.arange(out_tensor.shape[1])
                        for rx_mach in range(num_machines):
                                # only add to output if communication is necessary 

                                # Where does this machine send outputs? TODO: consider removing, this just maps C_out's to machine
                                output_channels = torch.tensor(output_channel_map[rx_mach],
                                        device=torch.device(configs['device']))

                                # TODO: is there a faster way to do this? Consider putting larger array 1st... just not sure which one that'd be
                                nonzero_out_channels = nonzero_Cout[torch.isin(nonzero_Cout, output_channels)]
                                if nonzero_out_channels.nelement() > 0:
                                        communication_mask = torch.isin(out_channel_array, nonzero_out_channels)

                                        # TODO: this is inefficient, redo. Probbably need to send a tensor and some info what output channels are being sent
                                        tmp_out = torch.zeros(out_tensor.shape) 
                                        tmp_out[:,communication_mask,:,:] = out_tensor[:,communication_mask,:,:]
                                        output[rx_mach][imach] = tmp_out

                # send to next layer  
                if send_module_outputs:      
                        input = output
                print(f'Finished execution of layer {imodule}')
                print()

# collect outputs -- assumes ends with Linear layer. Not sure how generalizable this is
need_to_init  = True
for rx_mach in range(num_machines):
        for tx_mach in range(num_machines):
                if not output[rx_mach][tx_mach] == None:
                        if need_to_init:
                                final_output = output[rx_mach][tx_mach]
                                need_to_init = False
                        else:
                                final_output += output[rx_mach][tx_mach]

Executing module 0: 
	Executing on machine 0
		-Skipping module ResNet
	Executing on machine 1
		-Skipping module ResNet
	Executing on machine 2
		-Skipping module ResNet
	Executing on machine 3
		-Skipping module ResNet
Finished execution of layer 0

Executing module 1: conv1
	Executing on machine 0
		-No input assigned to this machine. Skipping...
	Executing on machine 1
	Executing on machine 2
	Executing on machine 3
Finished execution of layer 1



In [127]:
'''Sanity Check C_out'''

for rx_mach in range(num_machines):
    
    need_to_init = True
    for tx_mach in range(num_machines):

        if not output[rx_mach][tx_mach] == None:
            if need_to_init:
                    final_output = output[rx_mach][tx_mach]
                    need_to_init = False
            else:
                    final_output += output[rx_mach][tx_mach]

    print(f'Machine {rx_mach}')
    print(get_nonzero_channels(final_output))     



Machine 0
tensor([ 0,  4,  8, 12, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
        29, 31, 36, 38, 42, 45, 46, 47, 54, 56, 57, 62])
Machine 1
tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31])
Machine 2
tensor([36, 38, 42, 45, 46, 47])
Machine 3
tensor([54, 56, 57, 62])


In [60]:
'''
    Compare I/O for one layer 
'''

def compare_IO(model, imodule, input, split_output):
    ''' This function is used for debugging. It computes 
    the output for a particular layer in the model using the input argument 
    and compares it to the split output'''

    curr_name, curr_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule)) 

    full_output = curr_module(input)

    io_match = torch.all(torch.eq(split_output, full_output))

    return (io_match, full_output)

In [106]:
''' Single Layer Test'''

io_match, full_output= compare_IO(model, 1, input_tensor, final_output)

print(io_match)
diff_output = torch.abs(full_output - final_output)

print()
#plt.hist(diff_output.reshape((-1,)))
#plt.show()

tensor(False)



In [37]:
''' Outside framework test'''

imodule = 1
curr_name, curr_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule)) 

full_output = curr_module(input_tensor)

split_output = torch.zeros(full_output.shape)
for imachine in range(num_machines):
    # get input channels
    input_channels = torch.tensor(configs['partition'][split_param_name]['channel_id'][imach],
                                        device=torch.device(configs['device']))
    N_in = len(input_channels) # TODO: is this used?

    # make split layer 
    split_layer = nn.Conv2d(N_in,
                    curr_module.weight.shape[0], # TODO does this need to be an int? (currently tensor)
                    kernel_size= curr_module.kernel_size,
                    stride=curr_module.stride,
                    padding=curr_module.padding, 
                    bias=False) # TODO: add bias during input collecting step on next layer 
    split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(1, input_channels))

    # eval and add
    split_output += split_layer(input_tensor.index_select(1, input_channels))

# compare 
io_match = torch.all(torch.eq(split_output, full_output))
diff_output = torch.abs(full_output - final_output)
print(io_match)
print(torch.max(diff_output))

tensor(False)
tensor(13.1018, grad_fn=<MaxBackward1>)


In [46]:
'''2D conv test'''

in_tensor = torch.arange(18, dtype=float).reshape([1,2,3,3])
print('in_tensor')
print(in_tensor)
print()

kernals = torch.randint(0,3,(1,2,2,2),  dtype=float)
kernals = torch.nn.Parameter(kernals)
print('kernals')
print(kernals)
print()

conv1 = nn.Conv2d(2,1, (2,2), padding=0, bias=False)
conv1.weight = kernals

out_tensor = conv1(in_tensor)
print('out tensor')
print(out_tensor)
print()


# verified this is the same as documentation 
# error computing 1st conv layer is probably from missing a computation?


in_tensor
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]]], dtype=torch.float64)

kernals
Parameter containing:
tensor([[[[1., 2.],
          [0., 2.]],

         [[1., 0.],
          [2., 1.]]]], dtype=torch.float64, requires_grad=True)

out tensor
tensor([[[[56., 65.],
          [83., 92.]]]], dtype=torch.float64, grad_fn=<ConvolutionBackward0>)





In [45]:
'''
    Compare vertical+horizontal split model with full model
'''

# TODO: finish
match_count = (torch.argmax(final_output, axis=1) == torch.argmax(output_full, axis=1)).sum().item()
label_hist = output_full.sum(0)
print(f'Matches: {match_count}/{output_full.size(0)}')
print(f'histogram {label_hist}')

RuntimeError: The size of tensor a (32) must match the size of tensor b (1000) at non-singleton dimension 2

In [None]:
'''
    Test partial execution of a single layer
'''

imach = 1
ilayer= 0
input_channels = torch.tensor(configs['partition'][layer_names[ilayer]]['channel_id'][imach],
                              device=torch.device(configs['device']))
output_channels = torch.tensor(configs['partition'][layer_names[ilayer]]['filter_id'][imach],
                               device=torch.device(configs['device']))
print(f'Input channels {input_channels}')
print(f'Output channels {output_channels}')

# TODO: generalize this to more than conv layers 
a_layer = model.conv1
split_layer = nn.Conv2d(len(input_channels),
                    len(output_channels),
                    kernel_size= a_layer.kernel_size,
                    stride=a_layer.stride,
                    padding=a_layer.padding, 
                    bias=a_layer.bias)

split_layer.parameters = a_layer.weight.index_select(0, output_channels).index_select(1, input_channels)


NameError: name 'layer_names' is not defined

In [None]:
'''
    iterate using children method?
'''

# creates nested structure for each layer/block 
i = 0
for name,module in enumerate( model_machines[imach].children()):
    print(f'{i} - {name}')
    i += 1


0 - 0
1 - 1
2 - 2
3 - 3
4 - 4
5 - 5
6 - 6


In [None]:

'''
    iterate using named_modules?
'''

# - gets redundant structure
# - I think the easiest way to iterate through the whole model is using this list and skipping large layers/blocks of the model and only executing conv/bn/linear etc. "layers"
# - provides good indicator of when to save input for skipped layer 

# not sure how to handle sequential layers... going to poke at this here 
i = 0
for name, module in model.named_modules():
    print(f'{i} - name: {name}; type: {type(module).__name__}')
    i += 1


0 - name: ; type: ResNet
1 - name: conv1; type: Conv2d
2 - name: bn1; type: BatchNorm1d
3 - name: layer1; type: Sequential
4 - name: layer1.0; type: BasicBlock
5 - name: layer1.0.conv1; type: Conv2d
6 - name: layer1.0.conv2; type: Conv2d
7 - name: layer1.0.bn1; type: BatchNorm1d
8 - name: layer1.0.bn2; type: BatchNorm1d
9 - name: layer1.0.shortcut; type: Sequential
10 - name: layer1.1; type: BasicBlock
11 - name: layer1.1.conv1; type: Conv2d
12 - name: layer1.1.conv2; type: Conv2d
13 - name: layer1.1.bn1; type: BatchNorm1d
14 - name: layer1.1.bn2; type: BatchNorm1d
15 - name: layer1.1.shortcut; type: Sequential
16 - name: layer2; type: Sequential
17 - name: layer2.0; type: BasicBlock
18 - name: layer2.0.conv1; type: Conv2d
19 - name: layer2.0.conv2; type: Conv2d
20 - name: layer2.0.bn1; type: BatchNorm1d
21 - name: layer2.0.bn2; type: BatchNorm1d
22 - name: layer2.0.shortcut; type: Sequential
23 - name: layer2.0.shortcut.0; type: Conv2d
24 - name: layer2.0.shortcut.1; type: BatchNorm1d

In [None]:
def combine_inputs(num_machines, input):
    '''
        combines input tensors 

        Input:
            num_machines - number of total machines
            input - num_machines x num_machines list with inputs from previous layers collected from each machine. Indexed [destination][origin]

        Output:
            curr_input - a single tensor for this layer from the combined inputs
    '''
    curr_input = False 
    for i in range(num_machines):
            if not input[imach][i] == []:
                    if not torch.is_tensor(curr_input):
                            curr_input = input[imach][i] # initialize curr_input with first input tensor 
                    else:
                            curr_input += input[imach][i]
    
    return curr_input 

In [None]:
'''
    Look at one "layer"
'''

ilayer = 1
a_name, a_module = next((m for i, m in enumerate(model.named_modules()) if i==ilayer))


a_module.eval()
with torch.no_grad():
        a_output = a_module(curr_input)

        



torch.all(torch.eq(a_output, )) 


Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [None]:
'''
    can you split certain operations across machines?
    - I think the math says yes but there is some error values a on the order of 1e-7 
    - could this be due to non-deterministic tensor operations? or is my math wrong?
'''

t = torch.rand(1, 3, 10,10)
all = torch.sum(t, 1, True)


kernel_size = 3

full_avg = F.avg_pool2d(all, kernel_size)
par_avg = torch.zeros(full_avg.shape)

for i in range(t.shape[1]):
    par_avg += F.avg_pool2d(t.index_select(1, torch.tensor(i)), kernel_size)

torch.abs(full_avg - par_avg)

tensor([[[[0.0000e+00, 1.1921e-07, 2.3842e-07],
          [1.1921e-07, 1.1921e-07, 2.3842e-07],
          [2.3842e-07, 2.3842e-07, 0.0000e+00]]]])

In [None]:
'''
    Com
'''