In [24]:
from typing import List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
dev = 'cpu'
if torch.cuda.is_available():
    dev = 'cuda:0'
print("Running on:",dev)
device = torch.device(dev)

Running on: cpu


# Quivers and quiver representations

In [49]:
class quiver:
    def __init__(self, vertices : List, edges : List):
        self.vertices = vertices
        # Add assert to check no repeated vertices
        # E.g. assert len(set(vertices)) == len(vertices)
        
        self.edges = edges
        # Add assert to check that edges is a list of pairs
        # First entry of the pair is the sourse, second is the target
        # Source and target of each edge should be in the vertex set      
        # Separate class for edges? vertices?
        
    def check_acyclic(self):
        None
        # One way: find all sources, do depth-first search
    
    def get_incoming(self, vertex):
        assert vertex in self.vertices, "No such vertex found"
        return [e for e in self.edges if e[2] == vertex]
        # Can get the incoming neighbors as [e[1] for e in self.get_incoming(vertex)]
        

In [50]:
#Quiver with skip connections and no bias

vertex_list = ['a', 'b', 'c']
edge_list = [('a', 'b'), ('a','c'), ('b','c')]

Q = quiver(vertex_list, edge_list)

In [51]:
print(Q.vertices, Q.edges)

['a', 'b', 'c'] [('a', 'b'), ('a', 'c'), ('b', 'c')]


In [52]:
dim_vector = {'a': 2, 'b': 4, 'c':2 }
assert set(dim_vector.keys()) == set(vertex_list), "Inappropriate dimension vector"

In [56]:
class quiver_rep:
    def __init__(self, Q: quiver, dims: Dict, matrices: Dict):
        self.Q = Q
        self.dims = dims
        self.matrices = matrices
        
        # Check the dimension vector
        assert len(dims) == len(Q.vertices), "Inappropriate dimension vector"
        for v in dims:
            assert v in Q.vertices, "Inappropriate dim vector"
            assert isinstance(dims[v], int) and dims[v] >=0, "Inappropriate dim vector"
            
        # Check the matrices
        assert len(matrices) == len(Q.edges), "Matrices error"
        for e in matrices:
            assert e in Q.edges, "Matrices error"
            assert isinstance(matrices[e], np.array), "Matrices error" # May need fixing
            assert np.shape(matrices[e]) == (dims[e[2]], dims[e[1]])
            
            
    def compute_reduced_rep(self):
        None
        

# Quiver Neural Networks

In [54]:
class RadAct(nn.Module):
    def __init__(self, eta = F.relu):
        super().__init__()
        self.eta = eta
        self.shift = 0 
        # Add internal bias/shift later
        
    def forward(self,x):
        # x: [Batch x Channel]
        r = torch.linalg.norm(x, dim=-1) 
        if torch.min(r) < 1e-6:
            r += 1e-6
        scalar = self.eta(r + self.shift) / r
        return x * scalar.unsqueeze(-1)   

In [55]:
class QuiverNN(nn.Module):
    
    def __init__(self, eta:float , Q: quiver, dims: Dict ):
        super().__init__()
        self.eta = eta
        self.Q = Q
        sefl.dims = dims
        # Assert statement to check that dims is a dimension vector for Q
        
        # Reduced dimension vector
        # self.dims_red = [self.dims[0]]


    def forward(self, x):
        h = x
        for lin,act in zip(self.layers[:-1], self.act_fns):
            h = act(lin(h))
        return self.output_layer(h)
    
    def set_weights(self, new_weights: quiver_rep):
        None 
    
    def set_activation_biases(self, new_biases: List[float]):    
        None

    def export_weights(self) -> quiver_rep:
        None
    
    def export_activation_biases(self) -> List[float]:
        None
    
    def export_reduced_weights(self) -> quiver_rep:
        None
    
    def transformed_network(self):
        None
        
    def reduced_network(self):
        None

# Scraps

In [13]:
class vertex:
    def __init__(self):
        None

In [6]:
a = vertex()