In [1]:
from typing import List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
dev = 'cpu'
if torch.cuda.is_available():
    dev = 'cuda:0'
print("Running on:",dev)
device = torch.device(dev)

Running on: cpu


# Quivers and quiver representations

In [2]:
class quiver:
    def __init__(self, vertices : List, edges : List):
        self.vertices = vertices
        # Add assert to check no repeated vertices
        # E.g. assert len(set(vertices)) == len(vertices)
        
        self.edges = edges
        # Add assert to check that edges is a list of pairs
        # First entry of the pair is the sourse, second is the target
        # Source and target of each edge should be in the vertex set      
        # Separate class for edges? vertices?
        
    def check_acyclic(self):
        None
        # One way: find all sources, do depth-first search
        
    def check_top_order(self):
        indices = {}
        for i,v in enumerate(self.vertices):
            indices[v] = i
        for e in self.edges:
            if indices[e[0]] > indices[e[1]]:
                return False
        return True
        
    
    def get_incoming(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return [e for e in self.edges if e[1] == vertex]
        # Can get the incoming neighbors as [e[1] for e in self.get_incoming(vertex)]
        
    def is_sink(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return all([e[0] != vertex for e in self.edges])

        

In [3]:
class quiver_rep:
    def __init__(self, quiver: quiver, dims: Dict, matrices: Dict):
        self.Q = quiver
        self.dims = dims
        self.matrices = matrices
        
        # Check the dimension vector
        assert len(dims) == len(Q.vertices), "Inappropriate dimension vector"
        for v in dims:
            assert v in Q.vertices, "Inappropriate dim vector"
            assert isinstance(dims[v], int) and dims[v] >=0, "Inappropriate dim vector"
            
        # Check the matrices
        assert len(matrices) == len(Q.edges), "Matrices error"
        for e in matrices:
            assert e in Q.edges, "Matrices error"
            assert isinstance(matrices[e], np.ndarray), "Matrices error" # May need fixing
            assert np.shape(matrices[e]) == (dims[e[1]], dims[e[0]]), "Dimension error"
            
            
    def dims_red(self) -> Dict:
        
        assert Q.check_top_order(), "Order of the vertices is not topological"

        d_red = {}
        for i in Q.vertices:
            if Q.is_sink(i):
                d_red[i] = self.dims[i]
                
            else:
                incoming = Q.get_incoming(i)
                if incoming:
                    d_red[i] = min(self.dims[i], sum([d_red[e[0]] for e in incoming]) )
                else:
                    d_red[i] = self.dims[i]
                
        return d_red
        

In [4]:
# EXAMPLE
# Quiver with skip connections and no bias

vertex_list = ['a', 'b', 'c', 'd']
edge_list = [('a', 'b'), ('a','c'), ('b','c'), ('c', 'd')]

Q = quiver(vertex_list, edge_list)
# print(Q.vertices, Q.edges)

In [5]:
# Test the methods

print(Q.get_incoming('a'), Q.get_incoming('b'), Q.get_incoming('c'), Q.get_incoming('d'))
print(Q.is_sink('a'), Q.is_sink('b'), Q.is_sink('c'), Q.is_sink('d'))
print(Q.check_top_order())

[] [('a', 'b')] [('a', 'c'), ('b', 'c')] [('c', 'd')]
False False False True
True


In [6]:
# Representation of this quiver

dim_vector = {'a': 2, 'b': 4, 'c': 8, 'd': 2 }

maps = {('a', 'b') : np.random.rand(4, 2), 
        ('a', 'c') : np.random.rand(8, 2), 
        ('b', 'c') : np.random.rand(8, 4),
        ('c', 'd') : np.random.rand(2, 8)}

ex_rep = quiver_rep(Q, dim_vector, maps)
print(ex_rep.dims_red())

{'a': 2, 'b': 2, 'c': 4, 'd': 2}


In [7]:
# Troubleshooting

W1 = np.random.rand(4, 2)
print(type(W1))
print(isinstance(W1, np.ndarray))

<class 'numpy.ndarray'>
True


# Dimensional reduction algorithm

In [8]:
# This will eventually be incorporated in the quiver_rep class

def QRDimRed(W : quiver_rep) -> quiver_rep:
    dims = W.dims
    matrices = W.matrices
    quiver = W.Q
    vertices = quiver.vertices
    edges = quiver.edges
    
    dims_red = W.dims_red()
    # print(dims, dims_red)
    
    assert quiver.check_top_order(), "Order of the vertices is not topological"
    
    # Q will be a dictionary mapping each vertex to an orthogonal matrix
    Q = {}
    
    # V wil be the matrices of a representation, mapping each edge to a matrix
    V = {}
    
    print(quiver.edges)
    print(quiver.vertices)
    
    for i in vertices:
        incoming = quiver.get_incoming(i)
        
        if incoming == []:
            print(i, " is a source vertex")
            Q_cur = np.eye(dims[i])
            
        else:
            M = np.array([])
            for e in incoming:
                None
                # Transform weights on incoming edges
                # M_e = W_e @ Q_j @ inc_j
                # Extend M = [M M_e]
            
            if dims_red[i] < dims[i]: 
                print("Reduce at vertex ", i)
                # Q_cur,R_cur = np.linalg.qr(M, mode="complete")
                None
                
            elif not quiver.is_sink(i):
                print(i, " is not a sink")
                # Q_cur,R_cur = np.linalg.qr(M)
                None
                
            else:
                print(i, " is a sink")
                Q_cur = np.eye(dims[i])
                # R_cur = M
                None
                
            # Process and add to the dictionaries
            # Q[i] = Q_cur
            for e in incoming:
                None                       
                # Extract V_e from R_i for all incoming edges e
                # V[e] = V_e
                
                
    # return Q, V

    None


In [9]:
# Back to the running example

QRDimRed(ex_rep)

[('a', 'b'), ('a', 'c'), ('b', 'c'), ('c', 'd')]
['a', 'b', 'c', 'd']
a  is a source vertex
Reduce at vertex  b
Reduce at vertex  c
d  is a sink


# Quiver Neural Networks

In [10]:
class RadAct(nn.Module):
    def __init__(self, eta = F.relu):
        super().__init__()
        self.eta = eta
        self.shift = 0 
        # Add internal bias/shift later
        
    def forward(self,x):
        # x: [Batch x Channel]
        r = torch.linalg.norm(x, dim=-1) 
        if torch.min(r) < 1e-6:
            r += 1e-6
        scalar = self.eta(r + self.shift) / r
        return x * scalar.unsqueeze(-1)   

In [11]:
class QuiverNN(nn.Module):
    
    def __init__(self, eta:float , Q: quiver, dims: Dict ):
        super().__init__()
        self.eta = eta
        self.Q = Q
        sefl.dims = dims
        # Assert statement to check that dims is a dimension vector for Q
        
        # Reduced dimension vector
        # self.dims_red = [self.dims[0]]


    def forward(self, x):
        h = x
        for lin,act in zip(self.layers[:-1], self.act_fns):
            h = act(lin(h))
        return self.output_layer(h)
    
    def set_weights(self, new_weights: quiver_rep):
        None 
    
    def set_activation_biases(self, new_biases: List[float]):    
        None

    def export_weights(self) -> quiver_rep:
        None
    
    def export_activation_biases(self) -> List[float]:
        None
    
    def export_reduced_weights(self) -> quiver_rep:
        None
    
    def transformed_network(self):
        None
        
    def reduced_network(self):
        None

# Scraps

In [12]:
class vertex:
    def __init__(self):
        None

In [13]:
a = vertex()