In [1]:
import spacy, random, math, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.datasets import TranslationDataset, Multi30k, IWSLT
from torchtext.data import Field, BucketIterator, RawField, Dataset

%load_ext autoreload
%autoreload 2

#### Experiment with just GCN

In [2]:
class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout = 0.2):
        """
        each layer has the following form of computation
        H = f(A * H * W)
        H: (b, seq len, ninp)
        A: (b, seq len, seq len)
        W: (ninp, nout)
        """
        super(GCNLayer, self).__init__()
        self.W = nn.Parameter(torch.randn(input_dim, output_dim))
        self.b = nn.Parameter(torch.randn(output_dim))
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, A):
        """
        H = relu(A * x * W)
        x: (b, seq len, ninp)
        A: (b, seq len, seq len)
        W: (ninp, nout)
        """
        x = self.dropout(x)
        x = torch.bmm(A, x)  # x: (b, seq len, ninp)
        x = x.matmul(self.W) + self.b
        x = self.relu(x)
        return x
    
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

#### Baby Test

From the test, we know the right way to batch matrix is:
1. align the batch dim
2. with in the same batch, pad smaller matrix on the right and bottom

In [11]:
layer1 = GCNLayer(3, 2, 0)
layer1.apply(initialize_weights)

GCNLayer(
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
)

In [12]:
bs = 1
A = torch.tensor([[[1,0,0],
                   [1,1,1], 
                   [0,1,1]],
                  [[1,0,0],
                   [1,1,0], 
                   [0,0,0]]], dtype=torch.float)
x = torch.tensor([[[1,2,3],
                   [4,5,6], 
                   [7,8,9]],
                  [[100, 200, 300],
                   [200, 300, 400], 
                   [0, 0, 0]]], dtype=torch.float)

In [13]:
layer1(x, A)

tensor([[[ 1.3465,  1.7680],
         [ 0.0000,  6.5830],
         [ 0.0000,  7.0198]],

        [[55.2780,  0.0000],
         [75.6653,  0.0000],
         [ 0.8018,  2.2048]]], grad_fn=<ReluBackward0>)

In [14]:
x[0]

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

In [15]:
layer1.W

Parameter containing:
tensor([[-0.5969,  0.5413],
        [-0.3737,  1.2504],
        [ 0.6297, -1.1596]], requires_grad=True)

In [17]:
b = 1
A[b].matmul(x[b]).matmul(layer1.W)+layer1.b

tensor([[ 55.2780, -41.4739],
        [ 75.6653, -21.9457],
        [  0.8018,   2.2048]], grad_fn=<AddBackward0>)