In [1]:
import os
import math


import torch
import torchvision

import ipywidgets as widgets
import matplotlib.pyplot as plt

from determined.experimental import Checkpoint
from torchvision.utils import make_grid

to_pil = torchvision.transforms.ToPILImage()

In [2]:
def shift_image_range(images: torch.Tensor, range_in=(-1, 1), range_out=(0, 1)):
    images = images.clone()
    images.detach_()

    scale = (range_out[1] - range_out[0]) / (range_in[1] - range_in[0])
    bias = range_out[0] - range_in[0] * scale
    images = images * scale + bias
    images.clamp_(min=range_out[0], max=range_out[1])

    return images

In [3]:
def generate_images(checkpoint, z, out_path='../outputs'):
    num_rows = int(math.sqrt(num_images))
    name = checkpoint.split('/')[-1]
    
    print(name)
    print()
    
    
    trail = Checkpoint.load_from_path(checkpoint, map_location=torch.device('cpu'))
    generator = trail.generator.eval()
    
    images = shift_image_range(generator(z))
    grid = to_pil(make_grid(images, nrow=num_rows))
    grid.save(f'{out_path}/{name}.png')

    
    fig, axes = plt.subplots(dpi=200, figsize=(6, 6))
    plt.imshow(grid)
    plt.show()

In [4]:
num_images = 25
z = torch.randn(num_images, 128, 1, 1)

checkpoints_path = '../checkpoints'

checkpoints_paths = [f.path for f in os.scandir(checkpoints_path) if f.is_dir()]
checkpoints_paths.sort()

dropdown = widgets.Dropdown(
    options=['---', *checkpoints_paths],
    value='---',
    description='Checkpoints:',
    disabled=False,
    layout={'width': '500px'}
)

output = widgets.Output()


def on_dropdown_change(event):
    if event['type'] == 'change' and event['name'] == 'value':
        with output:
            output.clear_output()
            if event['new'] != '---':
                generate_images(event['new'], z)
            else:
                print('Pleases select a valid checkpoint!')

dropdown.observe(on_dropdown_change)

display(dropdown)
display(output)

Dropdown(description='Checkpoints:', layout=Layout(width='500px'), options=('---', '../checkpoints/01_batch_no…

Output()