## Introduction
This notebook uses trainplotkit with PyTorch Lightning.

The model and data are based on the [Lightning in 15 minutes](https://lightning.ai/docs/pytorch/stable/starter/introduction.html) tutorial on PyTorch Lightning's [documentation site](https://lightning.ai/docs/pytorch/stable/) with validation and testing from the [next step](https://lightning.ai/docs/pytorch/stable/common/evaluation_basic.html) of their tutorial.

This notebook also acts as preparation for the implementation of adapters to minimize the code needed to use trainplotkit with PyTorch Lightning.


## Imports

In [None]:
import sys, importlib
import os
import torch
from torch import optim, nn, utils, Tensor
from torch.nn import functional as F
from torchvision import datasets, transforms
import lightning as L

sys.path.append('..')
from trainplotkit.plotgrid import PlotGrid
from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, PredImageSP, ClassProbsSP

In [2]:
def reload_imports():
    importlib.reload(sys.modules['trainplotkit.layout'])
    importlib.reload(sys.modules['trainplotkit.plotgrid'])
    importlib.reload(sys.modules['trainplotkit.subplots.basic'])
    global PlotGrid, TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, PredImageSP, ClassProbsSP
    from trainplotkit.plotgrid import PlotGrid
    from trainplotkit.subplots.basic import TrainingCurveSP, MetricSP, ValidLossSP, ImageSP, PredImageSP, ClassProbsSP

## Data preparation
* From [tutorial](https://lightning.ai/docs/pytorch/stable/starter/introduction.html#define-a-dataset): Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.
* Added test and validation splits from the [next step](https://lightning.ai/docs/pytorch/stable/common/evaluation_basic.html)

In [3]:
train_set    = datasets.MNIST(root="MNIST", download=True, train=True, transform=transforms.ToTensor())
test_set     = datasets.MNIST(root="MNIST", download=True, train=False, transform=transforms.ToTensor())

train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
seed = torch.Generator().manual_seed(42)
train_set, valid_set = utils.data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = utils.data.DataLoader(train_set)
valid_loader = utils.data.DataLoader(valid_set, num_workers=15)
test_loader  = utils.data.DataLoader(test_set, num_workers=15)

## Model
* From [tutorial](https://lightning.ai/docs/pytorch/stable/starter/introduction.html#define-a-lightningmodule): A `LightningModule` enables your PyTorch `nn.Module` to play together in complex ways inside the `training_step` (there is also an optional `validation_step` and `test_step`).
* Added test and validation steps from the [next step](https://lightning.ai/docs/pytorch/stable/common/evaluation_basic.html)

In [4]:
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

### Model visualization (optional)
Just skip this section if you don't have `idlmav`. Nothing depends on it.

In [5]:
x = train_set[0][0]
x = x.view(x.size(0), -1)
z = encoder(x)
x.shape, z.shape

(torch.Size([1, 784]), torch.Size([1, 3]))

In [6]:
from idlmav import MAV, MavOptions, RenderOptions, plotly_renderer
enc_mav = MAV(encoder, x)
dec_mav = MAV(decoder, z)

In [7]:
with plotly_renderer('notebook_connected'): enc_mav.show_figure(RenderOptions(height_px=200))

In [8]:
with plotly_renderer('notebook_connected'): dec_mav.show_figure(RenderOptions(height_px=200))

## Training
From [tutorial](https://lightning.ai/docs/pytorch/stable/starter/introduction.html#train-the-model): The Lightning `Trainer` “mixes” any `LightningModule` with any dataset and abstracts away all the engineering complexity needed for scale.

In [9]:
# Uncomment to clear logs before fitting
!rm -rf lightning_logs

In [10]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, limit_val_batches=100, max_epochs=5)
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=valid_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 50.4 K | train
1 | decoder | Sequential | 51.2 K | train
-----------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated mode

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


From [tutorial](https://lightning.ai/docs/pytorch/stable/starter/introduction.html#train-the-model): The Lightning Trainer automates 40+ tricks including:
* Epoch and batch iteration
* `optimizer.step()`, `loss.backward()`, `optimizer.zero_grad()` calls
* Calling of `model.eval()`, enabling/disabling grads during evaluation
* Checkpoint Saving and Loading
* Tensorboard (see loggers options)
* Multi-GPU support
* TPU
* 16-bit precision AMP support


## Test the model

In [11]:
trainer = L.Trainer(limit_test_batches=100)
trainer.test(autoencoder, dataloaders=test_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.05456400290131569}]

## Use the model
From [tutorial](https://lightning.ai/docs/pytorch/stable/starter/introduction.html#use-the-model): Once you’ve trained the model you can export to onnx, torchscript and put it into production or simply load the weights and run predictions.

In [13]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=4-step=500.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[0.3963, 0.3712, 1.7157],
        [0.3116, 0.5385, 1.5479],
        [0.4815, 0.5032, 1.8297],
        [0.6178, 0.3892, 1.5489]], device='cuda:0', grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡


## Visualize training using tensorboard
From [tutorial](https://lightning.ai/docs/pytorch/stable/starter/introduction.html): If you have tensorboard installed, you can use it for visualizing experiments.

### Check for running instances

In [None]:
!ps aux | grep  tensorboard

dev        10131  0.0  0.0   4788  3236 pts/9    Ss+  20:14   0:00 /bin/bash -c ps aux | grep  tensorboard
dev        10133  0.0  0.0   4032  1988 pts/9    S+   20:14   0:00 grep tensorboard


### Start as sub-process

In [None]:
import subprocess
tb_server = subprocess.Popen(["tensorboard", "--logdir=/home/dev/ai/trainplotkit/nbs/lightning_logs", "--port=6006", "--host=localhost"])

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)


### Open in browser
From WSL, running the URL through `explorer.exe` opens the link correctly in the default browser

In [None]:
!explorer.exe http://localhost:6006/

### Terminate sub-process

In [None]:
tb_server.terminate()

TensorBoard caught SIGTERM; exiting...


### Other interactions
* Blocking start (not very valuable for Notebook environment):
  ```python
  import os
  os.system("tensorboard --logdir=/home/dev/ai/trainplotkit/nbs --load_fast=true --port 6006 &")
  ```
* Terminate using `pkill`
  ```bash
  pkill -f tensorboard
  ```
* Embed Tensorboard UI in notebook (may require some setup)
  ```raw
  %load_ext tensorboard
  %tensorboard --logdir . --port 6006
  ```

## Adding trainplotkit
The example presented in this tutorial can be visualized using trainplotkit by making the following updates:
* Return dictionaries from `LightningModule.training_step` and `LightningModule.validation_step` containing at least the following keys:
  * `loss`: The loss for that batch, usually returned already, but sometimes under another name or without a dictionary
  * `predictions`: The outputs after applying the model to `inputs`
  * (optional) `targets`: The quantity that is compared to `predictions` in the loss function (specify only if not equal to `batch[1]`)
* Write a [Ligthning callback](https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks) that implements the following methods and forwards the relevent information to a `PlotGrid` object:
  * `on_fit_start`
  * `on_train_batch_end`
  * `on_train_epoch_end`
  * `on_validation_batch_end`
  * `on_validation_epoch_end`
  * `on_fit_end`

The callback will likely change very little between modules and we might as well release `trainplotkit` with a `LightningAdapter` which performs this function, minimizing the amount of additional code required to add `trainplotkit` to an existing lightning workflow 

### Adding outputs to `training_step` and `validation_step`
Same as above, just a new `return` statement

In [14]:
class LitAutoEncoder2(LitAutoEncoder):
    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return dict(loss=loss, predictions=x_hat, targets=x)
    
    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
        return dict(loss=val_loss, predictions=x_hat, targets=x)
    
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
autoencoder = LitAutoEncoder2(encoder, decoder)

### Experimental callback
This callback just prints custom messages to trace the sequence of the callbacks

In [15]:
class PrintCallback(L.Callback):
    def __init__(self, batch_modulus=100):
        super().__init__()
        self.batch_modulus = batch_modulus

    def on_fit_start(self, trainer, pl_module):
        print("on_fit_start")

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % self.batch_modulus == 0: 
            print(f'        on_train_batch_end: {batch_idx}')

    def on_train_epoch_end(self, trainer, pl_module):
        print(f'    on_train_epoch_end')

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % self.batch_modulus == 0: 
            print(f'        on_validation_batch_end: {batch_idx}')

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f'    on_validation_epoch_end')

    def on_fit_end(self, trainer, pl_module):
        print("on_fit_end")

In [None]:
trainer = L.Trainer(callbacks=[PrintCallback()], limit_train_batches=500, limit_val_batches=500, max_epochs=5)
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=valid_loader)

### LightningAdapter

In [17]:
class LightningAdapter(L.Callback):
    def __init__(self, pg:PlotGrid):
        super().__init__()
        self.pg = pg
        self.busy_with_sanity_check = False
    
    def on_sanity_check_start(self, trainer, pl_module): self.busy_with_sanity_check = True
    def on_sanity_check_end(self, trainer, pl_module): self.busy_with_sanity_check = False

    def on_fit_start(self, trainer, pl_module):
        self.pg.before_fit()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if self.busy_with_sanity_check: return
        inputs = outputs['inputs'] if 'inputs' in outputs else batch[0]
        targets = outputs['targets'] if 'targets' in outputs else batch[1]
        self.pg.after_batch(training=True, inputs=inputs, targets=targets,
                            predictions=outputs['predictions'], loss=outputs['loss'])

    def on_train_epoch_end(self, trainer, pl_module):
        if self.busy_with_sanity_check: return
        self.pg.after_epoch(training=True)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if self.busy_with_sanity_check: return
        inputs = outputs['inputs'] if 'inputs' in outputs else batch[0]
        targets = outputs['targets'] if 'targets' in outputs else batch[1]
        self.pg.after_batch(training=False, inputs=inputs, targets=targets,
                            predictions=outputs['predictions'], loss=outputs['loss'])

    def on_validation_epoch_end(self, trainer, pl_module):
        if self.busy_with_sanity_check: return
        self.pg.after_epoch(training=False)

    def on_fit_end(self, trainer, pl_module):
        self.pg.after_fit()

In [None]:
reload_imports()

In [241]:
# Subplots
batch_loss_fn = lambda preds,targs: F.mse_loss(preds,targs,reduction='none').mean(dim=1)
sps = [
    TrainingCurveSP(colspan=2), 
    ImageSP(valid_set),
    ValidLossSP(batch_loss_fn, remember_past_epochs=True, colspan=2), 
    PredImageSP(remember_past_epochs=True, img_size=(28,28)),
]
pg = PlotGrid(num_grid_cols=3, subplots=sps, fig_height=600)
pg.show()

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': 'f1a71210-fa0c-4b7c-ac86-a0432211e42b',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': 'a3da3c93-0763-4b53-8fec-dd0b375b3a6b',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': '1f02b445-6d70-4aa9-a1e1-cf1a9c655c25',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'type': 'image',
        

In [242]:
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
autoencoder = LitAutoEncoder2(encoder, decoder)

trainer = L.Trainer(callbacks=[LightningAdapter(pg)], limit_train_batches=500, limit_val_batches=500, max_epochs=5)
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=valid_loader)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 50.4 K | train
1 | decoder | Sequential | 51.2 K | train
-----------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [243]:
pg.show_static()

## Debugging

### Subplots shifting for `show_static`
At first, `pg.show_static` would display the two image subplots in front of the two scatter subplots, obscuring the latter. This section was written to determine why this happens and how it can be avoided.

Investigation revealed the following:
* The issue can be reproduced with minimal steps:
  * Create a widget with 1 row and 2 columns using `make_subplots`
  * Add a `go.Scatter` trace to the first and a `go.Image` trace to the second
  * Render the widget in the notebook by returning it as the output of a cell or via `display`
  * Create a `go.Figure` with the widget as argument
  * Render the figure in the notebook using `show()`
* Calling `make_subplots` again and passing the newly created figure as an argument to `make_subplots` seems to fix this
* `PlotGrid.show_static` was updated to reflect this fix

In [218]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots

In [237]:
reload_imports()

In [238]:
batch_loss_fn = lambda preds,targs: F.mse_loss(preds,targs,reduction='none').mean(dim=1)
sps = [
    TrainingCurveSP(), 
    ImageSP(valid_set),
]
pg = PlotGrid(num_grid_cols=2, subplots=sps, fig_height=300)
pg.show()

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': 'eb0e8c71-4630-459b-a530-e6fa4b4cbec6',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': '157fb3e3-c3ed-4d49-b728-400761e0cf6e',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'c39f7de8-c5ae-432a-8f00-856c1944dc45',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'type': 'image',
        

In [239]:
fig = go.Figure(pg.widget)
fig.show(renderer='notebook_connected')

In [240]:
pg.show_static()

In [221]:
d = pg.widget.to_dict()
list(d.keys())

['data', 'layout']

In [222]:
list(d['layout'].keys())

['annotations',
 'template',
 'xaxis',
 'yaxis',
 'xaxis2',
 'yaxis2',
 'margin',
 'height',
 'autosize']

In [223]:
print(pg.widget.__dict__["_grid_str"])

This is the format of your plot grid:
[ (1,1) x,y   ]  [ (1,2) x2,y2 ]



In [224]:
d['layout']['margin']

{'l': 0, 'r': 0, 't': 20, 'b': 10}

In [225]:
pg.widget.layout['xaxis2']

layout.XAxis({
    'anchor': 'y2', 'title': {'text': ''}
})

In [226]:
list(pg.widget._get_subplot_coordinates())

[(1, 1), (1, 2)]

In [227]:
z = valid_set[0][0].permute((1,2,0)).tile((1,1,3))
display(z.shape)
display(z.min(), z.max())

torch.Size([28, 28, 3])

tensor(0.)

tensor(1.)

In [234]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
specs=[[{'rowspan': 1, 'colspan': 1}, {'rowspan': 1, 'colspan': 1}]]
sp_titles = ['Training curve', 'Input: sample 0<br>Target=2']
widget = go.FigureWidget(make_subplots(rows=1, cols=2, specs=specs, subplot_titles=sp_titles))
widget.update_layout(height=300, margin=dict(l=0, r=0, t=20, b=10))

train_loss_trace = go.Scatter(x=[], y=[], mode='lines+markers', name='Training loss')
valid_loss_trace = go.Scatter(x=[], y=[], mode='lines+markers', name='Validation loss')
marker_trace     = go.Scatter(x=[], y=[], mode='markers', showlegend=False, hoverinfo='skip',
                                marker=dict(color='rgba(0,0,0,0.2)', line=dict(color='black', width=2)))
widget.add_trace(train_loss_trace, row=1, col=1)
widget.add_trace(valid_loss_trace, row=1, col=1)
widget.add_trace(marker_trace, row=1, col=1)
widget.update_layout(**{'xaxis': {'title_text': 'Epoch'}, 'yaxis': {'title_text': 'Loss'}})

img_trace = go.Image(z=z.tolist(), zmin=[0.0,0.0,0.0,0], zmax=[1.0,1.0,1.0,1])
widget.add_trace(img_trace, row=1, col=2)
widget.update_layout(**{'xaxis2': {'title_text': ''}, 'yaxis2': {'title_text': ''}})

widget

FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Training loss',
              'type': 'scatter',
              'uid': 'f0573cb9-2eb1-416d-b430-1acb89a1939e',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'mode': 'lines+markers',
              'name': 'Validation loss',
              'type': 'scatter',
              'uid': '34457335-918c-419b-beb1-39e8b7dfd0d1',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'hoverinfo': 'skip',
              'marker': {'color': 'rgba(0,0,0,0.2)', 'line': {'color': 'black', 'width': 2}},
              'mode': 'markers',
              'showlegend': False,
              'type': 'scatter',
              'uid': '7adb4d57-4228-4d5b-b396-06f755eaa8fb',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'type': 'image',
        

In [235]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
widget = go.FigureWidget(make_subplots(rows=1, cols=2, specs=[[{}, {}]]))
widget.update_layout(height=300, margin=dict(l=0, r=0, t=20, b=10))
widget.add_trace(go.Scatter(x=[], y=[]), row=1, col=1)
widget.add_trace(go.Image(z=z.tolist(), zmin=[0,0,0,0], zmax=[1,1,1,1]), row=1, col=2)
widget

FigureWidget({
    'data': [{'type': 'scatter',
              'uid': 'a53a3fa3-70cb-444b-893e-0d17b04838cb',
              'x': [],
              'xaxis': 'x',
              'y': [],
              'yaxis': 'y'},
             {'type': 'image',
              'uid': '61385971-061e-4f79-b7ac-4f2fad572626',
              'xaxis': 'x2',
              'yaxis': 'y2',
              'z': [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0,
                    0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0,
                    0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0,
                    0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0,
                    0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
                    [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0,
                    0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], 

In [236]:
fig = go.Figure(widget)
fig = make_subplots(rows=1, cols=2, specs=[[{}, {}]], figure=fig)
fig.show(renderer='notebook_connected')