In [1]:
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
dev = 'cpu'
if torch.cuda.is_available():
    dev = 'cuda:0'
print("Running on:",dev)
device = torch.device(dev)

Running on: cuda:0


# Quivers and quiver representations

The quiver class is designed to define an acyclic quiver with no double edges. The initial quiver is not required to have a bias vertex; such a vertex can be added with the ```add_bias()``` method.

The input list of vertices is meant to be a list of strings, one for each vertex. It is best to avoid the label ```bias``` among the vertices. 

The input list of edges is meant to be a tuple ```e = (e[0], e[1])``` where ```e[0]``` is the source and ```e[1]``` is the target. No double edges or loops are allowed.

In [2]:
class quiver:
    """Quiver class. Vertices are given as a list of strings. Edges are given as a list of pairs.
    
    Attributes and Methods.
    """
    def __init__(self, vertices : List[str], edges : List[Tuple[str]]):
        self.vertices = vertices
        # Add assert to check no repeated vertices
        # E.g. assert len(set(vertices)) == len(vertices)
        # Assert there is no bias initially, or if there is, it has the desired properties
        
        self.edges = edges
        # Add assert to check that edges is a list of pairs
        # First entry of the pair is the sourse, second is the target
        # Source and target of each edge should be in the vertex set      
        # Separate class for edges? vertices?
        
        # Get the sources and sinks
        sources = set(self.vertices)
        sinks = set(self.vertices)
        for e in self.edges:
            sources.discard(e[1])
            sinks.discard(e[0])
        self.sources = sources
        self.sinks = sinks
    
    # Check that the quiver is acyclic
    def check_acyclic(self):
        None
        # One way: find all sources, do depth-first search
        
    # Check that the vertices are in topological order    
    def check_top_order(self):
        indices = {}
        for i,v in enumerate(self.vertices):
            indices[v] = i
        return all([indices[e[0]] <  indices[e[1]] for e in self.edges])
        
    # Get the incoming edges for every vertex
    def get_incoming(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return [e for e in self.edges if e[1] == vertex]
        # Can get the incoming neighbors as [e[1] for e in self.get_incoming(vertex)]
        
    # Add a bias vertex. Considering alternatives to this ... 
    def add_bias(self):
        # Add bias vertex. This will not disturb the topological order.
        for v in self.vertices:
            if v not in self.sources:
                self.edges.append(('bias', v))
        self.vertices = ['bias'] + self.vertices
        return
    
    #########
    # Check if a vertex is a sink (Don't really need this any more)
    def is_sink(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return all([e[0] != vertex for e in self.edges])
    


In [3]:
class quiver_rep:
    """Quiver representation class. Input a quiver with dimension vector and a matrix for each edge."""
    def __init__(self, quiver: quiver, dims: Dict[str,int], matrices: Dict[str, np.array]):
        self.quiver = quiver
        self.dims = dims
        self.matrices = matrices
        
        # Check the dimension vector
        assert len(dims) == len(self.quiver.vertices), "Inappropriate dimension vector"
        for v in dims:
            assert v in self.quiver.vertices, "Inappropriate dim vector"
            assert isinstance(dims[v], int) and dims[v] >=0, "Dimension needs to be a positive integer"
        
        if 'bias' in dims:
            assert dims['bias'] == 1, "Dimension at bias needs to be 1"

            
        # Check the matrices
        assert len(matrices) == len(self.quiver.edges), "Matrices error"
        for e in matrices:
            assert e in self.quiver.edges, "Matrices error"
            assert isinstance(matrices[e], np.ndarray), "Matrices error" # May need fixing
            assert np.shape(matrices[e]) == (dims[e[1]], dims[e[0]]), "Dimension error"
            
            
    # Compute the reduced dimension vector
    def comp_dims_red(self) -> Dict:
        
        assert self.quiver.check_top_order(), "Order of the vertices is not topological"

        dims_red = {}
        for i in self.quiver.vertices:
            if i == 'bias' or i in self.quiver.sources or i in self.quiver.sinks:
                dims_red[i] = self.dims[i]
            else:
                incoming = self.quiver.get_incoming(i)
                dims_red[i] = min(self.dims[i], sum([dims_red[e[0]] for e in incoming]) )

                    
        self.dims_red = dims_red        
        return dims_red
    
    
    # Auxiliary function
    def padzeros(self, M, newrows, newcols = None):
        oldrows, oldcols = M.shape
        if newcols == None:
            newcols = oldcols
        return np.pad(M,((0,newrows-oldrows),(0,newcols-oldcols)),mode="constant")
    
    def padzeros_to_dim(self, new_dims):
        new_matrices = {}
        for e in self.matrices:
            new_matrices[e] = self.padzeros(self.matrices[e],new_dims[e[1]],new_dims[e[0]])
        return quiver_rep(self.quiver,new_dims,new_matrices)
 
    
    # QR dimensional reduction algorithm
    def QRDimRed(self, verbose : bool = False ):
        dims = self.dims
        matrices = self.matrices
        quiver = self.quiver
        vertices = quiver.vertices
        edges = quiver.edges

        # Check that vertices are in a topological order
        assert quiver.check_top_order(), "Order of the vertices is not topological"

        # Compute the reduced dimension vector
        dims_red = self.comp_dims_red()
        # print(dims, dims_red)

        # Q = dictionary mapping each vertex to an orthogonal matrix
        Q = {}

        # Vmatrices = matrices of the reduced representation V, mapping each edge to a matrix
        Vmatrices = {}

        if verbose:
            print(quiver.edges)
            print(quiver.vertices)

        for i in vertices:
            incoming = quiver.get_incoming(i)

            # Case of a source vertex
            if incoming == []:
                Q[i] = np.eye(dims[i])

            # Case of a hidden vertex    
            elif i not in quiver.sinks:

                # Compute the matrix to be QR-decomposed
                M = np.array([])
                for e in incoming:
                    # Transform weights on incoming edges
                    Qj = Q[e[0]]
                    Me = matrices[e] @ Qj[:,:dims_red[e[0]]]
                    if np.shape(M) == (0,):
                        M = Me
                    else:
                        M = np.hstack((M, Me))

                Q[i], R = np.linalg.qr(M, mode="complete")

                # Case of reduction 
                if dims_red[i] < dims[i]: 
                    R = R[:dims_red[i]]

                # Process and add to the dictionaries
                for e in incoming:                       
                    # Extract V_e from R_i for all incoming edges e
                    Vmatrices[e] = R[:,:dims_red[e[0]]]
                    R = R[:,dims_red[e[0]]:]

            # Case of a sink (no reduction)
            else:
                Q[i] = np.eye(dims[i])
                for e in incoming:
                    # Transform weights on incoming edges
                    Qj = Q[e[0]]
                    Vmatrices[e] = matrices[e] @ Qj[:,:dims_red[e[0]]]                


        # Make V into a representation
        V = quiver_rep(quiver, dims_red, Vmatrices)

        # Verify that V is a subrepresentation of Q^{-1} W  
        for e in quiver.edges:
            Qi = Q[e[0]]
            Qj = Q[e[1]]
            max_diff = np.max(np.abs(np.transpose(Qj) @ matrices[e] @ Qi[:,:dims_red[e[0]]] 
                         - self.padzeros(Vmatrices[e], dims[e[1]])))
            assert max_diff < 1e-10, "Error in the algorithm"

        return Q, V
    
    
    def reduced_representation(self, verbose : bool = False ):
        return self.QRDimRed(verbose)[1]
                
    def transformed_representation(self):
        Q, V = self.QRDimRed()
        transformed_mat_dict = {}
        for e in self.quiver.edges:
            transformed_mat_dict[e] = np.transpose(Q[e[1]]) @ self.matrices[e] @ Q[e[0]]
        
        Q_inv_W = quiver_rep(self.quiver, self.dims, transformed_mat_dict)
        
        return Q_inv_W
    
    def Q_act(self,Q):
        new_matrices = {}
        for w in self.matrices:
            new_matrices[w] = Q[w[1]] @ self.matrices[w] @ np.transpose(Q[w[0]])
        return quiver_rep(self.quiver, self.dims, new_matrices)

### Linear Feedforward Function

In [4]:
# Linear feedforward function, ignoring biases
# Might turn out not to be super necessary

def lin_ff(W : quiver_rep) -> np.array:
    dims = W.dims
    matrices = W.matrices
    quiver = W.quiver
    vertices = quiver.vertices
    edges = quiver.edges
    
    assert quiver.check_top_order(), "Order of the vertices is not topological"
    
    # Dictionary for partial feedforward functions:
    partial = {}
    input_dim = 0
    # input_dim = sum( [dims[i] for i in quiver.sources])
    
    for i in quiver.vertices:
        incoming = quiver.get_incoming(i)
        
        # Case of a source
        if i in quiver.sources:
            
            # Update matrices already defined
            for j in partial:
                partial[j] = np.hstack((partial[j], np.zeros((dims[j],dims[i]))))
                
            # Define the new matrix as a projection matrix
            if input_dim == 0:
                partial[i] = np.eye(dims[i])
            else:
                partial[i] = np.hstack((np.zeros((dims[i], input_dim)), np.eye(dims[i])))
            input_dim += dims[i]

        # Case of a hidden vertex or a sink
        else:
            
            # Compute the matrix to be added
            A = np.zeros((dims[i], input_dim))
            for e in incoming:
                # Ignore biases for now
                if e[0] != 'bias':
                    A += matrices[e] @ partial[e[0]]
            partial[i] = A
         
    # Compute the final output matrix by stacking the matrices for sinks
    result = np.array([])
    for i in quiver.sinks:
        if np.shape(result) == (0,):
            result = partial[i]
        else:
            result = np.vstack((result, partial[i]))
                        
    return result

# Dimensional reduction algorithm

In [5]:
# Auxiliary function

def padzeros(M,newrows,newcols = None):
    oldrows, oldcols = M.shape
    if newcols == None:
        newcols = oldcols
    return np.pad(M,((0,newrows-oldrows),(0,newcols-oldcols)),mode="constant")

In [6]:
# QR dimensional reduction algorithm
# This is incorporated in the quiver_rep class

def QRDimRed(W : quiver_rep, verbose : bool = False ):
    dims = W.dims
    matrices = W.matrices
    quiver = W.quiver
    vertices = quiver.vertices
    edges = quiver.edges
    
    # Check that vertices are in a topological order
    assert quiver.check_top_order(), "Order of the vertices is not topological"
    
    # Compute the reduced dimension vector
    dims_red = W.comp_dims_red()
    # print(dims, dims_red)
    
    # Q = dictionary mapping each vertex to an orthogonal matrix
    Q = {}
    
    # Vmatrices = matrices of the reduced representation V, mapping each edge to a matrix
    Vmatrices = {}
    
    if verbose:
        print(quiver.edges)
        print(quiver.vertices)
    
    for i in vertices:
        incoming = quiver.get_incoming(i)
        
        # Case of a source vertex
        if incoming == []:
            Q[i] = np.eye(dims[i])
            
        # Case of a hidden vertex    
        elif i not in quiver.sinks:
            
            # Compute the matrix to be QR-decomposed
            M = np.array([])
            for e in incoming:
                # Transform weights on incoming edges
                Qj = Q[e[0]]
                Me = matrices[e] @ Qj[:,:dims_red[e[0]]]
                if np.shape(M) == (0,):
                    M = Me
                else:
                    M = np.hstack((M, Me))
                    
            Q[i], R = np.linalg.qr(M, mode="complete")

            # Case of reduction 
            if dims_red[i] < dims[i]: 
                R = R[:dims_red[i]]
                
            # Process and add to the dictionaries
            for e in incoming:                       
                # Extract V_e from R_i for all incoming edges e
                Vmatrices[e] = R[:,:dims_red[e[0]]]
                R = R[:,dims_red[e[0]]:]
                
        # Case of a sink (no reduction)
        else:
            Q[i] = np.eye(dims[i])
            for e in incoming:
                # Transform weights on incoming edges
                Qj = Q[e[0]]
                Vmatrices[e] = matrices[e] @ Qj[:,:dims_red[e[0]]]                

    
    # Make V into a representation
    V = quiver_rep(quiver, dims_red, Vmatrices)

    # Verify that V is a subrepresentation of Q^{-1} W  
    for e in quiver.edges:
        Qi = Q[e[0]]
        Qj = Q[e[1]]
        max_diff = np.max(np.abs(np.transpose(Qj) @ matrices[e] @ Qi[:,:dims_red[e[0]]] 
                     - padzeros(Vmatrices[e], dims[e[1]])))
        assert max_diff < 1e-10, "Error in the algorithm"

    return Q, V

# Example 1

In [7]:
# EXAMPLE
# Quiver with skip connections and no bias

vertex_list = ['a', 'b', 'c', 'd']
edge_list = [('a', 'b'), ('a','c'), ('b','c'), ('c', 'd')]

quiv_ex = quiver(vertex_list, edge_list)
# print(Q.vertices, Q.edges)

# Test the methods
print(quiv_ex.get_incoming('a'), quiv_ex.get_incoming('b'), quiv_ex.get_incoming('c'), quiv_ex.get_incoming('d'))
print(quiv_ex.is_sink('a'), quiv_ex.is_sink('b'), quiv_ex.is_sink('c'), quiv_ex.is_sink('d'))
print(quiv_ex.check_top_order())
print(quiv_ex.sources)
quiv_ex.add_bias()
print(quiv_ex.vertices)
print(quiv_ex.edges)

[] [('a', 'b')] [('a', 'c'), ('b', 'c')] [('c', 'd')]
False False False True
True
{'a'}
['bias', 'a', 'b', 'c', 'd']
[('a', 'b'), ('a', 'c'), ('b', 'c'), ('c', 'd'), ('bias', 'b'), ('bias', 'c'), ('bias', 'd')]


In [8]:
# Representation of this quiver

dim_vector = {'bias' : 1, 'a': 2, 'b': 4, 'c': 8, 'd': 2 }

maps = {('a', 'b') : np.random.rand(4, 2), 
        ('a', 'c') : np.random.rand(8, 2), 
        ('b', 'c') : np.random.rand(8, 4),
        ('c', 'd') : np.random.rand(2, 8),
        ('bias', 'b') : np.random.rand(4, 1),
        ('bias', 'c') : np.random.rand(8, 1),
        ('bias', 'd') : np.random.rand(2, 1)}

ex_rep = quiver_rep(quiv_ex, dim_vector, maps)
print(ex_rep.comp_dims_red())

{'bias': 1, 'a': 2, 'b': 3, 'c': 6, 'd': 2}


In [9]:
# Check the algorithm

Q_ex ,V_ex = QRDimRed(ex_rep)
print(ex_rep.comp_dims_red())
print([np.shape(Q_ex[i]) for i in Q_ex])
print([np.shape(V_ex.matrices[e]) for e in V_ex.matrices])

{'bias': 1, 'a': 2, 'b': 3, 'c': 6, 'd': 2}
[(1, 1), (2, 2), (4, 4), (8, 8), (2, 2)]
[(3, 2), (3, 1), (6, 2), (6, 3), (6, 1), (2, 6), (2, 1)]


In [10]:
lin_ff(ex_rep)

array([[6.4775832 , 4.65005482],
       [5.73522075, 5.0397263 ]])

In [11]:
ex_rep.matrices[('c','d')] @ ( ex_rep.matrices[('a','c')] + ex_rep.matrices[('b','c')] @ ex_rep.matrices[('a','b')] )

array([[6.4775832 , 4.65005482],
       [5.73522075, 5.0397263 ]])

In [12]:
lin_ff(V_ex)

array([[6.4775832 , 4.65005482],
       [5.73522075, 5.0397263 ]])

# Example 2

In [13]:
# EXAMPLE
# Quiver with multiple inputs and outputs

vertex_list2 = ['a', 'b', 'c', 'd', 'e']
edge_list2 = [('a', 'c'), ('b','c'), ('c','d'), ('c', 'e')]

quiv_ex2 = quiver(vertex_list2, edge_list2)
# print(Q.vertices, Q.edges)

# Test the methods
print(quiv_ex2.get_incoming('a'), quiv_ex2.get_incoming('b'), quiv_ex2.get_incoming('c'), 
      quiv_ex2.get_incoming('d'), quiv_ex2.get_incoming('e'))
print(quiv_ex2.is_sink('a'), quiv_ex2.is_sink('b'), quiv_ex2.is_sink('c'), 
      quiv_ex2.is_sink('d'), quiv_ex2.is_sink('e'))
print(quiv_ex2.check_top_order())
print(quiv_ex2.sources)
quiv_ex2.add_bias()
print(quiv_ex2.vertices)
print(quiv_ex2.edges)

[] [] [('a', 'c'), ('b', 'c')] [('c', 'd')] [('c', 'e')]
False False False True True
True
{'a', 'b'}
['bias', 'a', 'b', 'c', 'd', 'e']
[('a', 'c'), ('b', 'c'), ('c', 'd'), ('c', 'e'), ('bias', 'c'), ('bias', 'd'), ('bias', 'e')]


In [14]:
# Representation of this quiver

dim_vector2 = {'bias' : 1, 'a': 1, 'b': 2, 'c': 8, 'd': 2 , 'e': 6}

maps2 = {('a', 'c') : np.random.rand(8, 1), 
        ('b', 'c') : np.random.rand(8, 2), 
        ('c', 'd') : np.random.rand(2, 8),
        ('c', 'e') : np.random.rand(6, 8),
        ('bias', 'c') : np.random.rand(8, 1),
        ('bias', 'd') : np.random.rand(2, 1),
        ('bias', 'e') : np.random.rand(6, 1)}

ex_rep2 = quiver_rep(quiv_ex2, dim_vector2, maps2)
print(ex_rep2.comp_dims_red())

{'bias': 1, 'a': 1, 'b': 2, 'c': 4, 'd': 2, 'e': 6}


In [15]:
# Check the algorithm

Q_ex2 ,V_ex2 = QRDimRed(ex_rep2)
print(ex_rep2.comp_dims_red())
print([np.shape(Q_ex2[i]) for i in Q_ex2])
print([np.shape(V_ex2.matrices[e]) for e in V_ex2.matrices])

{'bias': 1, 'a': 1, 'b': 2, 'c': 4, 'd': 2, 'e': 6}
[(1, 1), (1, 1), (2, 2), (8, 8), (2, 2), (6, 6)]
[(4, 1), (4, 2), (4, 1), (2, 4), (2, 1), (6, 4), (6, 1)]


In [16]:
lin_ff(ex_rep2)

array([[1.83777321, 2.49704088, 2.35649069],
       [0.99597447, 2.63640189, 1.83602389],
       [1.52937084, 2.1005179 , 2.03322776],
       [1.40336098, 1.87197742, 2.30612105],
       [1.07614485, 1.75081518, 1.31076907],
       [0.71693728, 1.29812791, 1.11900536],
       [1.52257671, 2.4979547 , 2.31852296],
       [1.15139054, 1.60311214, 1.53733591]])

In [17]:
lin_ff(V_ex2)

array([[1.83777321, 2.49704088, 2.35649069],
       [0.99597447, 2.63640189, 1.83602389],
       [1.52937084, 2.1005179 , 2.03322776],
       [1.40336098, 1.87197742, 2.30612105],
       [1.07614485, 1.75081518, 1.31076907],
       [0.71693728, 1.29812791, 1.11900536],
       [1.52257671, 2.4979547 , 2.31852296],
       [1.15139054, 1.60311214, 1.53733591]])

# Example 3

In [18]:

vertex_list2 = ['a', 'b', 'c', 'd', 'e']
edge_list2 = [('a', 'b'), ('a','c'), ('b','d'), ('c', 'd'),('d','e')]

quiv_ex2 = quiver(vertex_list2, edge_list2)

dim_vector2 = {'a': 2, 'b': 3, 'c': 3, 'd': 6 , 'e': 2}

maps2 = {('a', 'b') : np.random.rand(3, 2), 
        ('a', 'c') : np.random.rand(3, 2), 
        ('b', 'd') : np.random.rand(6, 3),
        ('c', 'd') : np.random.rand(6, 3),
        ('d', 'e') : np.random.rand(2, 6)}

ex_rep2 = quiver_rep(quiv_ex2, dim_vector2, maps2)
print(ex_rep2.comp_dims_red())

Q_ex2 ,V_ex2 = QRDimRed(ex_rep2)

{'a': 2, 'b': 2, 'c': 2, 'd': 4, 'e': 2}


In [19]:
for e in V_ex2.quiver.edges:
    print()
    print(e)
    print(V_ex2.matrices[e])


('a', 'b')
[[-0.92070007 -0.68384949]
 [ 0.          0.52108646]]

('a', 'c')
[[-1.27158292 -0.87712477]
 [ 0.          0.40379505]]

('b', 'd')
[[ 1.95040582 -1.14579538]
 [ 0.         -0.66934579]
 [ 0.          0.        ]
 [ 0.          0.        ]]

('c', 'd')
[[ 1.93236873 -0.06835251]
 [ 0.36948703  0.24830295]
 [-1.0949733   0.32624576]
 [ 0.         -0.67090186]]

('d', 'e')
[[-0.9829389  -0.0192903   0.41830236  0.09901009]
 [-0.67489109 -0.07671589  0.18676567 -0.39561193]]


# Quiver Neural Networks

In [20]:
class RadAct(nn.Module):
    def __init__(self, eta = F.relu):
        super().__init__()
        self.eta = eta
        self.shift = 0 
        # Add internal bias/shift later
        
    def forward(self,x):
        # x: [Batch x Channel]
        # Radial activations
        r = torch.linalg.norm(x, dim=-1) 
        if torch.min(r) < 1e-6:
            r += 1e-6
        scalar = self.eta(r + self.shift) / r
        return x * scalar.unsqueeze(-1)   

In [21]:
class QuiverNN(nn.Module):
    
    def __init__(self, eta: float , quiver: quiver, dims: Dict[str,int] ):
        super().__init__()
        self.eta = eta
        self.quiver = quiver
        self.dims = dims
        self.matrices = nn.ModuleDict()
        
        # Linear layer for each edge
        for e in quiver.edges:
            self.matrices[self.edge_tup_to_str(e)] = nn.Linear(self.dims[e[0]], self.dims[e[1]], bias = False)
        
        assert quiver.check_top_order(), "Vertices not in topological order."
        
        # Radial activations
        self.act = RadAct(self.eta)
        
        # Add assert statement to check that dims is a dimension vector for Q
          
    # Encode each edge as a string
    def edge_tup_to_str(self, e : Tuple[str]):
        assert e in self.quiver.edges
        (t,h) = e
        return t + " to " + h
    
    # Extract the (source, target) pair from each string encoding an edge
    def edge_str_to_tup(self, e_str : str):
        t,to,h = e_str.split()
        assert (t,h) in self.quiver.edges
        return (t,h)
    
    # Get the matrix of an edge
    def get_matrix(self, e):
        return self.matrices[self.edge_tup_to_str(e)].weight


    # The feedforward function
    def forward(self, x, non_linear = True):
        
        # Initialize Data Flow
        h = {}
        
        # Bias 
        h['bias'] = torch.tensor(1.0)
        
        # Sources
        for v in self.quiver.sources:
            h[v] = x[v]
            batch_size = x[v].shape[0]
        
        # Assert batchsize is same for all sources 
        
        for v in self.quiver.vertices:
            
            # Non-source vertices
            if v not in self.quiver.sources:
                h[v] = torch.zeros(batch_size,self.dims[v])
                for e in self.quiver.get_incoming(v):
                    e_lin = self.matrices[self.edge_tup_to_str(e)]
                    h[e[1]] += e_lin(h[e[0]])
                if non_linear:
                    h[v] = self.act(h[v])
        
        # Feedforward function output
        out = {}
        for v in self.quiver.sinks:
            out[v] = h[v]
            
        return out
            
    
    def set_weights(self, new_weights: quiver_rep):
        assert new_weights.quiver == self.quiver, "weights have different quiver"
        
        for e in self.quiver.edges:
            self.matrices[self.edge_tup_to_str(e)].weight = \
                torch.nn.Parameter(torch.tensor(new_weights.matrices[e],dtype = torch.float))
        self.dims = new_weights.dims

    def export_weights(self) -> quiver_rep:
        quiver_rep_matrix_dict = {}
        for e in self.quiver.edges:
            quiver_rep_matrix_dict[e] = self.matrices[self.edge_tup_to_str(e)].weight.detach().cpu().numpy()
        return quiver_rep(self.quiver, self.dims, quiver_rep_matrix_dict)
    
    def export_reduced_weights(self) -> quiver_rep:
        exported_rep = self.export_weights()
        return exported_rep.reduced_representation()
    
    def transformed_network(self):
        exported_rep = self.export_weights()
        rep_transformed = exported_rep.transformed_representation()
        net_trans = QuiverNN(self.eta, self.quiver, self.dims)
        net_trans.set_weights(rep_transformed)
        #net_trans.set_activation_biases(self.export_activation_biases())
        return net_trans
        
    def reduced_network(self):
        reduced_rep = self.export_reduced_weights()
        net_reduced = QuiverNN(self.eta, self.quiver, reduced_rep.dims)
        net_reduced.set_weights(reduced_rep)
        #net_trans.set_activation_biases(self.export_activation_biases())
        return net_reduced
    
    def Q_act(self,Q):
        QW = self.export_weights().Q_act(Q)
        self.set_weights(QW)
    
    # Might not need these:
    def export_activation_biases(self) -> List[float]:
        None
    def set_activation_biases(self, new_biases: List[float]):    
        None

## Example

In [22]:
vertex_list2 = ['a', 'b', 'c', 'd', 'e']
edge_list2 = [('a', 'c'), ('b','c'), ('c','d'), ('c', 'e')]
dim_vector2 = {'bias' : 1, 'a': 1, 'b': 2, 'c': 8, 'd': 2 , 'e': 6}
quiv_ex2 = quiver(vertex_list2, edge_list2)
quiv_ex2.add_bias()

quiverNN = QuiverNN(eta = F.relu, quiver = quiv_ex2, dims = dim_vector2)

# Representation of this quiver

maps2 = {('a', 'c') : np.random.rand(8, 1), 
        ('b', 'c') : np.random.rand(8, 2), 
        ('c', 'd') : np.random.rand(2, 8),
        ('c', 'e') : np.random.rand(6, 8),
        ('bias', 'c') : np.random.rand(8, 1),
        ('bias', 'd') : np.random.rand(2, 1),
        ('bias', 'e') : np.random.rand(6, 1)}

ex_rep2 = quiver_rep(quiv_ex2, dim_vector2, maps2)
print(ex_rep2.comp_dims_red())

quiverNN.set_weights(ex_rep2)

{'bias': 1, 'a': 1, 'b': 2, 'c': 4, 'd': 2, 'e': 6}


In [23]:
x = {'a': torch.tensor([[1.0]]),'b': torch.tensor([[1.0,1.0]])}
quiverNN(x)

{'d': tensor([[7.2447, 6.3284]], grad_fn=<MulBackward0>),
 'e': tensor([[5.5508, 4.9974, 6.8940, 7.4432, 6.1963, 5.1831]],
        grad_fn=<MulBackward0>)}

In [24]:
quiverNN.export_weights().comp_dims_red()

{'bias': 1, 'a': 1, 'b': 2, 'c': 4, 'd': 2, 'e': 6}

In [25]:
red_quiverNN = quiverNN.reduced_network()
red_quiverNN

QuiverNN(
  (matrices): ModuleDict(
    (a to c): Linear(in_features=1, out_features=4, bias=False)
    (b to c): Linear(in_features=2, out_features=4, bias=False)
    (c to d): Linear(in_features=4, out_features=2, bias=False)
    (c to e): Linear(in_features=4, out_features=6, bias=False)
    (bias to c): Linear(in_features=1, out_features=4, bias=False)
    (bias to d): Linear(in_features=1, out_features=2, bias=False)
    (bias to e): Linear(in_features=1, out_features=6, bias=False)
  )
  (act): RadAct()
)

In [26]:
red_quiverNN = quiverNN.reduced_network()
x = {'a': torch.tensor([[1.0]]),'b': torch.tensor([[1.0,1.0]])}
red_quiverNN(x)

{'d': tensor([[7.2447, 6.3284]], grad_fn=<MulBackward0>),
 'e': tensor([[5.5508, 4.9974, 6.8940, 7.4432, 6.1963, 5.1831]],
        grad_fn=<MulBackward0>)}

# Test Theorem

#### Setup

In [27]:
W = quiverNN.export_weights()
T_net = quiverNN.transformed_network()
Q,V = quiverNN.export_weights().QRDimRed()
T = T_net.export_weights()

In [28]:
def single_mask(red_rows, red_cols, orig_rows, orig_cols):
    assert red_rows <= orig_rows 
    assert red_cols <= orig_cols
    mask = torch.ones((orig_rows, orig_cols))
    for j in range(red_rows, orig_rows):
        mask[j][:red_cols] = 0
    return mask

In [29]:
def gamma(model, x, y, lr, verbose = False):
    
    y_hat = model(x)

    loss = torch.tensor(0.0)
    for k in y:
        loss += (y_hat[k] - y[k]).pow(2).sum()
    loss.backward()
    
    if verbose:
        print(loss.item())
    
    with torch.no_grad(): 
        for e in model.quiver.edges:
            p = model.get_matrix(e)
            p -= lr * p.grad
    
    return model


In [30]:
def gammaPGD(model, x, y, lr, verbose = False):
    
    y_hat = model(x)
    
    dims = model.dims 
    red_dims = model.export_reduced_weights().dims
    
    loss = torch.tensor(0.0)
    for k in y:
        loss += (y_hat[k] - y[k]).pow(2).sum()
    loss.backward()
    
    if verbose:
        print(loss.item())
    
    with torch.no_grad(): 
        for e in model.quiver.edges:
            p = model.get_matrix(e)
            mask_grad = single_mask(red_dims[e[1]],red_dims[e[0]],dims[e[1]],dims[e[0]])
            p -= lr * (mask_grad * p.grad)
    
    return model
    

In [31]:
gd_quiverNN = type(quiverNN)(quiverNN.eta, quiverNN.quiver, quiverNN.dims) # get a new instance
gd_quiverNN.load_state_dict(quiverNN.state_dict())

pgd_quiverNN = type(quiverNN)(quiverNN.eta, quiverNN.quiver, quiverNN.dims) # get a new instance
pgd_quiverNN.load_state_dict(quiverNN.state_dict())

gd_red_quiverNN = type(red_quiverNN)(red_quiverNN.eta, red_quiverNN.quiver, red_quiverNN.dims) # get a new instance
gd_red_quiverNN.load_state_dict(red_quiverNN.state_dict())

pgd_T_quiverNN = type(T_net)(T_net.eta, T_net.quiver, T_net.dims) # get a new instance
pgd_T_quiverNN.load_state_dict(T_net.state_dict())

gd_T_quiverNN = type(T_net)(T_net.eta, T_net.quiver, T_net.dims) # get a new instance
gd_T_quiverNN.load_state_dict(T_net.state_dict())

<All keys matched successfully>

In [32]:
lr = 0.001
x = {'a': torch.tensor([[1.0]]),'b': torch.tensor([[1.0,1.0]])}
y = {'d': torch.tensor([[1.0,1.0]]), 'e': torch.tensor([[1.0,1.0,1.0,1.0,1.0,1.0]])}

#### Loss(Gamma_Proj(T)) = Loss(Gamma_Red(V))

In [33]:
print("\n V (reduced network) under (reduced) gradient descent.")
for i in range(3):
    gamma(gd_red_quiverNN, x, y, lr, True);
    
print("\n T=Q^{-1}W (interpolating network) under projected gradient descent.")
for i in range(3):
    gammaPGD(pgd_T_quiverNN, x, y, lr, True);


 V (reduced network) under (reduced) gradient descent.
224.8317413330078
165.90267944335938
86.0519027709961

 T=Q^{-1}W (interpolating network) under projected gradient descent.
224.8317413330078
165.90267944335938
86.0519027709961


#### Gamma(W) = Q Gamma(T)

In [34]:
QT = T.Q_act(Q)

# For \gamma^0
total_diff = 0.0
for k in W.matrices:
    total_diff += np.sum(np.abs(QT.matrices[k] - W.matrices[k]))
print(total_diff)

1.0844607967949726e-06


In [35]:
# For \gamma^1

gamma(gd_quiverNN, x, y, lr)
gammaW = gd_quiverNN.export_weights()

gamma(gd_T_quiverNN, x, y, lr)
gammaT = gd_T_quiverNN.export_weights()

QgammaT = gammaT.Q_act(Q)

total_diff = 0.0
for k in W.matrices:
    total_diff += np.sum(np.abs(QgammaT.matrices[k] - gammaW.matrices[k]))
print(total_diff)


1.4850112807739124e-06



#### Gamma_Proj(T) - T = ### ( Gamma(V) -V )

In [36]:
pgd_T = pgd_T_quiverNN.export_weights()

gd_V_pad = gd_red_quiverNN.export_weights().padzeros_to_dim(T.dims)
V_pad = V.padzeros_to_dim(T.dims)

total_diff = 0.0
for k in T.matrices:
    total_diff += np.sum(pgd_T.matrices[k] - T.matrices[k]-(gd_V_pad.matrices[k] - V_pad.matrices[k]))
print(total_diff)

3.4935624976251134e-08
