In [1]:
import time
from tqdm import tqdm
import numpy as np
import scipy as sp
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models, transforms

from networkAlignmentAnalysis import utils
from networkAlignmentAnalysis import datasets

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device: ', DEVICE)

using device:  cuda


In [2]:
# The LAYER_REGISTRY contains meta parameters for each type of layer used in alignment networks
# each layer type is associated with a few features, including:
# name (string): just for completeness, will only be used for plotting
# layer-handle (lambda method): takes as input a registered layer and returns the part of that layer
#                               to perform alignment methods on. For example, if a registered layer
#                               is layer=nn.Sequential(nn.Linear(10,10), nn.Dropout()), then the layer
#                               handle should be: lambda layer: layer[0])
# alignment_method (callable): the method used to measure alignment for a particular layer
# Note: as of writing this, I only have nn.Linear and nn.Conv2d here, but this will start to be more
# useful and meaningful when reusing typical combinations of layers as a single registered "layer" 
# that include things like dropout, pooling, nonlinearities, etc.
REGISTRY_REQUIREMENTS = ['name', 'layer_handle', 'alignment_method', 'ignore']
LAYER_REGISTRY = {
    nn.Linear: {
        'name': 'linear', 
        'layer_handle': lambda layer:layer, 
        'alignment_method': utils.alignment_linear,
        'ignore': False,
        },

    nn.Conv2d: {
        'name': 'conv2d', 
        'layer_handle': lambda layer:layer, 
        'alignment_method': utils.alignment_convolutional,
        'ignore': False,
        },
}

def default_metaprms_ignore(name):
    """convenience method for named metaparameters to be ignored"""
    metaparameters = {
        'name': name,
        'layer_handle': None,
        'alignment_method': None,
        'ignore': True
    }
    return metaparameters

def default_metaprms_linear(index, name='linear'):
    """convenience method for named metaparameters in a linear layer packaged in a sequential"""
    metaparameters = {
        'name': name,
        'layer_handle': lambda layer: layer[index],
        'alignment_method': utils.alignment_linear,
        'ignore': False,
    }
    return metaparameters

def default_metaprms_conv2d(index, name='conv2d'):
    """convenience method for named metaparameters in a conv2d layer packaged in a sequential"""
    metaparameters = {
        'name': name,
        'layer_handle': lambda layer: layer[index],
        'alignment_method': utils.alignment_convolutional,
        'ignore': False,
    }
    return metaparameters

def check_metaparameters(metaparameters, throw=True):
    """validate whether metaparameters is a dictionary containing the required keys for an alignment network"""
    if not all([required in metaparameters for required in REGISTRY_REQUIREMENTS]):
        if throw:
            raise ValueError(f"metaparameters are missing required keys, it requires all of the following: {REGISTRY_REQUIREMENTS}")
        return False
    return True

def check_registry():
    for layer_type, metaparameters in LAYER_REGISTRY:
        if not check_metaparameters(metaparameters, throw=False):
            raise ValueError(f"Layer type: {layer_type} from the `LAYER_REGISTRY` is missing metaparameters. "
                            "It requires all of the following: {REGISTRY_REQUIREMENTS}")

class AlignmentNetwork(nn.Module):
    """
    This is the base class for a neural network used for alignment-related experiments. 

    The point of all the wrangling of standard torch workflows in this class is to make 
    it easy to perform all the alignment-related computations for networks with different
    architectures without having to rewrite similar code over and over again. In this way,
    the user only needs to add a layer type to the **LAYER_REGISTRY** in this file and
    then alignment methods can be automatically applied. 

    The forward method of **AlignmentNetwork** passes the input (*x*) through each registered
    layer of the network in order of it's registration. If hidden activations are requested, then
    the output of each registered layer is saved. The alignment methods are applied to the 
    hidden activation at the output of layer L-1 and the weights of layer L. 

    Note: some shape wrangling (like that which happens between a convolutional layer and a
    linear layer are often treated as a nn.Module layer), but these don't require alignment-
    related processing. To use these, set 'ignore' of the metaparameters to True. Alternatively,
    you can append them to the last component of a layer.

    A layer in the layer_registry should have the following properties:
    1. Be a child of the nn.Module class with a forward method
    2. Have at most one "relevant" processing stage with weights for measuring alignment
    """
    def __init__(self):
        super().__init__() # register it as a nn.Module
        self.layers = nn.ModuleList() # a list of all modules in the forward pass
        self.metaparameters = [] # list of dictionaries containing metaparameters for each layer
        self.hidden = [] # list of tensors containing hidden activations

    def register_layer(self, layer, verbose=True, **kwargs):
        """
        register_layer adds a **layer** to the network's module list and it's associated metaparameters
        for determining what kind of aligment-related processing is done on the layer

        by default, the layer is used as a key to lookup the metaparameters from the **LAYER_REGISTRY**. 
        kwargs can update keys in the metaparameters. If the layer class is not registered, then all 
        metaparameters must be provided as kwargs.
         
        Required kwargs are: name, layer_handle, alignment_method, ignore, ...
        """
        if not isinstance(layer, nn.Module):
            raise TypeError(f"provided layer is of type: {type(layer)}, but only nn.Module objects are permitted!")
        
        metaparameters = LAYER_REGISTRY.get(type(layer), {})
        for metaprms in REGISTRY_REQUIREMENTS:
            # for each possible entry in layer metaparameters, check if it's provided, not none, then update it
            if metaprms in kwargs and kwargs[metaprms] is not None:
                metaparameters[metaprms]=kwargs[metaprms]
        
        # check whether metaparameters contain the correct keys
        check_metaparameters(metaparameters, throw=True)
        
        # add layer to network
        self.layers.append(layer)
        self.metaparameters.append(metaparameters)

        if verbose:
            print(f"Added a {metaparameters['name']} layer to the network")


    def forward(self, x, store_hidden=False):
        """standard forward pass of all layers with option of storing hidden activations (and output)"""
        self.hidden = [] # always reset so as to not keep a previous forward pass accidentally
        for layer, metaprms in zip(self.layers, self.metaparameters):
            x = layer(x) # pass through next layer
            if store_hidden and not metaprms['ignore']: 
                self.hidden.append(x)
        return x
    
    @torch.no_grad()
    def get_activations(self, x=None, precomputed=False):
        """convenience method for getting list of intermediate activations throughout the network"""
        if not precomputed and x is not None:
            _ = self.forward(x, store_hidden=True)
        else:
            raise ValueError("x needs to be provided if precomputed is False")
        return self.hidden
    
    @torch.no_grad()
    def get_alignment_layers(self):
        """convenience method for retrieving registered layers for alignment measurements throughout the network"""
        layers = []
        for layer, metaprms in zip(self.layers, self.metaparameters):
            if not metaprms['ignore']:
                layers.append(metaprms['layer_handle'](layer))
        return layers
    
    @torch.no_grad()
    def get_alignment_metaparameters(self):
        """convenience method for retrieving registered layers for alignment measurements throughout the network"""
        metaparameters = []
        for metaprms in self.metaparameters:
            if not metaprms['ignore']:
                metaparameters.append(metaprms)
        return metaparameters
    
    @torch.no_grad()
    def get_alignment_weights(self):
        """convenience method for retrieving registered weights for alignment measurements throughout the network"""
        return [layer.weight for layer in self.get_alignment_layers()]
    
    @torch.no_grad()
    def compare_weights(self, weights):
        current_weights = self.get_alignment_weights()
        delta_weights = []
        for iw, cw in zip(weights, current_weights):
            delta_weights.append(torch.norm(cw.flatten(1), iw.flatten(1), dim=1))
        return delta_weights
    
    @torch.no_grad()
    def measure_alignment(self, x, precomputed=False, method='alignment'):
        activations = self.get_activations(x=x, precomputed=precomputed)
        alignment = []
        for activation, layer, metaprms in zip(activation, self.get_alignment_layers(), self.get_alignment_metaparameters()):
            alignment.append(metaprms['alignment_method'](activation, layer, method=method))
        return alignment


class MLP(AlignmentNetwork):
    def __init__(self, verbose=True):
        super().__init__()

        layer1 = nn.Sequential(nn.Linear(784, 100), nn.ReLU())
        layer2 = nn.Sequential(nn.Dropout(), nn.Linear(100, 100), nn.ReLU())
        layer3 = nn.Sequential(nn.Dropout(), nn.Linear(100, 50), nn.ReLU())
        layer4 = nn.Sequential(nn.Dropout(), nn.Linear(50, 10))

        self.register_layer(layer1, **default_metaprms_linear(0), verbose=verbose)
        self.register_layer(layer2, **default_metaprms_linear(1), verbose=verbose)
        self.register_layer(layer3, **default_metaprms_linear(1), verbose=verbose)
        self.register_layer(layer4, **default_metaprms_linear(1), verbose=verbose)
    
    
    

In [3]:
net = MLP()

Added a linear layer to the network
Added a linear layer to the network
Added a linear layer to the network
Added a linear layer to the network


In [6]:
preprocess = transforms.Compose([
            transforms.ToTensor(), # first, convert image to PyTorch tensor
            transforms.Normalize((0.1307,), (0.3081,)), # normalize inputs
            transforms.Lambda(torch.flatten), # convert to vectors
        ])
trainloader, testloader = datasets.downloadMNIST(preprocess=preprocess)

In [8]:
net.to(DEVICE)
    
# Prepare Training Functions 
loss_function = nn.CrossEntropyLoss() # Note: this automatically applies softmax...
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
# optimizer = torch.optim.Adadelta(net.parameters())

# Preallocate summary variables  
iterations = 100
numTrainingSteps = len(trainloader)*iterations
trackLoss = torch.zeros(numTrainingSteps)
trackAccuracy = torch.zeros(numTrainingSteps)
alignFull = []
deltaWeights = []

init_weights = net.get_alignment_weights()

# Train Network & Measure Integration
t = time.time()
for epoch in tqdm(range(0, iterations)): 
    # Set current loss value
    currentLoss = 0.0
    numBatches = 0
    currentCorrect = 0
    currentAttempted = 0

    for idx,batch in enumerate(trainloader):
        cidx = epoch*len(trainloader) + idx
        
        images, label = batch
        images = images.to(DEVICE)
        label = label.to(DEVICE)

        # Zero the gradients
        optimizer.zero_grad()

        # Perform forward pass
        outputs = net(images)

        # Perform backward pass & optimization
        loss = loss_function(outputs, label)
        loss.backward()
        optimizer.step()
        
        # Measure Integration
        alignFull.append(net.measure_alignment(images, precomputed=True, method='alignment'))
        
        # Measure Change in Weights
        deltaWeights.append(net.compare_weights(init_weights))

        # Track Loss and Accuracy
        trackLoss[cidx] = loss.item()
        trackAccuracy[cidx] = 100*torch.sum(torch.argmax(outputs,axis=1)==label)/images.shape[0]

    # Print statistics for each epoch
    print('Loss in epoch %3d: %.3f, Accuracy: %.2f%%.' % (epoch, loss.item(), 100*torch.sum(torch.argmax(outputs,axis=1)==label)/images.shape[0]))


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:22<?, ?it/s]


ValueError: x needs to be provided if precomputed is False

In [8]:
# the BaseNetwork needs alignment and other methods
# keep working on MLP here
# all the models I've written are in the models/models.py module, will break them out into different modules as I write them!
# will probably have to keep updating and copy BaseNetwork and MLP to the appropriate modules, but wanted to "seed" them with the target organization