## Load dependencies

In [1]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.nn import Parameter
from torch.nn import MSELoss, L1Loss, SmoothL1Loss, CrossEntropyLoss

import pandas as pd
import numpy as np

from collections import Iterable, defaultdict

from factslab.utility import load_glove_embedding
from factslab.datastructures import ConstituencyTree, DependencyTree
from factslab.pytorch.temporal_events_attention import Attention_mlp
import torch


## Toy data 

In [2]:
X = [["The", "boy", "ran", "into", "the", "garden", "and", "started", "playing", "."],
     ["Susan", "dies", ".", "So", "does", "Jack"],
     ["He", "belived", "in", "humanity", "but", "didn't", "believe", "in", "God", "."],
     ["The", "man", "stole", "his", "umbrella", "and", "started", "running", "."]]
     
spans = [ 
          [[1,2], [5,6,7]],
          [[1,2,3], [4]],
          [[1], [6]],
          [[1,2,3], [5,6,7]]
        ]
         
roots = [[2, 7],
         [1, 4],
        [1, 6],
        [2, 6]]


## Define models

In [3]:
model1 = Attention_mlp(embedding_size=1024,
                        pred_attention_type=None,
                       relation_type="concat", 
                       regression_hidden_sizes=[24,16], output_size=1,
                         device=torch.device(type="cpu"), batch_size=2)

model2 = Attention_mlp(embedding_size=1024,
                        pred_attention_type="const-span-attention",
                       relation_type="concat", 
                       regression_hidden_sizes=[24,16], output_size=1,
                         device=torch.device(type="cpu"), batch_size=2)

model3 = Attention_mlp(embedding_size=1024,
                        pred_attention_type="param-span-attention",
                       relation_type="concat", 
                       regression_hidden_sizes=[24,16], output_size=1,
                         device=torch.device(type="cpu"), batch_size=2)

model4 = Attention_mlp(embedding_size=1024,
                        pred_attention_type=None,
                       relation_type="param-sent-attention", 
                       regression_hidden_sizes=[24,16], output_size=1,
                         device=torch.device(type="cpu"), batch_size=2)

model5 = Attention_mlp(embedding_size=1024,
                        pred_attention_type="const-span-attention",
                       relation_type="param-sent-attention", 
                       regression_hidden_sizes=[24,16], output_size=1,
                         device=torch.device(type="cpu"), batch_size=2)

model6 = Attention_mlp(embedding_size=1024,
                        pred_attention_type="param-span-attention",
                       relation_type="param-sent-attention", 
                       regression_hidden_sizes=[24,16], output_size=1,
                         device=torch.device(type="cpu"), batch_size=2)

models = [model1, model2, model3, model4, model5, model6]

## Toy Outputs

In [4]:
for model in models:
    print(model(X[:2], spans[:2], roots[:2]))

tensor([-0.1136, -0.2035], grad_fn=<SqueezeBackward0>)
tensor([-0.1834, -0.0989], grad_fn=<SqueezeBackward0>)
tensor([0.1353, 0.1841], grad_fn=<SqueezeBackward0>)
tensor([0.0918, 0.0200], grad_fn=<SqueezeBackward0>)
tensor([0.2911, 0.2882], grad_fn=<SqueezeBackward0>)
tensor([-0.1958, -0.2061], grad_fn=<SqueezeBackward0>)


## Model parameters

In [5]:
for i,model in enumerate(models):
    print("########## .   Model {} Parameters   ##############\n".format(i+1))
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)
    print("##############################################\n")


########## .   Model 1 Parameters   ##############

linear_maps.0.weight torch.Size([24, 2048])
linear_maps.0.bias torch.Size([24])
linear_maps.1.weight torch.Size([16, 24])
linear_maps.1.bias torch.Size([16])
linear_maps.2.weight torch.Size([1, 16])
linear_maps.2.bias torch.Size([1])
##############################################

########## .   Model 2 Parameters   ##############

linear_maps.0.weight torch.Size([24, 2048])
linear_maps.0.bias torch.Size([24])
linear_maps.1.weight torch.Size([16, 24])
linear_maps.1.bias torch.Size([16])
linear_maps.2.weight torch.Size([1, 16])
linear_maps.2.bias torch.Size([1])
att_map.weight torch.Size([1, 1024])
##############################################

########## .   Model 3 Parameters   ##############

linear_maps.0.weight torch.Size([24, 2048])
linear_maps.0.bias torch.Size([24])
linear_maps.1.weight torch.Size([16, 24])
linear_maps.1.bias torch.Size([16])
linear_maps.2.weight torch.Size([1, 16])
linear_maps.2.bias torch.Size([1])
att_map.w