## Introduction
The algorithm used to position subplots on a grid is developed in this notebook.

This allows the user to specify only the span (number of rows and columns occupied) of each sub-plot and the total number of columns in the grid, while the algorithms determines where in the grid to place each sub-plot.

The grid operates by the following rules:
* Each sub-plot spans an integer number of columns (default=minimum=1) and an integer number of rows (default=minimum=1) on the grid
* The total number of grid columns is fixed beforehand
* The number of grid rows grows until all sub-plots have been accommodated

The layout algorithm is a very simple, greedy algorithm, but consists of several steps which justify executing each one individually in a notebook environment putting them together.

After this calculation, the position and span of each sub-plot must by passed to the `plotly` objects in two different formats:
* The `specs` dictionary passed to `make_subplots`
* The `row` and `col` arguments of the `add_trace` method

## Imports

In [1]:
from typing import List, Tuple, Mapping
from plotly import graph_objects as go
from plotly.subplots import make_subplots
import torch
from torch import Tensor

## Plotly sub-plot construction

### Exploratory code
* Adapted from example [here](https://plotly.com/python/subplots/#multiple-custom-sized-subplots)
  - See linked example for more detail on `specs` argument 
* The `renderer='notebook_connected'` part replaces the ~4 MB plotly source in the notebook DOM by a CDN, reducing the size of this notebook when saved 

In [2]:
specs = [[dict(rowspan=1, colspan=2), None, dict(rowspan=2, colspan=1)],
         [dict(rowspan=1, colspan=1), dict(rowspan=1, colspan=1), None]]
fig = make_subplots(rows=2, cols=3, specs=specs)
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1]), row=1, col=1)
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1]), row=1, col=3)
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1]), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1]), row=2, col=2)
fig.update_layout(height=300, width=600, margin=dict(l=8, r=8, t=8, b=8))
fig.show(renderer='notebook_connected')

### As a function
* `positions` is a list of tuples containing the row and column indices of the top-left grid location of each sub-plot
* Pass `positions` as 0-based (as it is easier to work with in general), but convert it to 1-based when passing to `fig.add_trace`

In [3]:
def show_subplots(num_rows, num_cols, specs, positions):
    fig = make_subplots(rows=num_rows, cols=num_cols, specs=specs)
    for pos in positions:
        fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1]), row=pos[0]+1, col=pos[1]+1)
    fig.update_layout(height=300, width=600, margin=dict(l=8, r=8, t=8, b=8))
    fig.show(renderer='notebook_connected')

In [4]:
num_rows, num_cols = 2,3
specs = [[dict(rowspan=1, colspan=2), None, dict(rowspan=2, colspan=1)],
         [dict(rowspan=1, colspan=1), dict(rowspan=1, colspan=1), None]]
positions = [(0,0), (0,2), (1,0), (1,1)]
show_subplots(num_rows,num_cols,specs,positions)

## Deriving specs dictionary from positions and sub-plot spans


### Exploratory code

* NB initialize `specs` using a nested list comprehension
  - Caveat: if initialized as `[[None]*num_cols]*num_rows`, cells in the same column of different rows will share memory

In [5]:
positions = [(0,0), (0,2), (1,0), (1,1)]
spans = [(1,2), (2,1), (1,1), (1,1)]

specs = [[None for _ in range(num_cols)] for _ in range(num_rows)]
for pos,span in zip(positions,spans):
    specs[pos[0]][pos[1]] = dict(rowspan=span[0], colspan=span[1])

show_subplots(num_rows,num_cols,specs,positions)

### As a function

In [6]:
def calc_specs(num_cols, num_rows, positions, spans):
    specs = [[None for _ in range(num_cols)] for _ in range(num_rows)]
    for pos,span in zip(positions,spans):
        specs[pos[0]][pos[1]] = dict(rowspan=span[0], colspan=span[1])
    return specs

## Layout: determining sub-plot positions
Inputs
* Ordered list of sub-plots, each with row and column spans defined
* Number of grid columns

Outputs
* Row and column positions of each subplot
* Number of grid rows

### Approach

* Keep track of occupied grid cells using a matrix the same size as the grid
  * Cells in the matrix have values equal to the index of the subplot that occupies the corresponding cell in the grid
  * Unoccupied cells are indicated by a value of `None` in the matrix
  * The matrix has a fixed width, whereas the height is resized dynamically as subplots are added
* Keep track of the next available grid cell
  * This variable starts at the top-left corner of the grid, increments towards the right and wraps downwards
  * If possible, the next sub-plot will be placed with its top-left corner at this cell. Otherwise, the variable will be incremented as above
  * If empty cells are left behind when sub-plots are placed, there are not revisited again. This ensures that the order of the sub-plots remain as specified
* Continue until all sub-plots have been placed

### Grid cell generator
* Yields (row, column) tuples for a grid with a fixed number of columns and an indefinite number of rows

In [7]:
def grid_cell_gen(num_cols):
    row = 0
    while True:
        for col in range(num_cols): yield (row, col)
        row += 1

In [8]:
cell_gen = grid_cell_gen(3)
[next(cell_gen) for _ in range(10)]

[(0, 0),
 (0, 1),
 (0, 2),
 (1, 0),
 (1, 1),
 (1, 2),
 (2, 0),
 (2, 1),
 (2, 2),
 (3, 0)]

### Dynamic resizing
* Checks if the matrix can fit a sub-plot with its top-left corner at `pos` and the specified `span`. 
* Resizes matrix if needed.

In [9]:
def ensure_matrix_height(matrix:Tensor, pos:Tuple[int,int], span:Tuple[int,int]=(1,1)) -> Tensor:
    num_grid_cols = matrix.shape[1]
    required_rows = pos[0] + span[0] - matrix.shape[0]
    if required_rows > 0:
        new_rows = torch.full((required_rows, num_grid_cols), fill_value=-1)
        matrix = torch.cat((matrix, new_rows), dim=0)
    return matrix

In [10]:
matrix = torch.tensor([[1,1,2],[3,4,2]])
print(matrix)
matrix = ensure_matrix_height(matrix, (2,2), (3,2))
print(matrix)

tensor([[1, 1, 2],
        [3, 4, 2]])
tensor([[ 1,  1,  2],
        [ 3,  4,  2],
        [-1, -1, -1],
        [-1, -1, -1],
        [-1, -1, -1]])


### Availability check
* Checks if all cells starting at the specified position (top-left corner) and continuing for the specified span (downwards and to the right) are unoccupied (value < -1>)
* If the matrix is not high enough, it is resized and returned
* If `span` alone is wider than the matrix, `ValueError` is raised because the sub-plot will never fit
* If `pos+span` is wider than the matrix, False is returned, because the sub-plot may fit at the start of a new row

In [11]:
def cells_available(matrix:Tensor, pos:Tuple[int,int], span:Tuple[int,int]=(1,1)) -> Tuple[bool, Tensor]:
    # Check width and height
    matrix = ensure_matrix_height(matrix, pos, span)
    if span[1] > matrix.shape[1]: raise ValueError(f'Sub-plot ({span[1]}) is wider than grid ({matrix.shape[1]})')
    if pos[1] + span[1] > matrix.shape[1]: return False, matrix  # Off the right-most edge

    # Check cell contents
    target_cells = matrix[pos[0]:pos[0]+span[0], pos[1]:pos[1]+span[1]]
    is_available = bool((target_cells < 0).all())
    return is_available, matrix

In [12]:
matrix = torch.tensor([[1,1,2],[3,4,2]])
print(matrix)
is_available, matrix = cells_available(matrix, (1,1))
print(is_available, matrix)
is_available, matrix = cells_available(matrix, (2,1), (1,2))
print(is_available, matrix)
is_available, matrix = cells_available(matrix, (2,1), (1,3))
print(is_available, matrix)

tensor([[1, 1, 2],
        [3, 4, 2]])
False tensor([[1, 1, 2],
        [3, 4, 2]])
True tensor([[ 1,  1,  2],
        [ 3,  4,  2],
        [-1, -1, -1]])
False tensor([[ 1,  1,  2],
        [ 3,  4,  2],
        [-1, -1, -1]])


In [13]:
pos, span = (1,1), (1,1)
matrix[pos[0]:pos[0]+span[0], pos[1]:pos[1]+span[1]]

tensor([[4]])

### Algorithm

In [14]:
def place_subplots(num_grid_cols:int, spans:List[Tuple[int, int]]):
    matrix = torch.empty(0, num_grid_cols)
    positions:List[Tuple[int,int]] = []
    cell_gen = grid_cell_gen(num_grid_cols)
    cur_pos = next(cell_gen)
    for spi, span in enumerate(spans):
        is_available, matrix = cells_available(matrix, cur_pos, span)
        while not is_available: 
            cur_pos = next(cell_gen)
            is_available, matrix = cells_available(matrix, cur_pos, span)
        matrix[cur_pos[0]:cur_pos[0]+span[0], cur_pos[1]:cur_pos[1]+span[1]] = spi
        positions.append(cur_pos)
    num_grid_rows = matrix.shape[0]
    specs = calc_specs(num_grid_cols, num_grid_rows, positions, spans)
    return num_grid_rows, positions, specs, matrix

In [15]:
num_grid_cols = 3
spans = [(1,2), (2,1), (1,1), (1,1)]
num_grid_rows, positions, specs, matrix = place_subplots(num_grid_cols, spans)
print(matrix)
show_subplots(num_grid_rows, num_grid_cols, specs, positions)

tensor([[0., 0., 1.],
        [2., 3., 1.]])


## More tests

In [16]:
def test_placement(num_grid_cols, spans):
    num_grid_rows, positions, specs, matrix = place_subplots(num_grid_cols, spans)
    show_subplots(num_grid_rows, num_grid_cols, specs, positions)

In [17]:
test_placement(4, [(2,2),(1,3),(2,1)])

In [18]:
test_placement(2, [(2,2),(1,1),(2,1),(1,1)])