<a href="https://colab.research.google.com/github/fekonrad/CombOptLayer/blob/main/Demo/COptLayer_Warcraft_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using the COptLayer
First, let's install the COptLayer repository:

In [None]:
!git clone https://github.com/fekonrad/CombOptLayer.git
!cd CombOptLayer

Cloning into 'CombOptLayer'...
remote: Enumerating objects: 96, done.[K
remote: Counting objects: 100% (96/96), done.[K
remote: Compressing objects: 100% (78/78), done.[K
remote: Total 96 (delta 31), reused 25 (delta 3), pack-reused 0 (from 0)[K
Receiving objects: 100% (96/96), 40.63 KiB | 5.80 MiB/s, done.
Resolving deltas: 100% (31/31), done.


... and install all necessary libraries:

In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import matplotlib.pyplot as plt

import CombOptLayer
from CombOptLayer import COptLayer
from CombOptLayer.losses import PerturbedLoss

# The Task
...

This is a short demo on how to use the COptLayer for the toy problem of finding shortest paths on Warcraft maps. To illustrate what the maps, graphs and paths look like, see the plots below:


In [None]:
class WarcraftPaths(Dataset):
    def __init__(self, map_path, cost_path, paths_path):
        super().__init__()
        self.maps = torch.tensor(np.load(map_path), dtype=torch.float32).permute(0, 3, 1, 2)
        self.costs = torch.tensor(np.load(cost_path), dtype=torch.float32)
        self.shortest_paths = torch.tensor(np.load(paths_path), dtype=torch.float32)

    def __len__(self):
        return self.maps.shape[0]

    def __getitem__(self, item):
        return self.maps[item], self.costs[item], self.shortest_paths[item]

In [None]:
# TODO: Load and plot a sample map

## The Model (CNN)
We implement a very basic CNN to estimate the vertex costs of the map. Since this problem is not very complex, a relatively small model would suffice.

In [None]:
class VertexWeightCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # define architecture ...
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding='same')
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding='same')
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, padding='same')
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, padding='same')
        self.conv5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, padding='same')
        self.conv6 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, padding='same')
        self.final_layer = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=2, padding='same')

    def forward(self, x):
        """
        :param img: torch.tensor of shape (b, c, h, w)      (in our case c=3, h=w=96)
        :return: torch.tensor of shape (b, h', w')      (in our case h'=w'=12)
        """
        x = self.conv1(x)             # (b, 16, 96, 96)
        x = nn.ReLU()(x)
        x = self.conv2(x)               # (b, 32, 96, 96)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2)(x)        # (b, 32, 48, 48)

        x = self.conv3(x)             # (b, 32, 48, 48)
        x = nn.ReLU()(x)
        x = self.conv4(x)               # (b, 32, 96, 96)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2)(x)        # (b, 32, 24, 24)

        x = self.conv5(x)             # (b, 32, 24, 24)
        x = nn.ReLU()(x)
        x = self.conv6(x)               # (b, 32, 24, 24)
        x = nn.ReLU()(x)
        x = nn.MaxPool2d(kernel_size=2)(x)        # (b, 32, 12, 12)

        return nn.Softplus()(self.final_layer(x)).squeeze(1)  # Softplus to make weights non-negative.

## The Combinatorial Solver (Dijkstra)
Now we just have to implement our solver for finding shortest paths (given vertex weights) on 2D grids, where  the possible moves at each point are up, down, left, right and diagonal.

We have to make sure that the input of our solver is compatible with what our statistical model (the CNN) returns as outputs, i.e. in this case a `torch.tensor` of shape `(b, 1, h, w)`.

**Remark:**
The COptLayer will have to solve multiple instances of the combinatorial problem at once. Therefore it is clever to try to implement the solver using parallelizable operations, instead of trying to solve the problems sequentially. This means either sticking to what `torch` already has implemented or writing your own parallelized methods e.g. using CUDA.

In [None]:
def parallel_dijkstra(vertex_weights: torch.Tensor, max_iterations: int = None):
    """
    Parallel Dijkstra-like solver for multiple instances on a 2D grid.

    :param vertex_weights: Tensor of shape (b, h, w) representing vertex weights.
    :param max_iterations: Maximum number of iterations to run. If None, set to h + w.
    :return: Tensor of shape (b, h, w) indicating the path taken (1s on the path, 0s elsewhere).
    """
    b, h, w = vertex_weights.shape
    device = vertex_weights.device

    # Initialize distance tensor with infinity and set the start position
    distances = torch.full((b, h, w), float('inf'), device=device)
    distances[:, 0, 0] = vertex_weights[:, 0, 0]

    # Initialize predecessor tensors to keep track of paths
    predecessors_x = torch.full((b, h, w), -1, dtype=torch.long, device=device)
    predecessors_y = torch.full((b, h, w), -1, dtype=torch.long, device=device)

    # Define shifts for 8-connected neighborhood
    shifts = [(-1, 0), (1, 0), (0, -1), (0, 1),
              (-1, -1), (-1, 1), (1, -1), (1, 1)]

    # Determine the number of iterations
    if max_iterations is None:
        max_iterations = h + w  # Upper bound for grid-based paths

    for _ in range(max_iterations):
        updated = False
        current_distances = distances.clone()

        for dx, dy in shifts:
            # Shift the distances tensor
            shifted_distances = F.pad(current_distances, (1, 1, 1, 1), mode='constant', value=float('inf'))
            if dx < 0:
                shifted_distances = shifted_distances[:, :h, 1:w+1]
            elif dx > 0:
                shifted_distances = shifted_distances[:, 2:h+2, 1:w+1]
            else:
                shifted_distances = shifted_distances[:, 1:h+1, 1:w+1]

            if dy < 0:
                shifted_distances = shifted_distances[:, :, :w]
            elif dy > 0:
                shifted_distances = shifted_distances[:, :, 2:w+2]
            else:
                shifted_distances = shifted_distances[:, :, 1:w+1]

            # Compute the new possible distances
            new_distances = shifted_distances + vertex_weights

            # Update the distances tensor
            mask = new_distances < distances
            if mask.any():
                distances = torch.where(mask, new_distances, distances)
                # Update predecessors
                px = torch.where(mask, torch.full_like(predecessors_x, torch.clamp(torch.arange(b, device=device)[:, None, None], max=b-1)), predecessors_x)
                py = torch.where(mask, torch.full_like(predecessors_y, torch.clamp(torch.arange(w, device=device)[None, :, None], max=w-1)), predecessors_y)
                updated = True

        if not updated:
            break

    # Backtracking to find the paths
    path_tensor = torch.zeros_like(vertex_weights, dtype=torch.float32)

    # Start from the bottom-right corner
    x_coords = torch.full((b,), h - 1, dtype=torch.long, device=device)
    y_coords = torch.full((b,), w - 1, dtype=torch.long, device=device)

    for _ in range(h + w):
        # Set the path
        path_tensor[torch.arange(b), x_coords, y_coords] = 1

        # Get predecessor coordinates
        prev_x = predecessors_x[torch.arange(b), x_coords, y_coords]
        prev_y = predecessors_y[torch.arange(b), x_coords, y_coords]

        # Check if we've reached the start
        if (prev_x == -1) & (prev_y == -1):
            break

        # Update coordinates
        x_coords = prev_x
        y_coords = prev_y

    return path_tensor

Now we have everything we need to build the model!
Here we use the PerturbedLoss with our implemented solver to train the CNN.

*Note:*
Technically, we should require the vertex weights to be non-negative, in order to guarantee convergence of the solver. Thus, one could/should use the "Multiplicative Perturbation discussed in the paper (reference here). Here we simply use the additive perturbation (which might lead to some weights becoming negative!) and the experiments still seem to work fine.

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VertexWeightCNN().to(DEVICE)
solver = dijkstra
loss_fn = PerturbedLoss(solver, objective='min', num_samples=10, smoothing=1.0)

The training routine now works like any other training routine in torch!

In [None]:
# TODO: Implement training routine and monitor sample map.
epochs = 10
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# TODO: Maybe fix the paths here.
data = WarcraftPaths("warcraft-maps-shortest-paths/test_maps.npy",
                     "warcraft-maps-shortest-paths/test_vertex_weights.npy",
                     "warcraft-maps-shortest-paths/test_shortest_paths.npy")

dataloader = DataLoader(data, batch_size=16, shuffle=True)
steps = 0
loss_vals = []

for _ in range(epochs):
    for maps, weights, paths in dataloader:
        maps, weights, paths = maps.to(DEVICE), weights.to(DEVICE), paths.to(DEVICE)

        optimizer.zero_grad()
        vertex_weight_pred = model(maps)
        paths_pred = solver(vertex_weight_pred.squeeze(1))
        loss = loss_fn(vertex_weight_pred, paths)
        loss_val = loss.item()
        loss_vals.append(loss_val)
        print(f"Loss after {steps} Steps: {loss_val}")
        loss.backward()
        optimizer.step()
        steps += 1

    fig, ax = plt.subplots(ncols=2, nrows=2)
    ax[0, 0].imshow(weights[0].cpu().detach().numpy())
    ax[0, 1].imshow(model(maps).cpu().detach().numpy()[0])
    ax[1, 0].imshow(paths[0].cpu().detach().numpy())
    ax[1, 1].imshow(paths_pred[0].cpu().detach().numpy())
    plt.show()

plt.plot(loss_vals)
plt.xlabel("Steps")
plt.ylabel("Perturbed Loss")
plt.show()


# Results
...