In [17]:
'''
    Load RESNET model and split it
        1. layer by layer
        2. [TODO] vertically 
'''

'\n    Load RESNET model and split it\n        1. layer by layer\n        2. [TODO] vertically \n'

In [18]:
from source.core.engine import MoP
import source.core.run_partition as run_p
from os import environ
from source.utils.dataset import *
from source.utils.misc import *

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from source.models import resnet

import torch.nn.functional as F

import numpy as np

from source.utils import io
from source.utils import testers
from source.core import engine
import json

from torchsummary import summary

import time

In [19]:
# setup config
dataset='cifar10'
environ["config"] = f"config/{dataset}.yaml"

configs = run_p.main()

configs["device"] = "cpu"
configs['load_model'] = "cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001.pt"
configs["num_partition"] = 'resnet18-v2.yaml'

device :  
model :  resnet18
data_code :  cifar10
num_classes :  10
model_file :  test.pt
epochs :  0
batch_size :  128
optimizer :  sgd
lr_scheduler :  default
learning_rate :  0.01
seed :  1234
sparsity_type :  kernel
prune_ratio :  1
admm :  True
admm_epochs :  3
rho :  0.0001
multi_rho :  True
retrain_bs :  128
retrain_lr :  0.005
retrain_ep :  50
retrain_opt :  default
xentropy_weight :  1.0
warmup :  False
warmup_lr :  0.001
warmup_epochs :  10
mix_up :  True
alpha :  0.3
smooth :  False
smooth_eps :  0
save_last_model_only :  False
num_partition :  1
layer_type :  regular
bn_type :  masked
par_first_layer :  False
comm_outsize :  False
lambda_comm :  0
lambda_comp :  0
distill_model :  
distill_loss :  kl
distill_temp :  30
distill_alpha :  1


In [20]:
# load data and load or train model
model = get_model_from_code(configs).to(configs['device']) # grabs model architecture from ./source/models/escnet.py
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [21]:
'''
   1st section of resnet model 
'''
class ResnetBlockOne(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockOne, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition
        
        self.conv1 = conv_layer(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        num_bn = self.bn_partition.pop(0)
        self.bn1 = bn_layer(64) if num_bn==1 else bn_layer(64, num_bn)

        self.layer1 = self._make_layer(block, int(64*self.shrink),  num_blocks[0], stride=1)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        
        # override the the foward pass to only include the first modules 
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)

        return out

'''
   2nd section of resnet model 
'''
class ResnetBlockTwo(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockTwo, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer2 = self._make_layer(block, int(128*self.shrink), num_blocks[1], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer2(x)
        return out

'''
    3rd section of resenet model 
'''
class ResnetBlockThree(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockThree, self).__init__()

        self.in_planes = 128
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer3 = self._make_layer(block, int(256*self.shrink), num_blocks[2], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer3(x)
        return out

'''
    4-th section of resenet model 
'''
class ResnetBlockFour(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockFour, self).__init__()

        self.in_planes = 256
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer4 = self._make_layer(block, num_filters, num_blocks[3], stride=2)
        self.linear = nn.Linear(num_filters*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer4(x)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [22]:
# get framework for each layer of resnet

# inputs can be found from looking at ./source/util/misc/get_model_from_code
num_classes=configs['num_classes']
bn_layers = get_bn_layers('regular') # basic block layer
conv_layers = get_layers(configs['layer_type'])

# resnet18 inputs from resnet.py
layer_1 =  ResnetBlockOne(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_2 =  ResnetBlockTwo(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_3 =  ResnetBlockThree(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_4 =  ResnetBlockFour(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 

print(layer_4)
#block2 = |

ResnetBlockFour(
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(512, eps=1e-05,

In [23]:
split_model = [layer_1, layer_2, layer_3, layer_4]

# load model params into dictionary
state_dict = torch.load(io.get_model_path("{}".format(configs["load_model"])), map_location=configs['device'])

# add params to split
for l in split_model:
    l = io.load_state_dict(l, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


not found:  layer2.0.conv1.weight
not found:  layer2.0.conv2.weight
not found:  layer2.0.bn1.weight
not found:  layer2.0.bn1.bias
not found:  layer2.0.bn1.running_mean
not found:  layer2.0.bn1.running_var
not found:  layer2.0.bn1.num_batches_tracked
not found:  layer2.0.bn2.weight
not found:  layer2.0.bn2.bias
not found:  layer2.0.bn2.running_mean
not found:  layer2.0.bn2.running_var
not found:  layer2.0.bn2.num_batches_tracked
not found:  layer2.0.shortcut.0.weight
not found:  layer2.0.shortcut.1.weight
not found:  layer2.0.shortcut.1.bias
not found:  layer2.0.shortcut.1.running_mean
not found:  layer2.0.shortcut.1.running_var
not found:  layer2.0.shortcut.1.num_batches_tracked
not found:  layer2.1.conv1.weight
not found:  layer2.1.conv2.weight
not found:  layer2.1.bn1.weight
not found:  layer2.1.bn1.bias
not found:  layer2.1.bn1.running_mean
not found:  layer2.1.bn1.running_var
not found:  layer2.1.bn1.num_batches_tracked
not found:  layer2.1.bn2.weight
not found:  layer2.1.bn2.bias


In [24]:
# look at state dict keys
print(state_dict.keys())
for i in split_model:
    print(len(l.state_dict().keys()))



odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', '

In [25]:
# load weights into full model
model = io.load_state_dict(model, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


In [26]:
# compare outputs

input = torch.rand(1000, 3, 32, 32, device=torch.device(configs['device'])) # 1k images, 3 channels, 32x32 image (cifar100) 

# put models into eval mode and on device
model.eval()
model.to(configs['device'])
for l in split_model:
    l.eval()
    l.to(configs['device'])

# make inference 
with torch.no_grad():
        output_full = model(input)

        output_split = input
        for l in split_model:
                output_split = l(output_split)


match_count = (torch.argmax(output_split, axis=1) == torch.argmax(output_full, axis=1)).sum().item()
label_hist = output_full.sum(0)
print(f'Matches: {match_count}/{output_full.size(0)}')
print(f'histogram {label_hist}')

Matches: 1000/1000
histogram tensor([  -33.9977,  -915.5308,   284.7191,  -976.8938,   468.4694, -1295.0208,
         3070.4282,  -915.8573,  -226.5788, -1049.0566])


In [27]:
# go layer by layer to identify mismatch
split_model[0].conv1(input)

tensor([[[[-5.6931e-05, -5.0791e-04, -2.9262e-04,  ..., -3.6080e-04,
           -3.1026e-04, -3.0171e-04],
          [-4.1799e-04, -3.0245e-04, -8.5406e-04,  ..., -5.4843e-04,
           -2.6983e-04, -6.8433e-04],
          [-1.7801e-04, -7.5514e-04, -6.7573e-04,  ..., -5.9406e-04,
           -7.4440e-04, -6.8391e-04],
          ...,
          [-4.0061e-04, -8.3287e-04, -6.1543e-04,  ..., -3.6693e-04,
           -3.4874e-04, -5.0168e-04],
          [-6.1675e-04, -4.5475e-04, -9.5080e-04,  ..., -5.4160e-04,
           -5.3737e-04, -5.7436e-04],
          [-4.0809e-04, -6.8612e-04, -5.1417e-04,  ..., -4.2185e-04,
           -3.6707e-04, -3.8267e-04]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
     

In [28]:
'''
    Test I/O Logic
'''

'\n    Test I/O Logic\n'

In [29]:
# SAVE

# make dir name 
time_stamp = time.strftime("%Y%m%d-%H%M%S")
if len(configs['load_model']) == 0:
    folder_name='{}-{}-{}-np{}-pr{}-lcm{}-{}'.format( 
                configs['data-code'], 
                configs['model'], 
                configs['sparsity-type'], 
                configs['num_partition'], 
                configs['prune-ratio'], 
                configs['lambda-comm'],
                time_stamp)
else:
    folder_name = '{}-{}'.format(configs['load_model'][:-3],time_stamp)

# make folder 
folder_path = os.path.join(os.getcwd(), 'assets', 'models',folder_name)
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

# save weights
index = 0
for l in split_model:
    fpath = os.path.join(os.getcwd(), 'assets', 'models', folder_path, f'layer_model_{index}.pth')
    torch.save(l.state_dict(), fpath)
    index += 1

In [30]:
# LOAD

select = 0

layer_output_sizes = [(3,32,32), (64,32,32), (128,16,16), (256,8,8)]

model_path = os.path.join(os.getcwd(), 'assets', 'models')
filenames = os.listdir(model_path)

# get dirs
split_model_names = []
for filename in filenames: # loop through all the files and folders
    if os.path.isdir(os.path.join(model_path, filename)): # check whether the current object is a folder or not
        split_model_names.append(filename)

print('Split models:')
print(split_model_names)
print()

model_name = split_model_names[select] 
print(f'loading split model {model_name}')

index = 0
for l in split_model:
    layer_state_dict = torch.load(os.path.join(model_path, model_name, f'layer_model_{index}.pth'))
    l = io.load_state_dict(l, layer_state_dict)


    summary(l, layer_output_sizes[index],device=configs['device'])
    index += 1

Split models:
['cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-135016', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-141807', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240622-092931', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240622-125523', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240623-182335', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240623-182910', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240624-204135', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240625-073155', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-090056', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-202301', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-213627', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-213724', 'cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240626-214353']

loading split model cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001-20240615-135016
----------------------------------------------------------------
      

In [31]:
'''
    add partitions and communications to configs
'''

# gets random test input (with correct size)
input_var = engine.get_input_from_code(configs)
#print(input_var)

# Config partitions and prune_ratio
configs['num_partition'] = './config/resnet18-v2.yaml'
configs = engine.partition_generator(configs, model)
            
# Compute output size of each layer
configs['partition'] = engine.featuremap_summary(model, configs['partition'], input_var)
        
# Setup communication costs
configs['comm_costs'] = engine.set_communication_cost(model, configs['partition'],)


{'inputs': [[1, 1, 1, 1, 1]], 'conv1.weight': [[1, 1, 1, 1]], 'layer1.0.conv1.weight': [[1, 1, 1, 1]], 'layer1.0.conv2.weight': [[1, 1, 1, 1]], 'layer1.1.conv1.weight': [[1, 1, 1, 1]], 'layer1.1.conv2.weight': [[1, 1, 1, 1]], 'layer2.0.conv1.weight': [[1, 1, 1, 1]], 'layer2.0.conv2.weight': [[1, 1, 1, 1]], 'layer2.0.shortcut.0.weight': [[1, 1, 1, 1]], 'layer2.1.conv1.weight': [[1, 1, 1, 1]], 'layer2.1.conv2.weight': [[1, 1, 1, 1]], 'layer3.0.conv1.weight': [[1, 1, 1, 1]], 'layer3.0.conv2.weight': [[1, 1, 1, 1]], 'layer3.0.shortcut.0.weight': [[1, 1, 1, 1]], 'layer3.1.conv1.weight': [[1, 1, 1, 1]], 'layer3.1.conv2.weight': [[1, 1, 1, 1]], 'layer4.0.conv1.weight': [[1, 1, 1, 1]], 'layer4.0.conv2.weight': [[1, 1, 1, 1]], 'layer4.0.shortcut.0.weight': [[1, 1, 1, 1]], 'layer4.1.conv1.weight': [[1, 1, 1, 1]], 'layer4.1.conv2.weight': [[1, 1, 1, 1]]} {'conv1.weight': [[0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 1], [1, 1, 1, 0]], 'layer1.0.conv1.weight': [[0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 1], [

In [32]:
'''
    inspect IO per layer 
'''

savepath=io.get_fig_path("{}".format('.'.join(configs["load_model"].split('.')[:-1])))
print(savepath)

# this function looks for model layers named in "confgis['partition']" (other layers are ignored)
# -> "conv" and shortcut layers are the only ones "split"
# -> total 20/49 (20/62?) layers are split
counter = 0
for name, W in model.named_parameters():
        if name in configs['partition']:
            #print(f'{counter} | {name}')
            counter +=1

# Plot model
layer_id = (0,1,2) # inspect these layers
testers.plot_layer(model, configs['partition'], layer_id=layer_id, savepath=savepath)

c:\Users\natet\Desktop\graduate school\thesis\CaP\assets\figs\cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001
name: conv1.weight, img size: (3, 64) weight size: (64, 3)
name: layer1.0.conv1.weight, img size: (64, 64) weight size: (64, 64)


In [33]:
# split model 

# make copies of model per machine
num_machines = max(configs['partition']['bn_partition']) # TODO: double check this makes sense
model_machines = [model]*num_machines

layer_names =  [name for name, module in model.named_parameters()]
num_tot_layers = len(layer_names)

split_layer_names = list(configs['partition'].keys())

In [34]:
# setup input 
N_batch = 1000
input_tensor = torch.rand(N_batch, 3, 32, 32, device=torch.device(configs['device'])) # 1k images, 3 channels, 32x32 image (cifar100) 
input_size = (N_batch, 3, 32, 32)

# broadcast input_tensor to different machines
input = [[]]*num_machines
input = [input]*num_machines
for imach in range(num_machines):
    input[imach][0] = input_tensor

# put models into eval mode and on device
model.eval()
model.to(configs['device'])
for l in model_machines:
    l.eval()
    l.to(configs['device'])

In [36]:
'''
    mock run through inference using split models 
'''

# TODO: reduce size of communicated tensors to only what is necessary 
# TODO: also check bias for nonzero

# channel_id == INPUTS
# filter_id  == OUTPUTS

# make inference 
with torch.no_grad():
        # iterate through layers in model 1 by 1 
        for ilayer in range(num_tot_layers):

                # initialize output for ilayer
                output = [[]]*num_machines
                output = [output]*num_machines # square list indexed as: output[destination/RX machine][origin/TX machine]
                
                print(f'Executing layer {ilayer}: {layer_names[ilayer]}')

                # iterate through each machine (done in parallel later)
                for imach in range(num_machines):
                        print(f'\tExecuting on machine {imach}')

                        # skip machine when there is no input
                        if layer_names[ilayer] in split_layer_names:
                                # update I/O if encounter split layer
                                # TODO: revist this implementation. Havent thought through if this will work for fully connected layers inbetween split layers yet

                                if len(configs['partition'][layer_names[ilayer]]['channel_id'][imach]) == 0:
                                        print(f'\t\t no input found. Skipping...')
                                        continue

                                # What input channels does this machine compute?
                                input_channels = torch.tensor(configs['partition'][layer_names[ilayer]]['channel_id'][imach],
                                        device=torch.device(configs['device']))
                                N_in = len(input_channels) # TODO: is this used?

                                # Where to send output (map of output channels to different machines)
                                output_channel_map = configs['partition'][layer_names[ilayer]]['filter_id']

                        # get the current layer
                        curr_layer = next((x for i,x in enumerate( model_machines[imach].children()) if i==ilayer)) 

                        # 
                        curr_input = torch.zeros(input_size, device=torch.device(configs['device']), dtype=torch.float32)
                        for i in range(num_machines):
                                if not input[imach][i] == []:
                                        curr_input += input[imach][i]

                        # reduce computation-- make vertically split layer 
                        # TODO: generalize this to more than conv layers 
                        if type(curr_layer) == nn.Conv2d:
                                split_layer = nn.Conv2d(N_in,
                                                curr_layer.weight.shape[0], # TODO does this need to be an int? (currently tensor)
                                                kernel_size= curr_layer.kernel_size,
                                                stride=curr_layer.stride,
                                                padding=curr_layer.padding, 
                                                bias=curr_layer.bias)
                                needs_comms = True

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_layer.weight.index_select(1, input_channels))

                                # TODO: add support for splitting bias

                        elif type(curr_layer) == nn.BatchNorm2d:
                                split_layer = nn.BatchNorm2d(N_in, 
                                                curr_layer.eps,
                                                momentum=curr_layer.momentum, 
                                                affine=curr_layer.affine, 
                                                track_running_stats=curr_layer.track_running_stats)
                                needs_comms = False

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_layer.weight.index_select(0, input_channels))
                                split_layer.running_mean = torch.nn.Parameter(curr_layer.running_mean.index_select(0, input_channels))
                                split_layer.running_var = torch.nn.Parameter(curr_layer.running_var.index_select(0, input_channels))

                                if not curr_layer.bias == None:
                                        split_layer.bias = torch.nn.Parameter(curr_layer.bias.index_select(0, input_channels))
                        elif type(curr_layer) == nn.Linear:
                                # TODO: assumes there is a bias 
                                split_layer = nn.Linear(N_in, 
                                                curr_layer.weight.shape[0])
                                needs_comms = False

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_layer.weight.index_select(1, input_channels))

                                # TODO: double check bias is applied correctly
                                if not curr_layer.bias == None:
                                        split_layer.bias = curr_layer.bias
                        else:
                                print(f'Unrecognized layer type {curr_layer}')

                        # look at which C_out need to be computed and sent
                        if curr_layer == nn.BatchNorm2d:
                                nonzero_Cout = torch.unique(torch.nonzero(curr_layer.weight, as_tuple=True)[0])
                                in_channel_dimension = 0
                        else:
                                nonzero_Cout = torch.unique(torch.nonzero(split_layer.weight, as_tuple=True)[0]) # find nonzero dimensions in output channels
                                in_channel_dimension = 1

                        # eval and send to next machines 
                        print(f'\t\t {curr_input.device}\n\t\t{input_channels.device}\n\t\t{split_layer.weight.device}\n\t\t{split_layer.bias}')
                        out_tensor = split_layer(curr_input.index_select(in_channel_dimension, input_channels))
                        for rx_mech in range(num_machines):
                                # only add to output if communication is necessary 

                                # Where does this machine send outputs? TODO: consider removing, this just maps C_out's to machine
                                output_channels = torch.tensor(output_channel_map[rx_mech],
                                        device=torch.device(configs['device']))

                                # TODO: is there a faster way to do this? Consider putting larger array 1st... just not sure which one that'd be
                                is_comms = torch.any(torch.isin(nonzero_Cout, output_channels))
                                if is_comms.item():
                                        output[rx_mech][imach] = out_tensor

                # update input size. TODO: is this necessary? 
                input_size = tuple(out_tensor.shape)

                # send to next layer        
                input = output
                print(f'Finished execution of layer {ilayer}')
                print()


# collect outputs -- assumes ends with Linear layer. Not sure how generalizable this is
final_output = torch.zeros(input_size, dtype=torch.float32)
for tx_mech in range(num_machines):
        for rx_mech in range(num_machines):
                if not input[imach][i] == []:
                        final_output += output[rx_mech][tx_mech]

Executing layer 0: conv1.weight
	Executing on machine 0
		 no input found. Skipping...
	Executing on machine 1
		 cpu
		cpu
		cpu
		None
	Executing on machine 2
		 no input found. Skipping...
	Executing on machine 3
		 cpu
		cpu
		cpu
		None


RuntimeError: INDICES element is out of DATA bounds, id=1 axis_dim=1

In [None]:
'''
    Compare vertical+horizontal split model with full model
'''

# TODO: finish
match_count = (torch.argmax(final_output, axis=1) == torch.argmax(output_full, axis=1)).sum().item()
label_hist = output_full.sum(0)
print(f'Matches: {match_count}/{output_full.size(0)}')
print(f'histogram {label_hist}')

NameError: name 'final_output' is not defined

In [None]:
'''
    Test partial execution of a single layer
'''

imach = 1
ilayer= 0
input_channels = torch.tensor(configs['partition'][layer_names[ilayer]]['channel_id'][imach],
                              device=torch.device(configs['device']))
output_channels = torch.tensor(configs['partition'][layer_names[ilayer]]['filter_id'][imach],
                               device=torch.device(configs['device']))
print(f'Input channels {input_channels}')
print(f'Output channels {output_channels}')

# TODO: generalize this to more than conv layers 
a_layer = model.conv1
split_layer = nn.Conv2d(len(input_channels),
                    len(output_channels),
                    kernel_size= a_layer.kernel_size,
                    stride=a_layer.stride,
                    padding=a_layer.padding, 
                    bias=a_layer.bias)

split_layer.parameters = a_layer.weight.index_select(0, output_channels).index_select(1, input_channels)


Input channels tensor([0], device='cuda:0', dtype=torch.int32)
Output channels tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       device='cuda:0', dtype=torch.int32)
