In [12]:
'''
    Load RESNET model and split it
        1. layer by layer
        2. [TODO] vertically 
'''

'\n    Load RESNET model and split it\n        1. layer by layer\n        2. [TODO] vertically \n'

In [2]:
from source.core.engine import MoP
import source.core.run_partition as run_p
from os import environ
from source.utils.dataset import *
from source.utils.misc import *
from split_network import *

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from source.models import resnet

import torch.nn.functional as F

import numpy as np

from source.utils import io
from source.utils import testers
from source.core import engine
import json
import itertools

from torchsummary import summary

import time

In [3]:
# setup config
dataset='cifar10'
environ["config"] = f"config/{dataset}.yaml"

configs = run_p.main()

configs["device"] = "cpu"
configs['load_model'] = "cifar10-resnet18-kernel-npv2-pr0.75-lcm0.001.pt"
configs["num_partition"] = '4' #'resnet18-v2.yaml'

device :  
model :  resnet18
data_code :  cifar10
num_classes :  10
model_file :  test.pt
epochs :  0
batch_size :  128
optimizer :  sgd
lr_scheduler :  default
learning_rate :  0.01
seed :  1234
sparsity_type :  kernel
prune_ratio :  1
admm :  True
admm_epochs :  300
rho :  0.0001
multi_rho :  True
retrain_bs :  128
retrain_lr :  0.005
retrain_ep :  50
retrain_opt :  default
xentropy_weight :  1.0
warmup :  False
warmup_lr :  0.001
warmup_epochs :  10
mix_up :  True
alpha :  0.3
smooth :  False
smooth_eps :  0
save_last_model_only :  False
num_partition :  1
layer_type :  regular
bn_type :  masked
par_first_layer :  False
comm_outsize :  False
lambda_comm :  0
lambda_comp :  0
distill_model :  
distill_loss :  kl
distill_temp :  30
distill_alpha :  1


In [4]:
# load data and load or train model
model = get_model_from_code(configs).to(configs['device']) # grabs model architecture from ./source/models/escnet.py
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [6]:
'''
   1st section of resnet model 
'''
class ResnetBlockOne(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockOne, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition
        
        self.conv1 = conv_layer(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        num_bn = self.bn_partition.pop(0)
        self.bn1 = bn_layer(64) if num_bn==1 else bn_layer(64, num_bn)

        self.layer1 = self._make_layer(block, int(64*self.shrink),  num_blocks[0], stride=1)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        
        # override the the foward pass to only include the first modules 
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)

        return out

'''
   2nd section of resnet model 
'''
class ResnetBlockTwo(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockTwo, self).__init__()

        self.in_planes = 64
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer2 = self._make_layer(block, int(128*self.shrink), num_blocks[1], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer2(x)
        return out

'''
    3rd section of resenet model 
'''
class ResnetBlockThree(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockThree, self).__init__()

        self.in_planes = 128
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer3 = self._make_layer(block, int(256*self.shrink), num_blocks[2], stride=2)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer3(x)
        return out

'''
    4-th section of resenet model 
'''
class ResnetBlockFour(nn.Module):
    def __init__(self, block, num_blocks, conv_layer, bn_layer, num_classes=10, num_filters=512, bn_partition=[1]*9):
        super(ResnetBlockFour, self).__init__()

        self.in_planes = 256
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.shrink = num_filters/512
        self.bn_partition = bn_partition

        self.layer4 = self._make_layer(block, num_filters, num_blocks[3], stride=2)
        self.linear = nn.Linear(num_filters*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        # TODO: find better way to implement this method using inheretence and getting from ResNet class
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.conv_layer, self.bn_layer, stride, self.bn_partition.pop(0)))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = self.layer4(x)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [7]:
# get framework for each layer of resnet

# inputs can be found from looking at ./source/util/misc/get_model_from_code
num_classes=configs['num_classes']
bn_layers = get_bn_layers('regular') # basic block layer
conv_layers = get_layers(configs['layer_type'])

# resnet18 inputs from resnet.py
layer_1 =  ResnetBlockOne(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_2 =  ResnetBlockTwo(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_3 =  ResnetBlockThree(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 
layer_4 =  ResnetBlockFour(resnet.BasicBlock, [2,2,2,2],conv_layers, bn_layers, num_classes=num_classes) # also includes bn1 and conv1 

#print(layer_4)

# load model params into dictionary
state_dict = torch.load(io.get_model_path("{}".format(configs["load_model"])), map_location=configs['device'])

# add params to split
split_model = [layer_1, layer_2, layer_3, layer_4]
for l in split_model:
    l = io.load_state_dict(l, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)

not found:  layer2.0.conv1.weight
not found:  layer2.0.conv2.weight
not found:  layer2.0.bn1.weight
not found:  layer2.0.bn1.bias
not found:  layer2.0.bn1.running_mean
not found:  layer2.0.bn1.running_var
not found:  layer2.0.bn1.num_batches_tracked
not found:  layer2.0.bn2.weight
not found:  layer2.0.bn2.bias
not found:  layer2.0.bn2.running_mean
not found:  layer2.0.bn2.running_var
not found:  layer2.0.bn2.num_batches_tracked
not found:  layer2.0.shortcut.0.weight
not found:  layer2.0.shortcut.1.weight
not found:  layer2.0.shortcut.1.bias
not found:  layer2.0.shortcut.1.running_mean
not found:  layer2.0.shortcut.1.running_var
not found:  layer2.0.shortcut.1.num_batches_tracked
not found:  layer2.1.conv1.weight
not found:  layer2.1.conv2.weight
not found:  layer2.1.bn1.weight
not found:  layer2.1.bn1.bias
not found:  layer2.1.bn1.running_mean
not found:  layer2.1.bn1.running_var
not found:  layer2.1.bn1.num_batches_tracked
not found:  layer2.1.bn2.weight
not found:  layer2.1.bn2.bias


In [8]:
# load weights into full model
model = io.load_state_dict(model, 
                    state_dict['model_state_dict'] if 'model_state_dict' in state_dict 
                    else state_dict['state_dict'] if 'state_dict' in state_dict else state_dict,)


In [11]:
'''
    add partitions and communications to configs
'''

# gets random test input (with correct size)
input_var = engine.get_input_from_code(configs)
#print(input_var)

# Config partitions and prune_ratio
configs['num_partition'] = '4'#'./config/resnet18-v2.yaml'
configs = engine.partition_generator(configs, model)
            
# Compute output size of each layer
configs['partition'] = engine.featuremap_summary(model, configs['partition'], input_var)
        
# Setup communication costs
configs['comm_costs'] = engine.set_communication_cost(model, configs['partition'],)


# split model general parameters

# make copies of model per machine
num_machines = max(configs['partition']['bn_partition']) # TODO: double check this makes sense
model_machines = [model]*num_machines

module_names =  [module[0] for i, module in enumerate(model.named_modules())]
num_total_modules = len(module_names)

split_module_names = list(configs['partition'].keys())

print(module_names)


Inference time per data is 32.937527ms.
conv1.weight 1024
layer1.0.conv1.weight 1024
layer1.0.conv2.weight 1024
layer1.1.conv1.weight 1024
layer1.1.conv2.weight 1024
layer2.0.conv1.weight 256
layer2.0.conv2.weight 256
layer2.0.shortcut.0.weight 256
layer2.1.conv1.weight 256
layer2.1.conv2.weight 256
layer3.0.conv1.weight 64
layer3.0.conv2.weight 64
layer3.0.shortcut.0.weight 64
layer3.1.conv1.weight 64
layer3.1.conv2.weight 64
layer4.0.conv1.weight 16
layer4.0.conv2.weight 16
layer4.0.shortcut.0.weight 16
layer4.1.conv1.weight 16
layer4.1.conv2.weight 16
['', 'conv1', 'bn1', 'layer1', 'layer1.0', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.0.bn1', 'layer1.0.bn2', 'layer1.0.shortcut', 'layer1.1', 'layer1.1.conv1', 'layer1.1.conv2', 'layer1.1.bn1', 'layer1.1.bn2', 'layer1.1.shortcut', 'layer2', 'layer2.0', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.0.bn1', 'layer2.0.bn2', 'layer2.0.shortcut', 'layer2.0.shortcut.0', 'layer2.0.shortcut.1', 'layer2.1', 'layer2.1.conv1', 'layer2.1.conv2', '

In [13]:
'''
    Setup datastrcuts to ID layers executed with 'extra' functionality
'''

# module numbering is based on module.named_parameters()
relu_modules = [2, 7,8,13,14,20,2428,29,35,3943,44,50,54,58,59] # execute relu on this  layer 
split_module_indexes =  [i for i in range(len(module_names)) if 'conv' in module_names[i] ]

print(split_module_indexes)


[1, 5, 6, 11, 12, 18, 19, 26, 27, 33, 34, 41, 42, 48, 49, 56, 57]


In [14]:
def get_nonzero_channels(atensor, dim=1):
    return torch.unique(torch.nonzero(atensor, as_tuple=True)[dim]) 


In [15]:
def  compare_tensors(t1, t2, dim=1, rshape=(1,64,-1)):
    diff = torch.abs(t1-t2)

    max_diff_pin_dim = torch.max(diff.reshape(rshape), dim)
    return max_diff_pin_dim[0]
    

In [20]:
''' Prep Inout for Mock Run'''

# TODO: reduce size of communicated tensors to only what is necessary 
# TODO: also check bias for nonzero
# TODO: come up with more general scheme to handle residual layers

# channel_id == INPUTS
# filter_id  == OUTPUTS

# setup input 
N_batch = 1
input_tensor = torch.rand(N_batch, 3, 32, 32, device=torch.device(configs['device'])) # 1k images, 3 channels, 32x32 image (cifar100) 

# broadcast input_tensor to different machines
# TODO: find a better datastructure for this
#input = np.empty((num_machines, num_machines), dtype=torch.Tensor)
input = [None]*num_machines
input = [input[:] for i in range(num_machines)]
for imach in range(num_machines):
    input[imach][0] = input_tensor

relu_layers,avg_pool_layers = get_functional_layers(configs['model'])
residual_block_start, residual_connection_start, residual_block_end = get_residual_block_indexes(model)

# put models into eval mode and on device
model.eval()
model.to(configs['device'])

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [21]:
'''
    mock run through inference using split models 
'''

residual_input = {} # use this to keep track of inputs stored in machine memory for residule layers
add_bias = False # add bias for previous conv layer 

# make inference 
with torch.no_grad():
        # iterate through layers 1 module at a time 
        for imodule in range(num_total_modules): # 16 <=> layer_1 block 

                if imodule in [0]:
                        continue

                # initialize output for ilayer
                #output = np.empty((num_machines, num_machines), dtype=torch.Tensor) # square list indexed as: output[destination/RX machine][origin/TX machine]
                # TODO: find a better datastructure for this 
                output = [None]*num_machines
                output = [output[:] for i in range(num_machines)]
                
                send_module_outputs = True

                print(f'Executing module {imodule}: {module_names[imodule]}')

                # iterate through each machine (done in parallel later)
                for imach in range(num_machines):
                        print(f'\tExecuting on machine {imach}')
                        
                        add_residual = False

                        # combine inputs from machines
                        curr_input = False 
                        rx_count = 0
                        for i in range(num_machines):
                                if not input[imach][i] == None:
                                        if not torch.is_tensor(curr_input):
                                                curr_input = input[imach][i] # initialize curr_input with first input tensor 
                                        else:
                                                curr_input += input[imach][i]
                                        rx_count += 1
                        if add_bias:
                                # TODO: check if this works (this is not required for resnet18 because no bias on conv layers)
                                dummy, prev_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule-1))
                                bias = prev_module.bias 
                                curr_input += bias/rx_count

                        # skip this machine+module if there is no input to compute 
                        if not torch.is_tensor(curr_input):
                                print('\t\t-No input sent to this machine. Skipping module')
                                continue

                        # debug
                        print(f'\t\t received input channels {get_nonzero_channels(curr_input)}')

                        if imodule in residual_block_start:
                                # save input for later 
                                residual_input[str(imach)] = {}
                                residual_input[str(imach)]['block_in'] = curr_input
                                print('\t\t-Saving input for later...')
                        elif imodule in residual_connection_start:
                                # swap tensors
                                residual_input[str(imach)]['block_out'] = curr_input
                                curr_input = residual_input[str(imach)]['block_in'] 
                                print('\t\t-Saving current input. Swapping for input saved from start of block')

                        # get the current module
                        # TODO: is this very bad for latency? Only load module if you have to 
                        curr_name, curr_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule)) 

                        # update communication I/O for this layer  
                        # TODO: revist this implementation
                        split_param_name = curr_name + '.weight'
                        if split_param_name in split_module_names:

                                # skip if machine doesnt expect input
                                if len(configs['partition'][split_param_name]['channel_id'][imach]) == 0:
                                        print(f'\t\t-No input assigned to this machine. Skipping...')
                                        continue
                                
                                # TODO: reconsider implementation 
                                # What input channels does this machine compute?
                                input_channels = torch.tensor(configs['partition'][split_param_name]['channel_id'][imach],
                                        device=torch.device(configs['device']))
                                N_in = len(input_channels) # TODO: is this used?

                                # Where to send output (map of output channels to different machines)
                                output_channel_map = configs['partition'][split_param_name]['filter_id']
                        elif type(curr_module) == nn.BatchNorm2d:
                                # TODO: address the following assumptions:
                                #       - assume all BN layers have C_in divisable by num_machines
                                #       - assume C_in are evenly split in sequential order WARNING THIS WILL BREAK WHEN WE START TO DO ASSIGN WEIGHTS TO DIFF MACHINES
                                N_Cin = curr_module.num_features
                                Cin_per_machine = N_Cin/num_machines
                                if Cin_per_machine % 1 > 0:
                                        print('ERROR: UNEXPECTED NUMBER OF I/O FOR BATCH NORMAL MODULE {imodule}')
                                Cin_per_machine = int(Cin_per_machine)
                                input_channels = np.arange(Cin_per_machine) + imach*Cin_per_machine
                                output_channel_map = [None]*num_machines
                                for i in range(num_machines):
                                        if i == imach:
                                                output_channel_map[i] = input_channels
                                        else:
                                                output_channel_map[i] = np.array([])
                                input_channels = torch.tensor(input_channels, device=torch.device(configs['device']))
                        elif type(curr_module) == nn.Linear and imodule == num_total_modules-1:
                                # if final layer output all goes to machine 0 
                                # TODO: find better way to handle this. Also will we encounter Linear layers not at the end of the model
                                N_Cin = curr_module.in_features
                                Cin_per_machine = N_Cin/num_machines
                                if Cin_per_machine % 1 > 0:
                                        print('ERROR: UNEXPECTED NUMBER OF I/O FOR LINEAR MODULE {imodule}')
                                Cin_per_machine = int(Cin_per_machine)
                                input_channels = np.arange(Cin_per_machine) + imach*Cin_per_machine
                                N_Cout = curr_module.out_features 
                                output_channel_map = [None]*num_machines
                                for i in range(num_machines):
                                        if i == 0:
                                                output_channel_map[i] = np.arange(N_Cout) 
                                        else:
                                                output_channel_map[i] = np.array([])
                                input_channels = torch.tensor(input_channels, device=torch.device(configs['device']))


                        # reduce computation-- make vertically split layer 
                        # TODO: generalize this to more than conv layers 
                        if type(curr_module) == nn.Conv2d:
                                split_layer = nn.Conv2d(N_in,
                                                curr_module.weight.shape[0], # TODO does this need to be an int? (currently tensor)
                                                kernel_size= curr_module.kernel_size,
                                                stride=curr_module.stride,
                                                padding=curr_module.padding, 
                                                bias=False) # TODO: add bias during input collecting step on next layer 

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(1, input_channels))

                                # TODO: add support for splitting bias

                        elif type(curr_module) == nn.BatchNorm2d:
                                split_layer = nn.BatchNorm2d(N_in, 
                                                curr_module.eps,
                                                momentum=curr_module.momentum, 
                                                affine=curr_module.affine, 
                                                track_running_stats=curr_module.track_running_stats)

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(0, input_channels))
                                split_layer.running_mean = torch.nn.Parameter(curr_module.running_mean.index_select(0, input_channels))
                                split_layer.running_var = torch.nn.Parameter(curr_module.running_var.index_select(0, input_channels))

                                if not curr_module.bias == None:
                                        split_layer.bias = torch.nn.Parameter(curr_module.bias.index_select(0, input_channels))

                                # TODO: revise implementation to only compute necessary C_in to C_out 
                                # assume mach-Cout map from previous conv layer can be used as inputs for this bn layer
                                #input_channels = output_channel_map[imach]

                        elif type(curr_module) == nn.Linear:
                                # TODO: assumes there is a bias 
                                split_layer = nn.Linear(N_in, 
                                                curr_module.weight.shape[0])

                                # write parameters to split layer 
                                split_layer.weight = torch.nn.Parameter(curr_module.weight.index_select(1, input_channels))

                                # TODO: double check bias is applied correctly
                                if not curr_module.bias == None:
                                        split_layer.bias = curr_module.bias

                                # prep for linear layer
                                # TODO: assumes this always happens before linear layer 
                                # bn takes one in channel C_in_i and produces one out channel C_out_j. No communication is needed. 
                                curr_input = F.avg_pool2d(curr_input, 4)
                                curr_input = curr_input.view(curr_input.size(0), -1)

                        else:
                                print(f'\t\t-Skipping module {type(curr_module).__name__}')
                                send_module_outputs = False
                                continue
                        
                        # make sure layer is in eval mode
                        # TODO: if you set model.eval() can we skip this, also only required for bn layers? Maybe 
                        split_layer.eval()

                        # eval split
                        out_tensor = split_layer(curr_input.index_select(1, input_channels))
                        if type(curr_module) == nn.BatchNorm2d:
                                tmp_out_tensor = torch.zeros(curr_input.shape)
                                tmp_out_tensor[:,input_channels.numpy(),:,:] = out_tensor
                                out_tensor = tmp_out_tensor


                        print(f'\t\t Output tensor shape : {out_tensor.shape}')

                        # debug
                        nonzero_out_tensor = torch.unique(torch.nonzero(out_tensor, as_tuple=True)[1])


                        # check if this is residual layer
                        if imodule in residual_block_end and 'block_out' in residual_input[str(imach)]: # TODO: does this conditional make sense?
                                print('\t\t-adding residual')
                                out_tensor += residual_input[str(imach)]['block_out']

                                # erase stored 
                                residual_input[str(imach)] = {}

                        # apply ReLU after batch layers
                        if imodule in relu_modules:
                                print('\t\t-Applying ReLU')
                                out_tensor = F.relu(out_tensor)

                        # look at which C_out need to be computed and sent
                        #nonzero_Cout = torch.unique(torch.nonzero(split_layer.weight, as_tuple=True)[0]) # find nonzero dimensions in output channels
                        nonzero_Cout = get_nonzero_channels(out_tensor)

                        # communicate
                        out_channel_array = torch.arange(out_tensor.shape[1])
                        for rx_mach in range(num_machines):
                                # only add to output if communication is necessary 

                                # Get output channels for current rx machine? TODO: consider removing, this just maps C_out's to machine
                                output_channels = torch.tensor(output_channel_map[rx_mach],
                                        device=torch.device(configs['device']))

                                # TODO: is there a faster way to do this? Consider putting larger array 1st... just not sure which one that'd be
                                nonzero_out_channels = nonzero_Cout[torch.isin(nonzero_Cout, output_channels)]
                                if nonzero_out_channels.nelement() > 0:
                                        communication_mask = torch.isin(out_channel_array, nonzero_out_channels)

                                        # TODO: this is inefficient, redo. Probbably need to send a tensor and some info what output channels are being sent
                                        tmp_out = torch.zeros(out_tensor.shape) 
                                        if imodule == num_total_modules-1:
                                                tmp_out[:,communication_mask] = out_tensor[:,communication_mask]
                                        else:
                                                tmp_out[:,communication_mask,:,:] = out_tensor[:,communication_mask,:,:]
                                        output[rx_mach][imach] = tmp_out

                                        # debug
                                        print(f'\t\t sending C_out {nonzero_out_channels} to machine {rx_mach}')

                # send to next layer  
                if send_module_outputs:      
                        input = output
                print(f'Finished execution of layer {imodule}')
                print()

# collect outputs -- assumes ends with Linear layer. Not sure how generalizable this is
# if loop stops on module that doesnt calculate anything use input struct 
if send_module_outputs:
        tmp_output = output
else:
        tmp_output = input 
need_to_init  = True
for rx_mach in range(num_machines):
        for tx_mach in range(num_machines):
                if not tmp_output[rx_mach][tx_mach] == None:
                        if need_to_init:
                                final_output = tmp_output[rx_mach][tx_mach]
                                need_to_init = False
                        else:
                                # TODO: += causes assignment issues, switched to x = x+y which might be more more inefficent memory wise ... 
                                final_output = final_output + tmp_output[rx_mach][tx_mach] 
                                nz_channels = get_nonzero_channels(final_output)
                                #print(f'({rx_mach},{tx_mach}) {nz_channels}')

print()

Executing module 1: conv1
	Executing on machine 0
		 received input channels tensor([0, 1, 2])
		-No input assigned to this machine. Skipping...
	Executing on machine 1
		 received input channels tensor([0, 1, 2])
		 Output tensor shape : torch.Size([1, 64, 32, 32])
		 sending C_out tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31]) to machine 1
		 sending C_out tensor([56, 62]) to machine 3
	Executing on machine 2
		 received input channels tensor([0, 1, 2])
		 Output tensor shape : torch.Size([1, 64, 32, 32])
		 sending C_out tensor([ 0, 12]) to machine 0
		 sending C_out tensor([38, 42]) to machine 2
		 sending C_out tensor([54, 57]) to machine 3
	Executing on machine 3
		 received input channels tensor([0, 1, 2])
		 Output tensor shape : torch.Size([1, 64, 32, 32])
		 sending C_out tensor([ 4,  8, 14]) to machine 0
		 sending C_out tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 31]) to machine 1
		 sending C_out tensor([36, 42, 45, 46, 47]) to mac

In [18]:
'''
    Compare I/O for one layer 
'''

def compare_IO(model, module_list, input, split_output):
    ''' This function is used for debugging. It computes 
    the output for a particular layer in the model using the input argument 
    and compares it to the split output'''


    for imodule in module_list:
        curr_name, curr_module = next((x for i,x in enumerate(model.named_modules()) if i==imodule)) 
        print(curr_name)

        curr_module.eval()

        input = curr_module(input)
        if imodule in relu_modules:
            print('Apllying ReLU')
            input = F.relu(input)
    full_output = input 

    io_match = torch.all(torch.eq(split_output, full_output))

    return (io_match, full_output)

In [19]:
''' Single Layer Test'''
io_match, full_output= compare_IO(model, [1,2,5,6,7,8, 9,11,12, 13,14], input_tensor, final_output) #

# module 9 sequential seems to be messing stuff up 


#print(io_match)
diff_output = torch.abs(full_output - final_output)

print(torch.max(torch.reshape(diff_output, (N_batch, -1)), dim=1)[0])
#plt.hist(diff_output.reshape((-1,)))
#plt.show()

max_by_Cout = torch.max(torch.abs(diff_output.reshape((1,64,-1))), dim=2)

print()
print(max_by_Cout[0])
print(get_nonzero_channels(max_by_Cout[0]))


# get C_out with zero and non-zero diff
nonzero_Cout = get_nonzero_channels(full_output)
failing_Cout = nonzero_Cout[torch.isin(nonzero_Cout, get_nonzero_channels(max_by_Cout[0]))]
passing_Cout = nonzero_Cout[torch.isin(nonzero_Cout, get_nonzero_channels(max_by_Cout[0])) == False]
print() 
print(f'failing Cout = {failing_Cout}  (len = {len(failing_Cout)})')
print(f'passing Cout = {passing_Cout}  (len = {len(passing_Cout)})')

NameError: name 'final_output' is not defined

In [None]:
''' 
    Check Model output 1 horizontal block at a time 
'''

# imodule = 16


print(layer_1)
layer_1_output = layer_1(input_tensor)

diff_output = torch.abs(layer_1_output - final_output)

print(torch.max(torch.reshape(diff_output, (N_batch, -1)), dim=1))



ResnetBlockOne(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1,

In [None]:
'''
    Test entire model : vertically split vs full
'''

with torch.no_grad():
        output_full = model(input_tensor)

# TODO: finish
match_count = (torch.argmax(final_output, axis=1) == torch.argmax(output_full, axis=1)).sum().item()
label_hist = output_full.sum(0)
print(f'Matches: {match_count}/{output_full.size(0)}')
print(f'histogram {label_hist}')

Matches: 0/1
histogram tensor([ 0.3117, -0.8496,  0.7471, -0.8106,  0.5718, -1.5329,  3.0551, -1.4684,
        -0.2197, -0.9103])


In [None]:
def combine_inputs(num_machines, input):
    '''
        combines input tensors 

        Input:
            num_machines - number of total machines
            input - num_machines x num_machines list with inputs from previous layers collected from each machine. Indexed [destination][origin]

        Output:
            curr_input - a single tensor for this layer from the combined inputs
    '''
    curr_input = False 
    for i in range(num_machines):
            if not input[imach][i] == []:
                    if not torch.is_tensor(curr_input):
                            curr_input = input[imach][i] # initialize curr_input with first input tensor 
                    else:
                            curr_input += input[imach][i]
    
    return curr_input 

In [None]:
'''
    Conv1 layer test
'''

# DIFFERENCE SHOULD BE 0 NOT 1E-7

N_in = 1
split_1 = nn.Conv2d(N_in,
            model.conv1.weight.shape[0], # TODO does this need to be an int? (currently tensor)
            kernel_size= model.conv1.kernel_size,
            stride=model.conv1.stride,
            padding=model.conv1.padding, 
            bias=False) # TODO: add bias during input collecting step on next layer 
split_1.weight = torch.nn.Parameter(model.conv1.weight.index_select(1, torch.tensor([0])))  
out_split1 = split_1(input_tensor.index_select(1, torch.tensor([0])))

split_2 = split_1
split_2.weight = torch.nn.Parameter(model.conv1.weight.index_select(1, torch.tensor([1])))  
out_split2 = split_2(input_tensor.index_select(1, torch.tensor([1])))

split_3 = split_1
split_3.weight = torch.nn.Parameter(model.conv1.weight.index_select(1, torch.tensor([2])))  
out_split3 = split_3(input_tensor.index_select(1, torch.tensor([2])))

split_out = torch.add(torch.add(out_split1, out_split2), out_split3)
full_out = model.conv1(input_tensor)

diff_output = torch.abs(full_out - split_out)
max_diff = torch.max(diff_output)
max_diff.sci_mode = True
print(max_diff)

In [None]:
def get_split_out_from_module(input_tensor, model, module_select, configs, num_machines, module_names):
    '''
        assumes module is conv layer
    '''

    input_channel_map = configs['partition'][module_names[module_select] + '.weight']['channel_id']

    dummy, module = next((x for i,x in enumerate(model.named_modules()) if i==module_select))

    split_modules = [None]*num_machines
    split_outputs = [None]*num_machines
    for imachine in range(num_machines):

        input_channels = torch.tensor(input_channel_map[imachine])
        N_in = len(input_channels)
        split_modules[imachine] = nn.Conv2d(N_in,
                    module.weight.shape[0], # TODO does this need to be an int? (currently tensor)
                    kernel_size= module.kernel_size,
                    stride=module.stride,
                    padding=module.padding, 
                    bias=False) # TODO: add bias during input collecting step on next layer 
        split_modules[imachine].weight = torch.nn.Parameter(module.weight.index_select(1, input_channels))  
        split_outputs[imachine] = split_modules[imachine](input_tensor.index_select(1, input_channels))

    return (split_modules, split_outputs)