In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch

from torchviz import make_dot
from pt_tnn.temporal_graph import TemporalGraph

graph_viz_dir = "graph_viz"

In [2]:
def generate_graph_viz(config_file, T=3, **dot_kwargs):
    net = TemporalGraph(config_file)
    print(net)

    N = 1
    C, H, W = net.input_shape
    inputs = torch.ones(N, T, C, H, W)

    outputs = net(inputs, n_times=T, cuda=False)
    print(outputs.shape)

    for n, p in net.named_parameters():
        print(n, p.shape)

    dot = make_dot(outputs.mean(), params=dict(net.named_parameters()), **dot_kwargs)

    dot.format = "pdf"
    fname_base = os.path.join(graph_viz_dir, config_file.split('/')[-1].split('.')[0] + f"_unroll_{T}")
    dot.render(fname_base)
    os.remove(fname_base)  # delete auxiliary file

In [3]:
fname = "../configs/test_feedforward_large_skip.json"
generate_graph_viz(fname, T=4)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 3, 

In [4]:
fname = "../configs/test_identity_feedback.json"
generate_graph_viz(fname)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(4, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): ConvRNNBasicCell(
      (conv_input): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_state): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (activation): ReLU()
    )
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 4, 3, 3])
conv1._pre_memory.conv.bias torch.Size([2])
conv1._recurrent_cell.conv_input.weight torch.Si

In [5]:
fname = "../configs/test_identity_feedback.json"
generate_graph_viz(fname, T=4)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(4, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): ConvRNNBasicCell(
      (conv_input): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_state): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (activation): ReLU()
    )
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 4, 3, 3])
conv1._pre_memory.conv.bias torch.Size([2])
conv1._recurrent_cell.conv_input.weight torch.Si

In [6]:
fname = "../configs/test_feedforward.json"
generate_graph_viz(fname)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 3, 3, 3])
conv1._pre_memory.conv.bias torch.Size([2])
output._pre_memory.conv.weight torch.Size([1, 2, 3, 3])
output._pre_memory.conv.bias torch.Size([1])


In [7]:
fname = "../configs/test_feedforward_large.json"
generate_graph_viz(fname)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 3, 

In [8]:
fname = "../configs/test_feedforward_large.json"
generate_graph_viz(fname, T=8)  # should be identical to T=3

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 3, 

In [9]:
fname = "../configs/test_feedforward_identity_harbor.json"
generate_graph_viz(fname)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 3, 3, 3])
conv1._pre_memory.conv.bias torch.Size([2])
output._pre_memory.conv.weight torch.Size([1, 2, 3, 3])
output._pre_memory.conv.bias torch.Size([1])


In [10]:
fname = "../configs/test_identity_feedback_identity_harbor.json"
generate_graph_viz(fname, T=4)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(4, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 4, 3, 3])
conv1._pre_memory.conv.bias torch.Size([2])
output._pre_memory.conv.weight torch.Size([1, 2, 3, 3])
output._pre_memory.conv.bias torch.Size([1])


In [11]:
fname = "../configs/test_identity_feedback_identity_harbor.json"
generate_graph_viz(fname, T=5)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(4, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 4, 3, 3])
conv1._pre_memory.conv.bias torch.Size([2])
output._pre_memory.conv.weight torch.Size([1, 2, 3, 3])
output._pre_memory.conv.bias torch.Size([1])


In [12]:
fname = "../configs/test_identity_feedback_identity_harbor_two_layer.json"
generate_graph_viz(fname, T=5)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(5, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (output): RecurrentModule(
    (_harbor_policy): Identity()
    (_pre_memory): BasicConv(
      (conv): Conv2d(2, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout): Identity()
      (bn): Identity()
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
)
torch.Size([1, 1, 6, 6])
conv1._pre_memory.conv.weight torch.Size([2, 5, 3, 3])
conv1

In [13]:
fname = "../configs/alexnet.json"
generate_graph_viz(fname, T=8, show_attrs=True, show_saved=True)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLUPool(
      (conv): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (dropout): Identity()
      (bn): Identity()
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=[3, 3], stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLUPool(
      (conv): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (dropout): Identity()
      (bn): Identity()
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv3): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLU(
      (conv): Conv2d(1

In [14]:
fname = "../configs/alexnet_227.json"
generate_graph_viz(fname, T=8, show_attrs=True, show_saved=True)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLUPool(
      (conv): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (dropout): Identity()
      (bn): Identity()
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=[3, 3], stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLUPool(
      (conv): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (dropout): Identity()
      (bn): Identity()
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv3): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLU(
      (conv): Conv2d(1

In [15]:
fname = "../configs/alexnet_64.json"
generate_graph_viz(fname, T=7, show_attrs=True, show_saved=True)

TemporalGraph(
  (conv1): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLUPool(
      (conv): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (dropout): Identity()
      (bn): Identity()
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=[3, 3], stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv2): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLUPool(
      (conv): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (dropout): Identity()
      (bn): Identity()
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (_recurrent_cell): IdentityCell()
    (_post_memory): Identity()
  )
  (conv3): RecurrentModule(
    (_harbor_policy): ResizeConcat()
    (_pre_memory): BasicConvReLU(
      (conv): Conv2d(1

### Test Case

In [16]:
from torchvision.models import AlexNet

model = AlexNet()

x = torch.randn(1, 3, 227, 227)
y = model(x)
dot = make_dot(y, params=dict(model.named_parameters()), show_attrs=True, show_saved=True)

dot.format = "pdf"
fname_base = os.path.join(graph_viz_dir, "generic_alexnet_227")
dot.render(fname_base)
os.remove(fname_base)  # delete auxiliary file