In [74]:
import torch
import numpy as np
import torch.nn.functional as F

# Graph Class Definition

It should consist of:

- `nodes`: All current node in the existing tree with their embedding (V set)
- `node_types`: all nodes in the existing tree with either leaf type or internal type
- `edge_source`: the start node of all existing edges
- `edge_dest`: the end node of all existing edges
- `leaf`: the remaining species the model could choose from for leaf nodes

**No branch length implemented for now!**

In [53]:
class Graph:
    def __init__(self, batch_size, embed_size, num_taxon, device):
        if batch_size is None:
            return
        self.batch_size = batch_size
        self.device = device
        # init all components in the graph
        # existing nodes in the graph: each row contains the embedding vector of the original sequence for 1 node
        self.nodes = torch.zeros(0, embed_size, dtype=torch.float32, device=device)
        # type of the existing nodes: 1 dim array, each index is a type of the node, either internal nodes or leaf nodes
        self.node_types = torch.zeros(0, dtype=torch.uint8, device=device)
        # starting point of each node:1 dim array, number indicate the origin node of edge i
        self.edge_source = torch.zeros(0, dtype=torch.long, device=device)
        # ending point of each node:1 dim array, number indicate the destination node of edge i
        self.edge_dest = torch.zeros(0, dtype=torch.long, device=device)
        # all remaining species to choose from
        self.leaf = torch.zeros(num_taxon, embed_size, dtype=torch.float32, device=device)
        
        # FIXMEleaf
        # No edge feature for now...
        # not sure what it is for now, could be the graph selected for the current batch, will update
        self.owner_masks = torch.zeros(batch_size, 0, dtype=torch.uint8, device=device)
        self.last_inserted_node = torch.zeros(batch_size, dtype=torch.long, device=device)
        # current running graph in the batch
        self.running = torch.ones(batch_size, device=device, dtype=torch.uint8)


# Model Definition

## Propagator:
The message passing part of the model, we update the node vectors based on the current existing graph

If no node or no edges existed in the tree, the following model will not run and directly return the current graph feature matrix

For the layers of this part of the model:

1. `message_node`: get the node message from the vector embedding (embed_size -> message_size)
   - The message that comes out of this layer will be a vector of message_size for each existing nodes
   - The message that feeds into the next layer will be the addition of source and destination nodes from these vectors
3. `message_layer`: addition layer for `message_node` (message_size -> message_size)
4. `node_update_fn`: The Gated Recurrent Unit cells that pass the message along neighboring nodes (message_size -> embed_size)

We also include reverse message passing here.

The return matrix is the updated feature matrix of the existing tree.

**Unsure if how exactly `_reset_parameters` works, will investigate once the model is runnable**

In [54]:
class Propagator(torch.nn.Module):
    def __init__(self, embed_size, dropout):
        super().__init__()
        # The message size in the message-passing
        self.message_size = embed_size * 2
        # update all node vectors back to original embed size
        self.node_update_fn = torch.nn.GRUCell(self.message_size, embed_size)

        # Get the node message through a linear layer
        self.message_node = torch.nn.Linear(embed_size, self.message_size, bias=False)
        # second layer of message passing
        self.message_layer = torch.nn.Sequential(
            torch.nn.Tanh(),
            torch.nn.Linear(self.message_size, self.message_size)
        )
        self.dropout = torch.nn.Dropout(dropout)
        self._reset_parameters(embed_size)

    @staticmethod
    def _node_update_mask(graph: Graph, mask_override: torch.ByteTensor):
        return graph.owner_masks[graph.running if mask_override is None else mask_override].sum(0)>0

    def forward(self, graph: Graph, mask_override: torch.ByteTensor = None):
        # no node or edge in the graph, no need for message passing
        if graph.nodes.shape[0]==0 or graph.edge_source.shape[0]==0:
            return graph
        # get all node features from embedding
        node_features = self.message_node(graph.node)
        # get the source and destestion node features
        e1 = node_features.index_select(dim=0, index=graph.edge_source)
        e2 = node_features.index_select(dim=0, index=graph.edge_dest)
        messages = e1 + e2
        messages = self.message_layer(messages)
        messages = self.dropout(dropout)
        
        # concatnate the messages for all nodes
        # now the matrix contains a vector for each existing nodes, a result of addition of its destination and source nodes
        inputs = torch.zeros(graph.nodes.shape[0], self.message_size, device=graph.nodes.device,
                             dtype=graph.nodes.dtype).index_add_(0, graph.edge_dest, messages).\
                             index_add_(0, graph.edge_source, messages)

        inputs = self.dropout(inputs)

        # We also want reverse message passing
        r1 = node_features.index_select(dim=0, index=graph.edge_dest)
        r2 = node_features.index_select(dim=0, index=graph.edge_source)
        reverse_messages = e1 + e2
        reverse_messages = self.message_layer(reverse_messages)
        reverse_messages = self.dropout(dropout)

        reverse_inputs = torch.zeros(graph.nodes.shape[0], self.message_size, device=graph.nodes.device,
                             dtype=graph.nodes.dtype).index_add_(0, graph.edge_source, reverse_messages).\
                             index_add_(0, graph.edge_dest, reverse_messages)

        reverse_inputs = self.dropout(reverse_inputs)

        # Add up both messages
        full_input = inputs + reverse_inputs
        
        # now we do message passing
        updated_nodes = self.node_update_fn(full_input, graph.nodes)
        # put the updated node into the graph set
        # only update the masked node, keep the unmasked one as original
        graph.nodes = torch.where(self._node_update_mask(graph, mask_override).unsqueeze(-1), updated_nodes, graph.nodes)

        return graph

    def _reset_parameters(self, embed_size):
        msg_gain = torch.nn.init.calculate_gain("tanh")
        #FIXME: not sure if it should be embed_size*3 or embed_size*2
        xavier_init(self.message_node, msg_gain, embed_size * 3, self.message_size)
        xavier_init(self.message_layer[1], 1)

        self.node_update_fn.bias_hh.data.fill_(0)
        self.node_update_fn.bias_ih.data.fill_(0)
        self.node_update_fn.bias_hh[:embed_size].data.fill_(1)


## Multilayer Propagator:
This class simply run propagator multiple times, the number of iteration is decided by me

In [55]:
class MultilayerPropagator(torch.nn.Module):
    def __init__(self, embed_size, n_iter, dropout):
        super().__init__()
        ## run propagators for n_iter times
        self.propagators = torch.nn.ModuleList([Propagator(embed_size, dropout) for i in range(n_iter)])

    def forward(self, graph: Graph, *args, **kwargs):
        for p in self.propagators:
            graph = p(graph, *args, **kwargs)
        return graph

## Aggregator:
This part of the model get the propagated node vectors and aggregate them into 1 single vector of higher dimension

The output of this model will be used as inputs for all the decision making NN in the later part of the bigger model

For the layers of this model:

1. `aggregate`: map the embedded(propagated) node vectors to a higher dimension (embed_size -> aggregated_size)
2. `gated_sum`: gating vector for the gated sum (embed_size -> aggregated_size)

What comes out of the above 2 NN is 2 `num_node * aggregated_size` matrices $g_{v}$ and $h_{v}$. We do dot product first, and add all rows of the dot product together, ie $\sum_{v\in V} g_{v} \odot h_{v}$ 

In the end, it becomes one `1*aggregated_size` vector for each tree.

In [56]:
class Aggregator(torch.nn.Module):
    def __init__(self, embed_size, dropout, bias_if_empty=False):
        super.__init__()

        self.aggregated_size = embed_size * 2
        # map embedding to a higher dimension
        self.aggregate = torch.nn.Linear(embed_size, aggregated_size)
        # part of the gated sum
        self.gated_sum = torch.nn.Sequential(
            torch.nn.Linear(embed_size, aggregated_size),
            torch.nn.Sigmoid()
        )
  
        self.bias_if_empty = torch.nn.Parameter(torch.Tensor(1, aggregated_size)) if bias_if_empty else None
        self.dropout = torch.nn.Dropout(dropout)

        self._reset_parameters()

    def forward(self, graph: Graph):
        # if no nodes exists in the current tree, return a vector full of zeros or the bias
        if graph.nodes.shape[0] == 0:
            if self.bias_if_empty is not None:
                return self.bias_if_empty.expand(graph.batch_size, -1)
            else:
                return torch.zeros(graph.batch_size, self.aggregated_size, dtype=torch.float32, device=graph.device)

        
        gates = self.gated_sum(graph.nodes)
        feature = self.aggregate(graph.nodes)

        # get the gated sum from the 2 NN
        fmask = graph.owner_masks.float()
        gated_sum = torch.mm(fmask, feature * gates)

        return self.dropout(gated_sum)

    def _reset_parameters(self):
        xavier_init(self.transform, 1)
        xavier_init(self.gate[0], 1)
        self.gate[0].bias.data.fill_(1)
        if self.bias_if_empty is not None:
            torch.nn.init.normal_(self.bias_if_empty)

## Add_Node

This part of the model adds a node to the graph, it works as follows:

1. Do a `propagator` before everything, retrieve the upgraded node vectors
2. `decision_aggregator` give a single vector for the entire updated graph, and then pass it into `node_decision`. The output is either "0", for no node should be added; "1", a leaf node should be added; or "2", a internal node should be added
3. If leaf node, get a logit value from the existing nodes and the species node, the NN should return a logit number for each species that we could choose from, then we do a softmax to choose a species, and remove the choosen one from the species set.
4. If internal node, we generate the new node vector by combining the parameter of the new node NN and the aggregated vector for the existing graph.
5. We then add the new node to the graph and update the owner mask.

For the module of this part:

1. `node_decision`: given a aggregated vector of the current graph, output a number from 0,1,2
2. `node_embedding`: the output seems to be very very small number, maybe noise? or perhaps unnecessary.
3. `init_1`: take `node_embedding` as input, output another vector with `embed_size`
4. `init_2`: take the aggregated graph vector as input, combine with output of `init_1` gives the embedding for the newly generated internal node
5. `leaf_decision_species`: Part of the decision of "which species to choose from" that comes from the species node embedding
6. `leaf_decision_tree`: Part of the decision of "which species to choose from" that comes from the existing tree

In [86]:
class Add_Node(torch.nn.Module):
    def __init__(self, embed_size, aggregated_size, propagate_steps, dropout):
        super().__init__()
        self.propagator = MultilayerPropagator(embed_size, propagate_steps, dropout)
        self.decision_aggregator = Aggregator(embed_size, aggregated_size, dropout, bias_if_empty=True)
        self.generate_aggregator = Aggregator(embed_size, aggregated_size, dropout, bias_if_empty=True)
        self.leaf_node_aggregator = Aggregator(embed_size, aggregated_size, dropout, bias_if_empty=True)
        
        # Decide whether to:
        # - Not add any node: 0
        # - Add an internal node: 1
        # - Add a leaf node: 2
        self.node_decision = torch.nn.Linear(aggregated_size, 3)
        # get the parameter of the node embedding
        self.node_embedding = torch.nn.Parameter(torch.Tensor(embed_size))
        # NN for generating the internal nodes
        self.init_1 = torch.nn.Linear(embed_size, embed_size)
        self.init_2 = torch.nn.Linear(aggregated_size, embed_size, bias=False)

        # Decide which species to choose
        self.leaf_decision_species = torch.nn.Linear(embed_size, 1)
        self.leaf_decision_tree = torch.nn.Linear(aggregated_size, 1, bias=False)
        
        self._reset_parameters(embed_size, aggregated_size)

    def forward(self, graph:Graph):
        loss = 0
        running = graph.running
        # do a message passing before start the decision
        graph = self.propagator(graph)
        # make decision: add node or not? if so, what type?
        new_node_type = self.node_decision(self.decision_aggregator(graph))
        # Force model to add node if existing tree is empty
        if graph.node.shape[0] == 0:
            new_node_type[:, 0] = float("-inf")

        selected_node_type = sample_softmax(new_node_type)
        # if selected type is 0, terminate the whole algorithm
        graph.running = (selected_node_type != 0) & graph.running
        if graph.running.any():
            # Leaf node
            if selected_node_type == 1:
                # aggregate the existing tree into 1 vector
                species_aggregator = self.leaf_node_aggregator(graph)
                # get a logit for each species in the set
                # the logit will be a matrix of batch_size*num_species
                logits = self.leaf_decision_species(graph.leaf).unsqueeze(0) + self.leaf_decision_tree(species_aggregator).unsqueeze(1)
                logits = logits.view(logits.shape[0], -1)
                # Dont really do anything, will remove if unnecessary
                owner_mask_tmp = graph.owner_masks.unsqueeze(-1).expand(-1, -1, 1).contiguous().view(graph.batch_size, -1)
                # get the index of the selected species
                selected_species = masked_softmax(logits, owner_mask_tmp)
                # get the actual embedding of the selected species
                new_feature = graph.leaf[selected_species]
                
                # Now we need to remove the selected species from the species list
                graph.leaf = torch.cat((graph.leaf[:selected_species], graph.leaf[selected_species+1:]))
            else:
                # internal nodes
                new_embedding = self.node_embedding
                # get the vector representing the whole graph 
                init_feature = self.generate_aggregator(graph)
                # We generate the feature for this internal node based on the existing
                new_feature = self.init_1(new_embedding) + self.init_2(init_feature)

            # Now we have the node feature for the new node, we add it
            mask = graph.running
            index_seq = torch.arange(mask.long().sum(), device = graph.device, dtype = torch.long) + \
                    (graph.nodes.shape[0] if graph.nodes is not None else 0)
            last_nodes = torch.zeros(graph.batch_size, device = graph.device, dtype = torch.long)
            last_nodes[mask] = index_seq

            # Select last node if updated
            graph.last_inserted_node = torch.where(mask, last_nodes, graph.last_inserted_node)
            # Select the new generated node features 
            new_node = new_feature[mask]
            # So here is how this line of code works...
            # mask is all the graph in the batch that is currently running, eg:[0,0,1,1,1,0,0,0,0,0]
            # mask.nonzero() gives the index of all the nonzero values, eg:[[2], [3], [4]]
            # mask.nonzero().squeeze(-1) remove the last dimension, eg: [2,3,4]
            # With the one_hot and batch_size classes, gives a one_hot matrix, eg: [0,0,1,0,0,0,0,0,0,0]
            #                                                                      [0,0,0,1,0,0,0,0,0,0]
            #                                                                      [0,0,0,0,1,0,0,0,0,0]
            # transpose() is just the transpose of this matrix
            owner_mask = F.one_hot(mask.nonzero().squeeze(-1), graph.batch_size).transpose(0,1).byte()
            graph.nodes = torch.cat((graph.nodes, new_nodes), dim=0)
            graph.owner_masks = torch.cat((graph.owner_masks, owner_masks), dim=1)
                
        return graph

    def _reset_parameters(self, embed_size, aggregated_size):
        torch.nn.init.normal_(self.node_type_embedding)
        xavier_init(self.init_1, 1, embed_size + aggregated_size, embed_size)
        xavier_init(self.init_2, 1, embed_size + aggregated_size, embed_size)
        xavier_init(self.node_decision, 1)
        xavier_init(self.leaf_decision_species, 1, embed_size*2, 1)
        xavier_init(self.leaf_decision_tree, 1, embed_size + aggregated_size, 1)
        

## Add_Edge

This part decide whether to add an edge, and where the edge should connect the last added node to the existing graph

The mudules of this part work as follows:

1. `addedge_existing`: Part of the "add edge or not" decision that comes from the already existing graph, takes the aggregated vector as input
2. `addedge_new`: Part of the "add edge or not" decision that comes from the last added node (The new node), takes the node vector as input
3. `where_existing`: Part of the "where to add edge" decision that comes from the existing graph, takes all node vectors as input
4. `where_new`: Part of the "where to add edge" decision that comes from the new node.

To know which node we are connecting the new node to, we get a vector of logits from `where` modules. In this vector, each number represents a logit value for each node in the graph, we then do softmax and select the biggest number.


In [87]:
class Add_Edge(torch.nn.Module):
    def __init__(self, embed_size, aggregated_size, propagate_steps, dropout):
        super().__init()
        self.n_max_edges = 2
        self.propagator = MultilayerPropagator(embed_size, propagate_steps, dropout)
        self.edge_decision_aggregator = Aggregator(embed_size, aggregated_size, dropout)
        
        self.addedge_existing = torch.nn.Linear(aggregated_size, 1)
        self.addedge_new = torch.nn.Linear(embed_size, 1, bias=False)

        self.where_existing = torch.nn.Linear(embed_size, 1)
        self.where_new = torch.nn.Linear(embed_size, 1, bias=False)

        self._reset_parameters(embed_size, aggregated_size)

    def forward(self, graph: Graph):
        running = graph.running
        #Current added edge, should not exceed 2
        added_edge = 0
        new_node = graph.nodes.index_select(0, graph.last_inserted_node)

        # Decision process to add edges until termination
        while True:
            # A round of messing passing
            graph = self.propagator(graph, running)
            # need to consider both existing graph and the current new node
            new_edge_to_add = (self.addedge_existing(self.edge_decision_aggregator(graph)) +
                               self.addedge_new(new_nodes)).squeeze(-1)

            # give a true/false on if a new edge is needed or not
            add_or_not = sample_binary(new_edge_to_add)

            # Check if there are already 2 edges added
            if added_edge >= self.n_max_edges:
                add_or_not = torch.zeros_like(add_or_not)

            # termination condition
            running = running & add_or_not
            if not running.any():
                break

            # edge_logit -> [batch_size, num_nodes]
            # We do softmax on it and get the biggest number
            edge_logit = self.where_existing(graph.nodes).unsqueeze(0) + self.where_new(new_node).unsqueeze(1)
            edge_logit = edge_logit.view(edge_logit.shape[0], -1)

            # I don't think this line does anyting -- it returns the original owner_mask, but leave it for now
            owner_mask = graph.owner_mask.unsqueeze(-1).contiguous().view(graph.batch_size, -1)
            # Do the softmax, selected node is the index of the node that will be connected to
            selected_node = masked_softmax(edge_logit, owner_mask)

            # now we add the edge to the set
            selected_dest = graph.last_inserted_node[running]
            selected_src = selected_node[running]

            graph.edge_dest = torch.cat((graph.edge_dest, selected_dest), 0)
            graph.edge_source = torch.cat((graph.edge_source, selected_src), 0)
            add_index += 1
        
        return graph

    def _reset_paramters(self, embed_size, aggregated_size):
        xavier_init(self.addedge_existing, 1, embed_size + aggregated_size, 1)
        xavier_init(self.addedge_new, 1, embed_size + aggregated_size, 1)
        xavier_init(self.where_existing, 1, embed_size * 2, 1)
        xavier_init(self.where_new, 1, embed_size * 2, 1)
            

## Tree_Generation


In [None]:
class Tree_Generation(torch.nn.Module):
    def __init__(self, embed_size, propagate_steps, dropout=0.2):
        super().__init__()
        self.aggregated_size = embed_size*2
        self.embed_size = embed_size

        self.add_edge = Add_Edge(embed_size, self.aggregated_size, propagate_steps, dropout)
        self.add_node = Add_Node(embed_size, self.aggregated_size, propagate_steps, dropout)

    def forward(self, )


In [58]:
def xavier_init(layer, scale, n_inputs=None, n_outputs=None):
    n_inputs = n_inputs if n_inputs is not None else layer.weight.shape[1]
    n_outputs = n_outputs if n_outputs is not None else layer.weight.shape[0]
    limits = scale * math.sqrt(6.0 / (n_inputs + n_outputs))
    layer.weight.data.uniform_(-limits, limits)

    if layer.bias is not None:
        torch.nn.init.normal_(layer.bias)

In [59]:
def sample_softmax(tensor, dim=-1):
    eps=1e-20

    # Built in gumbel softmax could end up with lots of nans. Do it manually here.
    noise = -torch.log(-torch.log(torch.rand_like(tensor)+eps) + eps)
    res = F.softmax(tensor + noise, dim=-1)
    _, res = res.max(dim=dim)
    return res

In [60]:
def sample_binary(tensor):
    tensor = torch.sigmoid(tensor)
    return torch.rand_like(tensor) < tensor

In [61]:
def mask_softmax_input(tensor, mask):
    return torch.where(mask, tensor, torch.full([1], float("-inf"), dtype=tensor.dtype, device=tensor.device))

In [62]:
def masked_softmax(tensor, mask):
    tensor = mask_softmax_input(tensor, mask)
    return sample_softmax(tensor)

In [114]:
t = torch.tensor([[0.1, 0.3, 0.2], [0.5, 0.3, 0.11], [0.02, 0.8, 0.13]])
target = torch.tensor([1, 0, 1])

In [115]:
def remap_pad(t, pad_char, transform = lambda x: x+1):
    return torch.where(t != pad_char, transform(t), torch.zeros(1, dtype=t.dtype, device=t.device))
F.cross_entropy(t, target.long())

tensor(0.8649)