In [1]:
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
dev = 'cpu'
if torch.cuda.is_available():
    dev = 'cuda:0'
print("Running on:",dev)
device = torch.device(dev)

Running on: cpu


# Quivers and quiver representations

The quiver class is designed to define an acyclic quiver with no double edges. The initial quiver is not required to have a bias vertex; such a vertex can be added with the ```add_bias()``` method.

The input list of vertices is meant to be a list of strings, one for each vertex. It is best to avoid the label ```bias``` among the vertices. 

The input list of edges is meant to be a tuple ```e = (e[0], e[1])``` where ```e[0]``` is the source and ```e[1]``` is the target. No double edges or loops are allowed.

In [2]:
class quiver:
    """Quiver class. Vertices are given as a list of strings. Edges are given as a list of pairs.
    
    Attributes and Methods.
    """
    def __init__(self, vertices : List[str], edges : List[Tuple[str]]):
        self.vertices = vertices
        # Add assert to check no repeated vertices
        # E.g. assert len(set(vertices)) == len(vertices)
        # Assert there is no bias initially, or if there is, it has the desired properties
        
        self.edges = edges
        # Add assert to check that edges is a list of pairs
        # First entry of the pair is the sourse, second is the target
        # Source and target of each edge should be in the vertex set      
        # Separate class for edges? vertices?
        
        # Get the sources
        sources = set(self.vertices)
        for e in self.edges:
            sources.discard(e[1])
        self.sources = sources
                
    
    # Check that the quiver is acyclic
    def check_acyclic(self):
        None
        # One way: find all sources, do depth-first search
        
    # Check that the vertices are in topological order    
    def check_top_order(self):
        indices = {}
        for i,v in enumerate(self.vertices):
            indices[v] = i
        for e in self.edges:
            if indices[e[0]] == indices[e[1]]:
                return False
            if indices[e[0]] > indices[e[1]]:
                return False
        return True
        
    # Get the incoming edges for every vertex
    def get_incoming(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return [e for e in self.edges if e[1] == vertex]
        # Can get the incoming neighbors as [e[1] for e in self.get_incoming(vertex)]
        
    # Check if a vertex is a sink
    def is_sink(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return all([e[0] != vertex for e in self.edges])
    
    # Add a bias vertex. Considering alternatives to this ... 
    def add_bias(self):
        # Add bias vertex. This will not disturb the topological order.
        for v in self.vertices:
            if v not in self.sources:
                self.edges.append(('bias', v))
        self.vertices = ['bias'] + self.vertices
        return


In [3]:
class quiver_rep:
    """Quiver representation class. Input a quiver with dimension vector and a matrix for each edge."""
    def __init__(self, quiver: quiver, dims: Dict[str,int], matrices: Dict[str, np.array]):
        self.quiver = quiver
        self.dims = dims
        self.matrices = matrices
        
        # Check the dimension vector
        assert len(dims) == len(self.quiver.vertices), "Inappropriate dimension vector"
        for v in dims:
            assert v in self.quiver.vertices, "Inappropriate dim vector"
            assert isinstance(dims[v], int) and dims[v] >=0, "Dimension needs to be a positive integer"
        assert dims['bias'] == 1, "Dimension at bias needs to be 1"

            
        # Check the matrices
        assert len(matrices) == len(self.quiver.edges), "Matrices error"
        for e in matrices:
            assert e in self.quiver.edges, "Matrices error"
            assert isinstance(matrices[e], np.ndarray), "Matrices error" # May need fixing
            assert np.shape(matrices[e]) == (dims[e[1]], dims[e[0]]), "Dimension error"
            
            
    # Compute the reduced dimension vector
    def comp_dims_red(self) -> Dict:
        
        assert self.quiver.check_top_order(), "Order of the vertices is not topological"

        dims_red = {}
        for i in self.quiver.vertices:
            if i == 'bias' or i in self.quiver.sources or self.quiver.is_sink(i):
                dims_red[i] = self.dims[i]
            else:
                incoming = self.quiver.get_incoming(i)
                dims_red[i] = min(self.dims[i], sum([dims_red[e[0]] for e in incoming]) )

                    
        self.dims_red = dims_red        
        return dims_red
        

In [4]:
# Linear feedforward function, ignoring biases
# Might turn out not to be super necessary

def lin_ff(W : quiver_rep) -> np.array:
    dims = W.dims
    matrices = W.matrices
    quiver = W.quiver
    vertices = quiver.vertices
    edges = quiver.edges
    
    assert quiver.check_top_order(), "Order of the vertices is not topological"
    
    # Dictionary for partial feedforward functions:
    partial = {}
    result = np.array([])
    input_dim = 0
    # input_dim = sum( [dims[i] for i in quiver.sources])
    
    for i in quiver.vertices:
        incoming = quiver.get_incoming(i)
        
        # Case of a source
        if incoming == [] and i != 'bias':
            
            # Update matrices already defined
            for j in partial:
                partial[j] = np.hstack((partial[j], np.zeros((dims[j],dims[i]))))
                
            # Define the new matrix as a projection matrix
            if input_dim == 0:
                partial[i] = np.eye(dims[i])
            else:
                partial[i] = np.hstack((np.zeros((dims[i], input_dim)), np.eye(dims[i])))
            input_dim += dims[i]

        else:
            
            # Compute the matrix to be added
            A = np.zeros((dims[i], input_dim))
            for e in incoming:
                # Ignore biases for now
                if e[0] != 'bias':
                    A += matrices[e] @ partial[e[0]]
            partial[i] = A
            
            # Case of a sink: add rows to the final matrix
            if quiver.is_sink(i):
                if np.shape(result) == (0,):
                    result = partial[i]
                else:
                    result = np.vstack((result, partial[i]))
                        
    return result

# Dimensional reduction algorithm

In [5]:
# Auxiliary function

def padzeros(M,newrows,newcols = None):
    oldrows, oldcols = M.shape
    if newcols == None:
        newcols = oldcols
    return np.pad(M,((0,newrows-oldrows),(0,newcols-oldcols)),mode="constant")

In [6]:
# QR dimensional reduction algorithm
# This will eventually be incorporated in the quiver_rep class

def QRDimRed(W : quiver_rep):
    dims = W.dims
    matrices = W.matrices
    quiver = W.quiver
    vertices = quiver.vertices
    edges = quiver.edges
    
    # Compute the reduced dimension vector
    dims_red = W.comp_dims_red()
    # print(dims, dims_red)
    
    # Check that vertices are in a topological order
    assert quiver.check_top_order(), "Order of the vertices is not topological"
    
    # Q = dictionary mapping each vertex to an orthogonal matrix
    Q = {}
    
    # Vmatrices = matrices of the reduced representation V, mapping each edge to a matrix
    Vmatrices = {}
    
    print(quiver.edges)
    print(quiver.vertices)
    
    for i in vertices:
        incoming = quiver.get_incoming(i)
        
        # Case of a source vertex
        if incoming == []:
            Q[i] = np.eye(dims[i])
            
        else:
            
            # Compute the matrix to be QR-reduced
            M = np.array([])
            for e in incoming:
                # Transform weights on incoming edges
                Qj = Q[e[0]]
                Me = matrices[e] @ Qj[:,:dims_red[e[0]]]
                if np.shape(M) == (0,):
                    M = Me
                else:
                    M = np.hstack((M,Me))
            
            # Case of reduction 
            if dims_red[i] < dims[i]: 
                Q_cur, R = np.linalg.qr(M, mode="complete")
                R = R[:dims_red[i]]
                
            # Case of no reduction for non-sinks
            elif not quiver.is_sink(i):
                Q_cur, R = np.linalg.qr(M, mode="complete")
                
            # Case of a sink (no reduction)
            else:
                Q_cur = np.eye(dims[i])
                R = M
                
            # Process and add to the dictionaries
            Q[i] = Q_cur
            for e in incoming:                       
                # Extract V_e from R_i for all incoming edges e
                Vmatrices[e] = R[:,:dims_red[e[0]]]
                R = R[:,dims_red[e[0]]:]
    
    # Make V into a representation
    V = quiver_rep(quiver, dims_red, Vmatrices)

    # Verify that V is a subrepresentation of Q^{-1} W  
    for e in quiver.edges:
        Qi = Q[e[0]]
        Qj = Q[e[1]]
        max_diff = np.max(np.abs(np.transpose(Qj) @ matrices[e] @ Qi[:,:dims_red[e[0]]] 
                     - padzeros(Vmatrices[e], dims[e[1]])))
        assert max_diff < 1e-10, "Error in the algorithm"

    return Q, V

# Example 1

In [7]:
# EXAMPLE
# Quiver with skip connections and no bias

vertex_list = ['a', 'b', 'c', 'd']
edge_list = [('a', 'b'), ('a','c'), ('b','c'), ('c', 'd')]

quiv_ex = quiver(vertex_list, edge_list)
# print(Q.vertices, Q.edges)

# Test the methods
print(quiv_ex.get_incoming('a'), quiv_ex.get_incoming('b'), quiv_ex.get_incoming('c'), quiv_ex.get_incoming('d'))
print(quiv_ex.is_sink('a'), quiv_ex.is_sink('b'), quiv_ex.is_sink('c'), quiv_ex.is_sink('d'))
print(quiv_ex.check_top_order())
print(quiv_ex.sources)
quiv_ex.add_bias()
print(quiv_ex.vertices)
print(quiv_ex.edges)

[] [('a', 'b')] [('a', 'c'), ('b', 'c')] [('c', 'd')]
False False False True
True
{'a'}
['bias', 'a', 'b', 'c', 'd']
[('a', 'b'), ('a', 'c'), ('b', 'c'), ('c', 'd'), ('bias', 'b'), ('bias', 'c'), ('bias', 'd')]


In [8]:
# Representation of this quiver

dim_vector = {'bias' : 1, 'a': 2, 'b': 4, 'c': 8, 'd': 2 }

maps = {('a', 'b') : np.random.rand(4, 2), 
        ('a', 'c') : np.random.rand(8, 2), 
        ('b', 'c') : np.random.rand(8, 4),
        ('c', 'd') : np.random.rand(2, 8),
        ('bias', 'b') : np.random.rand(4, 1),
        ('bias', 'c') : np.random.rand(8, 1),
        ('bias', 'd') : np.random.rand(2, 1)}

ex_rep = quiver_rep(quiv_ex, dim_vector, maps)
print(ex_rep.comp_dims_red())

{'bias': 1, 'a': 2, 'b': 3, 'c': 6, 'd': 2}


In [9]:
# Check the algorithm

Q_ex ,V_ex = QRDimRed(ex_rep)
print(ex_rep.comp_dims_red())
print([np.shape(Q_ex[i]) for i in Q_ex])
print([np.shape(V_ex.matrices[e]) for e in V_ex.matrices])

[('a', 'b'), ('a', 'c'), ('b', 'c'), ('c', 'd'), ('bias', 'b'), ('bias', 'c'), ('bias', 'd')]
['bias', 'a', 'b', 'c', 'd']
{'bias': 1, 'a': 2, 'b': 3, 'c': 6, 'd': 2}
[(1, 1), (2, 2), (4, 4), (8, 8), (2, 2)]
[(3, 2), (3, 1), (6, 2), (6, 3), (6, 1), (2, 6), (2, 1)]


In [10]:
lin_ff(ex_rep)

array([[6.79219992, 3.3533425 ],
       [8.87269454, 4.27084873]])

In [11]:
ex_rep.matrices[('c','d')] @ ( ex_rep.matrices[('a','c')] + ex_rep.matrices[('b','c')] @ ex_rep.matrices[('a','b')] )

array([[6.79219992, 3.3533425 ],
       [8.87269454, 4.27084873]])

In [12]:
lin_ff(V_ex)

array([[6.79219992, 3.3533425 ],
       [8.87269454, 4.27084873]])

# Example 2

In [13]:
# EXAMPLE
# Quiver with multiple inputs and outputs

vertex_list2 = ['a', 'b', 'c', 'd', 'e']
edge_list2 = [('a', 'c'), ('b','c'), ('c','d'), ('c', 'e')]

quiv_ex2 = quiver(vertex_list2, edge_list2)
# print(Q.vertices, Q.edges)

# Test the methods
print(quiv_ex2.get_incoming('a'), quiv_ex2.get_incoming('b'), quiv_ex2.get_incoming('c'), 
      quiv_ex2.get_incoming('d'), quiv_ex2.get_incoming('e'))
print(quiv_ex2.is_sink('a'), quiv_ex2.is_sink('b'), quiv_ex2.is_sink('c'), 
      quiv_ex2.is_sink('d'), quiv_ex2.is_sink('e'))
print(quiv_ex2.check_top_order())
print(quiv_ex2.sources)
quiv_ex2.add_bias()
print(quiv_ex2.vertices)
print(quiv_ex2.edges)

[] [] [('a', 'c'), ('b', 'c')] [('c', 'd')] [('c', 'e')]
False False False True True
True
{'b', 'a'}
['bias', 'a', 'b', 'c', 'd', 'e']
[('a', 'c'), ('b', 'c'), ('c', 'd'), ('c', 'e'), ('bias', 'c'), ('bias', 'd'), ('bias', 'e')]


In [14]:
# Representation of this quiver

dim_vector2 = {'bias' : 1, 'a': 1, 'b': 2, 'c': 8, 'd': 2 , 'e': 6}

maps2 = {('a', 'c') : np.random.rand(8, 1), 
        ('b', 'c') : np.random.rand(8, 2), 
        ('c', 'd') : np.random.rand(2, 8),
        ('c', 'e') : np.random.rand(6, 8),
        ('bias', 'c') : np.random.rand(8, 1),
        ('bias', 'd') : np.random.rand(2, 1),
        ('bias', 'e') : np.random.rand(6, 1)}

ex_rep2 = quiver_rep(quiv_ex2, dim_vector2, maps2)
print(ex_rep2.comp_dims_red())

{'bias': 1, 'a': 1, 'b': 2, 'c': 4, 'd': 2, 'e': 6}


In [15]:
# Check the algorithm

Q_ex2 ,V_ex2 = QRDimRed(ex_rep2)
print(ex_rep2.comp_dims_red())
print([np.shape(Q_ex2[i]) for i in Q_ex2])
print([np.shape(V_ex2.matrices[e]) for e in V_ex2.matrices])

[('a', 'c'), ('b', 'c'), ('c', 'd'), ('c', 'e'), ('bias', 'c'), ('bias', 'd'), ('bias', 'e')]
['bias', 'a', 'b', 'c', 'd', 'e']
{'bias': 1, 'a': 1, 'b': 2, 'c': 4, 'd': 2, 'e': 6}
[(1, 1), (1, 1), (2, 2), (8, 8), (2, 2), (6, 6)]
[(4, 1), (4, 2), (4, 1), (2, 4), (2, 1), (6, 4), (6, 1)]


In [16]:
lin_ff(ex_rep2)

array([[1.87700539, 1.6172634 , 2.45880627],
       [2.17175284, 1.82903726, 2.9235503 ],
       [2.39143943, 2.24588898, 3.45241048],
       [1.77890854, 1.74271559, 1.59900845],
       [2.04102758, 1.85888085, 1.8802813 ],
       [2.31927294, 2.15726015, 2.97002895],
       [2.20074782, 2.39043941, 3.40847452],
       [1.53417415, 1.80139388, 2.33877951]])

In [17]:
lin_ff(V_ex2)

array([[1.87700539, 1.6172634 , 2.45880627],
       [2.17175284, 1.82903726, 2.9235503 ],
       [2.39143943, 2.24588898, 3.45241048],
       [1.77890854, 1.74271559, 1.59900845],
       [2.04102758, 1.85888085, 1.8802813 ],
       [2.31927294, 2.15726015, 2.97002895],
       [2.20074782, 2.39043941, 3.40847452],
       [1.53417415, 1.80139388, 2.33877951]])

# Quiver Neural Networks

In [18]:
class RadAct(nn.Module):
    def __init__(self, eta = F.relu):
        super().__init__()
        self.eta = eta
        self.shift = 0 
        # Add internal bias/shift later
        
    def forward(self,x):
        # x: [Batch x Channel]
        # Radial activations
        r = torch.linalg.norm(x, dim=-1) 
        if torch.min(r) < 1e-6:
            r += 1e-6
        scalar = self.eta(r + self.shift) / r
        return x * scalar.unsqueeze(-1)   

In [19]:
class QuiverNN(nn.Module):
    
    def __init__(self, eta: float , Q: quiver, dims: Dict[str,int] ):
        super().__init__()
        self.eta = eta
        self.Q = Q
        sefl.dims = dims
        # Assert statement to check that dims is a dimension vector for Q
        
        # Reduced dimension vector
        


    def forward(self, x):
        h = x
        for lin,act in zip(self.layers[:-1], self.act_fns):
            h = act(lin(h))
        return self.output_layer(h)
    
    def set_weights(self, new_weights: quiver_rep):
        None 
    
    def set_activation_biases(self, new_biases: List[float]):    
        None

    def export_weights(self) -> quiver_rep:
        None
    
    def export_activation_biases(self) -> List[float]:
        None
    
    def export_reduced_weights(self) -> quiver_rep:
        None
    
    def transformed_network(self):
        None
        
    def reduced_network(self):
        None

In [20]:
quiver_rep?

# Scraps

In [21]:
class vertex:
    def __init__(self):
        None

In [22]:
a = vertex()

In [23]:
# Troubleshooting

W1 = np.random.rand(4, 2)
print(type(W1))
print(isinstance(W1, np.ndarray))

<class 'numpy.ndarray'>
True


In [24]:
a = np.random.rand(4,2)
b = np.random.rand(4,4)
a, b

(array([[0.70277045, 0.85277962],
        [0.24535073, 0.09595361],
        [0.06375216, 0.48481909],
        [0.06149499, 0.72172909]]),
 array([[0.13440801, 0.78495063, 0.58392161, 0.32813079],
        [0.2320391 , 0.92621226, 0.30614295, 0.86373322],
        [0.62560492, 0.49726806, 0.19195408, 0.47647759],
        [0.32803341, 0.37474466, 0.9560282 , 0.52778952]]))

In [25]:
c = np.hstack((a,b))
np.shape(c)

(4, 6)

In [26]:
b

array([[0.13440801, 0.78495063, 0.58392161, 0.32813079],
       [0.2320391 , 0.92621226, 0.30614295, 0.86373322],
       [0.62560492, 0.49726806, 0.19195408, 0.47647759],
       [0.32803341, 0.37474466, 0.9560282 , 0.52778952]])

In [27]:
b[:,:2]

array([[0.13440801, 0.78495063],
       [0.2320391 , 0.92621226],
       [0.62560492, 0.49726806],
       [0.32803341, 0.37474466]])

In [28]:
b[:,2:]

array([[0.58392161, 0.32813079],
       [0.30614295, 0.86373322],
       [0.19195408, 0.47647759],
       [0.9560282 , 0.52778952]])

In [29]:
np.shape(b[:,:2])

(4, 2)

In [30]:
np.shape(np.array([]))

(0,)