In [2]:
import numpy as np
import scipy as sp
import torch
import torch.nn.functional as F
from torch import nn
import tools.alignmentAnalysisTools as aat
from torchvision import models, transforms

In [10]:
class net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)      
        self.maxPool = nn.MaxPool2d(kernel_size=3)
        self.fc1 = nn.Linear(256, 256) #4608, 512       
        self.o = nn.Linear(256, 10) #512, 10

        self.layerRegistration = {
            'conv1':False,
            'conv2':False,
            'fc1':True,
            'o':True
        }
        
        self.layers = []
        self.layers.append(nn.Sequential(
            self.conv1,
            nn.ReLU(),
        ))

        self.layers.append(nn.Sequential(
            self.conv2,
            nn.ReLU(),
            self.maxPool
        ))
        self.layers.append(nn.Sequential(
            nn.Flatten(start_dim=1),
            self.fc1,
            nn.ReLU()
        ))
        self.layers.append(nn.Sequential(
            self.o
        ))

        self.numLayers = len(self.layers)
    
    def forward(self, x):        
        self.activations = [None]*self.numLayers
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            self.activations[idx]=x
        return x

    def getNetworkWeights(self,onlyFF=False):
        netWeights = [None]*self.numLayers
        for 
        if not onlyFF:
            netWeights.append(self.conv1.weight.data.clone().detach())
            netWeights.append(self.conv2.weight.data.clone().detach())
        netWeights.append(self.fc1.weight.data.clone().detach())
        netWeights.append(self.o.weight.data.clone().detach())
        return netWeights

In [3]:
class CNN2P2(nn.Module):
    """
    CNN with 2 convolutional layers, a max pooling stage, and 2 feedforward layers
    Activation function is Relu by default (but can be chosen with hiddenactivation). 
    Output activation function is identity, because we're using CrossEntropyLoss
    """
    def __init__(self,convActivation=F.relu,linearActivation=F.relu):
        super().__init__()
        self.numLayers = 4
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)      
        self.maxPool = nn.MaxPool2d(kernel_size=3)
        self.fc1 = nn.Linear(256, 256) #4608, 512       
        self.o = nn.Linear(256, 10) #512, 10
        
        self.convActivation=convActivation
        self.linearActivation=linearActivation 

    def forward(self, x):        
        self.c1 = self.convActivation(self.conv1(x))
        self.c2 = self.maxPool(self.convActivation(self.conv2(self.c1)))
        self.f1 = self.linearActivation(self.fc1(torch.flatten(self.c2,1)))
        self.out = self.o(self.f1)
        return self.out 
    
    def getDropout(self):
        return None
    
    def setDropout(self,dropout):
        return None
    
    def getActivations(self,x):
        out = self.forward(x)
        activations = []
        activations.append(self.c1)
        activations.append(self.c2)
        activations.append(self.f1)
        activations.append(self.out)
        return activations
    
    def getNetworkWeights(self,onlyFF=False):
        netWeights = []
        if not onlyFF:
            netWeights.append(self.conv1.weight.data.clone().detach())
            netWeights.append(self.conv2.weight.data.clone().detach())
        netWeights.append(self.fc1.weight.data.clone().detach())
        netWeights.append(self.o.weight.data.clone().detach())
        return netWeights
    
    def compareNetworkWeights(self, initWeights):
        currWeights = self.getNetworkWeights()
        deltaWeights = []
        for iw,cw in zip(initWeights,currWeights):
            iw = torch.flatten(iw,1)
            cw = torch.flatten(cw,1)
            deltaWeights.append(torch.norm(cw-iw,dim=1))
        return deltaWeights
        
    def measureSimilarity(self,x):
        activations = self.getActivations(x)            
        similarity = []
        similarity.append(torch.mean(aat.similarityConvLayer(x, self.conv1),axis=1))
        similarity.append(torch.mean(aat.similarityConvLayer(activations[0], self.conv2),axis=1))
        similarity.append(aat.similarityLinearLayer(torch.flatten(activations[1],1), self.fc1))
        similarity.append(aat.similarityLinearLayer(activations[2], self.o))
        return similarity
        
    def measureAlignment(self,x):
        activations = self.getActivations(x)            
        alignment = []
        alignment.append(torch.mean(aat.alignmentConvLayer(x, self.conv1),axis=1))
        alignment.append(torch.mean(aat.alignmentConvLayer(activations[0], self.conv2),axis=1))
        alignment.append(aat.alignmentLinearLayer(torch.flatten(activations[1],1), self.fc1))
        alignment.append(aat.alignmentLinearLayer(activations[2], self.o))
        return alignment
    
    def manualShape(self,evals,evecs,DEVICE,evalTransform=None):
        if evalTransform is None: evalTransform = lambda x:x
            
        sbetas = [] # produce signed betas
        netweights = self.getNetworkWeights(onlyFF=True)
        for evc,nw in zip(evecs,netweights):
            nw = nw / torch.norm(nw,dim=1,keepdim=True)
            sbetas.append(nw.cpu() @ evc)
        
        ffLayers = [2,3]
        shapedWeights = [[] for _ in range(len(ffLayers))]
        for idx in range(len(ffLayers)):
            assert np.all(evals[idx]>=0), "Found negative eigenvalues..."
            cFractionVariance = evals[idx]/np.sum(evals[idx]) # compute fraction of variance explained by each eigenvector
            cKeepFraction = evalTransform(cFractionVariance).astype(cFractionVariance.dtype) # make sure the datatype doesn't change, otherwise pytorch einsum will be unhappy
            assert np.all(cKeepFraction>=0), "Found negative transformed keep fractions. This means the transform function has an improper form." 
            assert np.all(cKeepFraction<=1), "Found keep fractions greater than 1. This is bad practice, design the evalTransform function to have a domain and range within [0,1]"
            weightNorms = torch.norm(netweights[idx],dim=1,keepdim=True) # measure norm of weights (this will be invariant to the change)
            evecComposition = torch.einsum('oi,xi->oxi',sbetas[idx],torch.tensor(evecs[idx])) # create tensor composed of each eigenvector scaled to it's contribution in each weight vector
            newComposition = torch.einsum('oxi,i->ox',evecComposition,torch.tensor(cKeepFraction)).to(DEVICE) # scale eigenvectors based on their keep fraction (by default scale them by their variance)
            shapedWeights[idx] = newComposition / torch.norm(newComposition,dim=1,keepdim=True) * weightNorms
        
        # Assign new weights to network
        self.fc1.weight.data = shapedWeights[0]
        self.o.weight.data = shapedWeights[1]
    
    @staticmethod
    def targetedDropout(net,x,idx=None,layer=None,returnFull=False):
        assert layer>=0 and layer<=2, "dropout only works on first three layers"
        c1 = net.convActivation(net.conv1(x))
        if layer==0: 
            fracDropout = len(idx)/c1.shape[1]
            c1[:,idx]=0
            c1 = c1 * (1 - fracDropout)
        c2 = net.maxPool(net.convActivation(net.conv2(c1))) 
        if layer==1: 
            fracDropout = len(idx)/c2.shape[1]
            c2[:,idx]=0
            c2 = c2 * (1 - fracDropout)
        f1 = net.linearActivation(net.fc1(torch.flatten(c2,1)))        
        if layer==2: 
            fracDropout = len(idx)/f1.shape[1]
            f1[:,idx]=0
            f1 = f1 * (1 - fracDropout)
        out = net.o(f1)
        if returnFull: return c1,c2,f1,out
        else: return out
    
    @staticmethod
    def mlTargetedDropout(net,x,idx,layer,returnFull=False):
        assert type(idx) is tuple and type(layer) is tuple, "idx and layer need to be tuples"
        assert len(idx)==len(layer), "idx and layer need to have the same length"
        npLayer = np.array(layer)
        assert len(npLayer)==len(np.unique(npLayer)), "layer must not have any repeated elements"
        # Do forward pass with targeted dropout
        c1 = net.convActivation(net.conv1(x))
        if np.any(npLayer==0):
            cIndex = idx[npLayer==0]
            fracDropout=len(cIndex)/c1.shape[1]
            c1[:,cIndex]=0
            c1 = c1 * (1 - fracDropout)
        c2 = net.maxPool(net.convAcivation(net.conv2(c1)))
        if np.any(npLayer==1):
            cIndex = idx[npLayer==1]
            fracDropout=len(cIndex)/c2.shape[1]
            c2[:,cIndex]=0
            c2 = c21 * (1 - fracDropout)
        f1 = net.linearActivation(net.fc1(torch.flatten(c2,1)))
        if np.any(npLayer==2):
            cIndex = idx[npLayer==2]
            fracDropout = len(cIndex)/f1.shape[1]
            f1[:,cIndex]=0
            f1 = f1 * (1 - fracDropout)
        out = net.o(f1)
        if returnFull: return c1,c2,f1,out
        else: return out
    
    @staticmethod
    def inputEigenfeatures(net, dataloader, onlyFF=True, DEVICE=None):
        # Handle DEVICE if not provided
        if DEVICE is None: DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Measure Activations (without dropout) for all images
        storeDropout = net.getDropout()
        net.setDropout(0) # no dropout for measuring eigenfeatures
        allimages = []
        activations = []
        for images, label in dataloader:    
            allimages.append(images)
            images = images.to(DEVICE)
            label = label.to(DEVICE)
            activations.append(net.getActivations(images))
        net.setDropout(storeDropout)
        
        # Consolidate variable structure
        allinputs = []
        if not onlyFF:
            # Only add inputs to convolutional layers if onlyFF switch is off
            allinputs.append(torch.flatten(torch.cat(allimages,dim=0).detach().cpu(),1)) # inputs to first convolutional layer
            allinputs.append(torch.flatten(torch.cat([cact[0] for cact in activations],dim=0).detach().cpu(),1)) # inputs to second convolutional layer
        allinputs.append(torch.flatten(torch.cat([cact[1] for cact in activations],dim=0).detach().cpu(),1)) # inputs to first feedforward layer
        allinputs.append(torch.cat([cact[2] for cact in activations],dim=0).detach().cpu()) # inputs to last convolutional layer
            
        # Measure eigenfeatures for input to each feedforward layer
        eigenvalues = []
        eigenvectors = []
        for ai in allinputs:
            # Covariance matrix is positive semidefinite, but numerical errors can produce negative eigenvalues
            ccov = torch.cov(ai.T)
            crank = torch.linalg.matrix_rank(ccov)
            w,v = sp.linalg.eigh(ccov)
            widx = np.argsort(w)[::-1]
            w = w[widx]
            v = v[:,widx]
            # Automatically set eigenvalues to 0 when they are numerical errors!
            w[crank:]=0
            eigenvalues.append(w)
            eigenvectors.append(v)
            
        return eigenvalues, eigenvectors
    
    @staticmethod
    def measureEigenFeatures(net, dataloader, onlyFF=True, DEVICE=None):
        eigenvalues,eigenvectors = CNN2P2.inputEigenfeatures(net, dataloader, onlyFF=onlyFF, DEVICE=DEVICE)
        
        # Measure dot product of weights on eigenvectors for each layer
        beta = []
        netweights = net.getNetworkWeights(onlyFF=onlyFF)
        for evc,nw in zip(eigenvectors,netweights):
            nw = nw / torch.norm(nw,dim=1,keepdim=True)
            beta.append(torch.abs(nw.cpu() @ evc))
            
        return beta, eigenvalues, eigenvectors
    
    @staticmethod
    def avgFromFull(full):
        numEpochs = len(full)
        numLayers = len(full[0])
        avgFull = torch.zeros((numLayers,numEpochs))
        for layer in range(numLayers):
            avgFull[layer,:] = torch.tensor([torch.mean(f[layer]) for f in full])
        return avgFull.cpu()
    
    @staticmethod
    def layerFromFull(full,layer,dim=1):
        if dim==1: 
            return torch.cat([f[layer][:,None] for f in full],dim=dim).cpu() 
        elif dim==2:
            return torch.cat([f[layer][:,:,None] for f in full],dim=dim).cpu() 
        else:
            raise ValueError("Haven't coded layerFromFull for dimensions other than 1 or 2!")       