In [1]:
import os
import torch
from IPython.display import display
import ipywidgets as widgets
from torchsummary import summary

from csfnst.fastneuralstyle.networks import TransformerNet, BottleneckType

In [2]:
input_size = (3, 256, 256)
force_cpu = True
device_type = 'cuda' if torch.cuda.is_available() and not force_cpu else 'cpu'
device = torch.device(device_type)


checkpoints_dir = '../checkpoints/'
models = [model for model in os.listdir(checkpoints_dir) if model != '.gitkeep']
models.sort()

out = widgets.Output() 
dropdown_models = widgets.Dropdown(options=models, description='Model:')
button_summarize = widgets.Button(description='Summarize', button_style='Info')
hbox = widgets.HBox([dropdown_models, button_summarize])


def perform_summarization(button):
    out.clear_output()
    with out:
        print('Calculating...')
        
    if device_type == 'cuda':
        checkpoint = torch.load(checkpoints_dir + dropdown_models.value)
    else:
        checkpoint = torch.load(checkpoints_dir + dropdown_models.value, map_location={'cuda:0': 'cpu'})

    model = TransformerNet(
        channel_multiplier=checkpoint['channel_multiplier'],
        expansion_factor=checkpoint['expansion_factor'],
        bottleneck_type=BottleneckType[checkpoint['bottleneck_type'].replace('BottleneckType.', '')],
        bottleneck_size=checkpoint['bottleneck_size'],
        intermediate_activation_fn=checkpoint['intermediate_activation_fn'],
        final_activation_fn=checkpoint['final_activation_fn']
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    out.clear_output()
    with out:
        summary(model, device=device_type, input_size=input_size)

button_summarize.on_click(perform_summarization)
        
display(hbox, out)

HBox(children=(Dropdown(description='Model:', options=('experiment1__net01__s5__m32__residual_block.pth', 'exp…

Output()