# Shortest-Path

#### Fix path and import caldera

In [1]:
from os.path import join, isfile
import os
import sys

def find_pkg(name: str, depth: int):
    if depth <= 0:
        ret = None
    else:
        d = ['..'] * depth
        path_parts = d + [name, '__init__.py']
        
        if isfile(join(*path_parts)):
            ret = d
        else:
            ret = find_pkg(name, depth-1)
    return ret

def find_and_ins_syspath(name: str, depth: int):
    path_parts = find_pkg(name, depth)
    if path_parts is None:
        raise RuntimeError("Could not find {}. Try increasing depth.".format(name))
    path = join(*path_parts)
    if path not in sys.path:
        sys.path.insert(0, path)

try:
    import caldera
except ImportError:
    find_and_ins_syspath('caldera', 3)

import caldera

## Data Generation

Generate and visualize data

In [1]:
pass # generate data here

### Creating Dataset

Generate dataset

In [None]:
pass

## Graph Network

We are going to build a graph network to handle the data we just created.

We are going to use a flexible `encoder -> core[x] -> decoder` architecture for this problem. The architecture consists of 4 main networks, the **encoder**, **core**, **decoder**, and **output_transform** networks. The **encoder** encodes graph data inputs into arbitrary shapes. The **core** is the central graph message processing network. The **decoder** decodes encoded data. Finally, the **output_transform** transformed decoded data for final output. 

Setting up the network looks like the following:

``` python
class Network(torch.nn.Module):

    def __init__(...):
        super().__init__()
    
        self.config = {...}

        self.encoder = ...
        self.core = ...
        self.decoder = ...
        self.out_transform = ...  

    def forward(self, data, steps, save_all: bool = False):
        """The encoder -> core -> decode loop"""
        encoded = self.encoder(data) # encode data
        
        outputs = []
        for _ in range(steps):
            latent = self.core(encoded)
            encoded = self.decoder(latent)
            outputs.append(self.out_transform(latent)
        return outputs
```

### Flex Modules and Flex Dimensions

Setting up this network with the correct dimensions can become tricky, so we introduce a new module, the `Flex` module, which can resolve unknown dimensions on runtime. To make a module a `Flex` module, we just call `Flex` with any `torch.nn.Module`, as in `Flex(torch.nn.Linear)` or `Flex(MyAwesomeModule)`. To initialize the module with unknown dimensions, you use the flexible dimension object `Flex.d` in places where the dimension is to be resolve on runtime, as in `Flex(torch.nn.Linear)(Flex.d(), 10)`.

In [2]:
from caldera.blocks import Flex
import torch

FlexLinear = Flex(torch.nn.Linear)

linear0 = torch.nn.Linear(3, 10)
flex_linear0 = FlexLinear(Flex.d(), 10)
print(linear0)
print(flex_linear0)

ModuleNotFoundError: No module named 'caldera'

Notice that the FlexBlock indicates it is current unresolved. To resovle it, we need to provide it with a data example. You'll see the module is now resolved.

In [3]:
example = torch.zeros((1, 10))

flex_linear0(example)

print(flex_linear0)

NameError: name 'torch' is not defined

### Aggregators

Aggregators are layers that indicate how data is processed and aggregated between neighbors.

### Final Network

Build the final network

In [12]:
from caldera.blocks import NodeBlock, EdgeBlock, GlobalBlock
from caldera.blocks import AggregatingNodeBlock, AggregatingEdgeBlock, AggregatingGlobalBlock
from caldera.blocks import MultiAggregator
from caldera.blocks import Flex
from caldera.models import GraphCore, GraphEncoder
from caldera.defaults import CalderaDefaults as defaults

In [13]:
import torch
from caldera.defaults import CalderaDefaults as defaults
from caldera.blocks import Flex, NodeBlock, EdgeBlock, GlobalBlock, MLP, AggregatingEdgeBlock, AggregatingNodeBlock, \
    MultiAggregator, AggregatingGlobalBlock
from caldera.models import GraphEncoder, GraphCore
from caldera.data import GraphBatch


class Network(torch.nn.Module):
    def __init__(
            self,
            latent_sizes=(32, 32, 32),
            out_sizes=(1, 1, 1),
            latent_depths=(1, 1, 1),
            dropout: float = None,
            pass_global_to_edge: bool = True,
            pass_global_to_node: bool = True,
            activation=defaults.activation,
            out_activation=defaults.activation,
            edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]),
            edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
            node_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
            aggregator_activation=defaults.activation,
    ):
        super().__init__()
        self.config = {
            "sizes": {
                'latent': {
                    "edge": latent_sizes[0],
                    "node": latent_sizes[1],
                    "global": latent_sizes[2],
                    "edge_depth": latent_depths[0],
                    "node_depth": latent_depths[1],
                    "global_depth": latent_depths[2],
                },
                'out': {
                    'edge': out_sizes[0],
                    'node': out_sizes[1],
                    'global': out_sizes[2],
                    'activation': out_activation,
                }
            },
            'activation': activation,
            "dropout": dropout,
            "node_block_aggregator": edge_to_node_aggregators,
            "global_block_to_node_aggregator": node_to_global_aggregators,
            "global_block_to_edge_aggregator": edge_to_global_aggregators,
            "aggregator_activation": aggregator_activation,
            "pass_global_to_edge": pass_global_to_edge,
            "pass_global_to_node": pass_global_to_node,
        }

        ###########################
        # encoder
        ###########################

        self.encoder = self._init_encoder()
        self.core = self._init_core()
        self.decoder = self._init_encoder()
        self.output_transform = self._init_out_transform()

        self.output_transform = GraphEncoder(
            EdgeBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid()
                )
            ),
            NodeBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid()
                )
            ),
            GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
        )

    def _init_encoder(self):
        return GraphEncoder(
            EdgeBlock(Flex(MLP)(Flex.d(), self.config['sizes']['latent']['edge'], dropout=self.config['dropout'])),
            NodeBlock(Flex(MLP)(Flex.d(), self.config['sizes']['latent']['node'], dropout=self.config['dropout'])),
            GlobalBlock(Flex(MLP)(Flex.d(), self.config['sizes']['latent']['global'], dropout=self.config['dropout'])),
        )

    def _init_core(self):
        edge_layers = [self.config['sizes']['latent']['edge']] * self.config['sizes']['latent']['edge_depth']
        node_layers = [self.config['sizes']['latent']['node']] * self.config['sizes']['latent']['node_depth']
        global_layers = [self.config['sizes']['latent']['global']] * self.config['sizes']['latent']['global_depth']

        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), *edge_layers, dropout=self.config['dropout'], layer_norm=True),
                )
            ),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), *node_layers, dropout=self.config['dropout'], layer_norm=True),
                ),
                Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["node_block_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(
                        Flex.d(), *global_layers, dropout=self.config['dropout'], layer_norm=True
                    ),
                ),
                edge_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_edge_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
                node_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_node_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )

    def _init_out_transform(self):
        return GraphEncoder(
            EdgeBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), self.config['sizes']['out']['edge']),
                    self.config['sizes']['out']['activation']()
                )
            ),
            NodeBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), self.config['sizes']['out']['node']),
                    self.config['sizes']['out']['activation']()
                )
            ),
            GlobalBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), self.config['sizes']['out']['global']),
                    self.config['sizes']['out']['activation']()
                )
            )
        )

    def _forward_encode(self, data):
        e, x, g = self.encoder(data)
        return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)

    def _forward_decode(self, data):
        e, x, g = self.decoder(data)
        return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)

    def _forward_core(self, latent0, data):
        e = torch.cat([latent0.e, data.e], dim=1)
        x = torch.cat([latent0.x, data.x], dim=1)
        g = torch.cat([latent0.g, data.g], dim=1)
        data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
        e, x, g = self.core(data)
        return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)

    def _forward_out(self, data):
        e, x, g = self.output_transform(data)
        return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)

    def forward(self, data, steps, save_all: bool = False):
        data = self._forward_encode(data)
        latent0 = data

        outputs = []
        for _ in range(steps):
            data = self._forward_core(latent0, data)
            data = self._forward_decode(data)
            out_data = self._forward_out(data)
            if save_all:
                outputs.append(out_data)
            else:
                outputs = [out_data]
        return outputs
    
    

Provide example to resolve Flex modules

## Training

### Create loaders for training

In [15]:
input_datalist = [GraphData.from_networkx(g, feature_key='_features') for g in nx_graphs]
target_datalist = [GraphData.from_networkx(g, feature_key='_target') for g in nx_graphs]
eval_input_datalist = [GraphData.from_networkx(g, feature_key='_features') for g in eval_graphs]
eval_target_datalist = [GraphData.from_networkx(g, feature_key='_target') for g in eval_graphs]
loader = GraphDataLoader(input_datalist, target_datalist, batch_size=512)
eval_loader = GraphDataLoader(eval_input_datalist, eval_target_datalist, batch_size=len(eval_input_datalist))

In [16]:
import wandb
wandb.init(
    project=...,
    tags=...
    group=...
)
from tqdm.auto import tqdm

# get device
if torch.cuda.is_available():
    print("cuda available")
    cuda_device = torch.cuda.current_device()
    device = 'cuda:' + str(cuda_device)
else:
    device = 'cpu'

# initialize network    
network = Network()

# resolve
for input_batch, _ in loader:
    x = input_batch.x
    network(input_batch, 10)
    break
    
# send to device
network.to(device, non_blocking=True)
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.AdamW(network.parameters())

# TODO: loss should use all output to accomplish goal in min steps
# training loop
num_epochs = 1000
for epoch in tqdm(range(num_epochs)):
    running_loss = torch.tensor(0., device=device)
    for input_batch, target_batch in loader:
        network.train()
        input_batch = input_batch.to(device)
        target_batch = target_batch.to(device)
        
        output = network(input_batch, 10)[0]
        x, y = output.x, target_batch.x
        loss = loss_fn(x.flatten(), y[:, 0].flatten())
        loss.backward()
        
        with torch.no_grad():
            running_loss += loss
        
        optimizer.step()
        optimizer.zero_grad()
        
    wandb.log({'train_loss': running_loss.detach().cpu().item()}, step=epoch)
    if epoch % 10:
        network.eval()
        running_eval_loss = 0.
        for eval_input_batch, eval_target_batch in eval_loader:
            eval_input_batch = eval_input_batch.to(device)
            eval_target_batch = eval_target_batch.to(device)
                
            eval_out = network(eval_input_batch, 10)[0]
            x, y = eval_out.x, eval_target_batch.x
            eval_loss = loss_fn(x.flatten(), y[:, 0].flatten())
            running_eval_loss += eval_loss.detach().cpu().item()
        wandb.log({"eval_loss": running_eval_loss}, step=epoch)

cuda available


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))


