In [2]:
import spacy, random, math, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import TranslationDataset, Multi30k, IWSLT
from torchtext.data import Field, BucketIterator, RawField, Dataset

%load_ext autoreload
%autoreload 2

#### Experiment with just GCN

In [14]:
class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout = 0.2):
        """
        each layer has the following form of computation
        H = f(A * H * W)
        H: (b, seq len, ninp)
        A: (b, seq len, seq len)
        W: (ninp, nout)
        """
        super(GCNLayer, self).__init__()
        self.W = nn.Parameter(torch.randn(input_dim, output_dim))
        self.b = nn.Parameter(torch.randn(output_dim))
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, A):
        """
        H = relu(A * x * W)
        x: (b, seq len, ninp)
        A: (b, seq len, seq len)
        W: (ninp, nout)
        """
        x = self.dropout(x)
        x = torch.bmm(A, x)  # x: (b, seq len, ninp)
        x = x.matmul(self.W) + self.b
        x = self.relu(x)
        return x
    
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

#### Baby Test

From the test, we know the right way to batch matrix is:
1. align the batch dim
2. with in the same batch, pad smaller matrix on the right and bottom

In [15]:
layer1 = GCNLayer(4, 2, 0)
layer1.apply(initialize_weights)

GCNLayer(
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
)

In [16]:
bs = 1
A = torch.tensor([[[1,0,0],
                   [1,1,1], 
                   [0,1,1]],
                  [[1,0,0],
                   [1,1,0], 
                   [0,0,0]]], dtype=torch.float)
x = torch.tensor([[[1,2,3,4],
                   [4,5,6,7], 
                   [7,8,9,8]],
                  [[100, 200, 300, 400],
                   [200, 300, 400, 500], 
                   [0, 0, 0, 0]]], dtype=torch.float)

In [17]:
layer1(x, A)

tensor([[[0.0000e+00, 9.1991e+00],
         [0.0000e+00, 4.6133e+01],
         [0.0000e+00, 3.6460e+01]],

        [[0.0000e+00, 9.6678e+02],
         [0.0000e+00, 2.1880e+03],
         [1.3163e+00, 0.0000e+00]]], grad_fn=<ReluBackward0>)

In [18]:
x[0]

tensor([[1., 2., 3., 4.],
        [4., 5., 6., 7.],
        [7., 8., 9., 8.]])

In [19]:
layer1.W

Parameter containing:
tensor([[ 0.1265,  0.0858],
        [ 0.5089,  0.4087],
        [ 0.9336, -0.5888],
        [-3.0925,  2.6340]], requires_grad=True)

In [20]:
b = 1
A[b].matmul(x[b]).matmul(layer1.W)+layer1.b

tensor([[-8.4119e+02,  9.6678e+02],
        [-1.8360e+03,  2.1880e+03],
        [ 1.3163e+00, -4.7340e-01]], grad_fn=<AddBackward0>)