In [8]:
%reload_ext autoreload
%autoreload 2

In [9]:
import sys
sys.path.insert(0, '../src/')

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import model.layers as layers

In [11]:
BATCH_SIZE = 5
HIDDEN_SIZE = 3
DROP_PROB = 0.0
SEQ_LEN = 7

## `layers.Embedding`

In [12]:
vocab_size = 11
embed_size = 6
wordvecs = torch.randn((vocab_size, embed_size))
wordvecs

tensor([[ 1.1414,  1.4722,  0.2721, -0.9674, -0.7047, -1.1054],
        [-0.8321,  1.0332,  1.0490, -0.0857,  0.7630,  1.4631],
        [-2.6773,  0.5008,  0.8457,  0.2306, -0.1543, -0.7707],
        [-0.3787,  0.2165, -0.3395, -1.4682,  0.7444,  1.1926],
        [-0.6989,  1.9091,  0.1079,  0.6685, -0.1941,  0.0660],
        [ 0.9258, -0.5765,  0.7634,  0.3824,  1.1020, -0.0477],
        [-0.7579, -1.5932,  2.9906, -0.1598, -1.1834,  0.2845],
        [ 0.1882, -0.8804, -1.1744,  0.7914, -0.4334, -0.2785],
        [-1.1068, -0.4291, -0.1081,  0.4690,  0.4352, -0.5876],
        [-1.4428, -0.5450, -0.5684,  0.6086, -0.2083,  1.4488],
        [ 0.1462,  1.6937,  1.0190,  1.5619, -0.2005, -0.2793]])

In [13]:
embed = layers.Embedding(wordvecs, HIDDEN_SIZE, DROP_PROB)
embed

Embedding(
  (embed): Embedding(11, 6)
  (proj): Linear(in_features=6, out_features=3, bias=False)
  (hwy): HighwayEncoder(
    (transforms): ModuleList(
      (0): Linear(in_features=3, out_features=3, bias=True)
      (1): Linear(in_features=3, out_features=3, bias=True)
    )
    (gates): ModuleList(
      (0): Linear(in_features=3, out_features=3, bias=True)
      (1): Linear(in_features=3, out_features=3, bias=True)
    )
  )
)

In [14]:
seq_len = 7
inpt = torch.randint(0, vocab_size, size=(BATCH_SIZE, SEQ_LEN))
inpt

tensor([[ 7,  9,  1,  7,  5, 10,  6],
        [ 9, 10,  3,  2,  5,  8,  3],
        [ 3,  3,  5,  1,  4,  3, 10],
        [10,  8, 10, 10,  0,  6, 10],
        [ 8,  6,  0,  3,  6,  8,  1]])

In [15]:
with torch.no_grad():
    embed_output = embed.embed(inpt)
    proj_output = embed.proj(embed_output)
    hwy_output = embed.hwy(proj_output)
embed_output.shape, proj_output.shape, hwy_output.shape

(torch.Size([5, 7, 6]), torch.Size([5, 7, 3]), torch.Size([5, 7, 3]))

In [16]:
assert list(embed_output.shape) == [BATCH_SIZE, SEQ_LEN, embed_size]
assert list(proj_output.shape) == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]
assert list(hwy_output.shape) == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]

In [17]:
with torch.no_grad():
    output = embed(inpt)
output

tensor([[[ 0.4233,  0.0963, -0.2024],
         [-0.0488,  0.2129, -0.1705],
         [-0.3025,  0.4297,  0.4915],
         [ 0.4233,  0.0963, -0.2024],
         [-0.0042,  0.1524,  0.3704],
         [ 0.2882,  0.9089,  0.7716],
         [-1.3659, -0.3301, -0.6784]],

        [[-0.0488,  0.2129, -0.1705],
         [ 0.2882,  0.9089,  0.7716],
         [-0.2322, -0.2023, -0.0227],
         [-0.3450, -0.4699, -0.5621],
         [-0.0042,  0.1524,  0.3704],
         [ 0.0638, -0.2787, -0.2946],
         [-0.2322, -0.2023, -0.0227]],

        [[-0.2322, -0.2023, -0.0227],
         [-0.2322, -0.2023, -0.0227],
         [-0.0042,  0.1524,  0.3704],
         [-0.3025,  0.4297,  0.4915],
         [ 0.2775,  0.5744,  0.4712],
         [-0.2322, -0.2023, -0.0227],
         [ 0.2882,  0.9089,  0.7716]],

        [[ 0.2882,  0.9089,  0.7716],
         [ 0.0638, -0.2787, -0.2946],
         [ 0.2882,  0.9089,  0.7716],
         [ 0.2882,  0.9089,  0.7716],
         [ 0.1193,  0.0700,  0.1444],
      

In [18]:
assert list(output.shape) == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]

## `layers.HighwayEncoder`

In [19]:
num_layers = 2
highway = layers.HighwayEncoder(num_layers, HIDDEN_SIZE)
highway

HighwayEncoder(
  (transforms): ModuleList(
    (0): Linear(in_features=3, out_features=3, bias=True)
    (1): Linear(in_features=3, out_features=3, bias=True)
  )
  (gates): ModuleList(
    (0): Linear(in_features=3, out_features=3, bias=True)
    (1): Linear(in_features=3, out_features=3, bias=True)
  )
)

In [20]:
inpt = torch.randn((BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE))
inpt

tensor([[[-2.7623,  0.2548, -0.8532],
         [-0.6439,  1.2678,  0.1147],
         [-1.0045, -0.3908, -0.6559],
         [ 0.4075, -1.7423, -1.5121],
         [-0.4982,  0.7425, -0.0890],
         [-0.4536,  0.8275, -0.2277],
         [ 1.0304,  0.0636, -2.3095]],

        [[-0.4858,  0.1347, -1.6507],
         [ 0.1572,  0.3561,  0.3294],
         [-0.1289, -0.8367, -1.1735],
         [ 0.1685, -1.1265, -0.5369],
         [ 0.4764,  0.0111,  1.6539],
         [ 0.8682, -0.5777,  0.9782],
         [ 0.0155, -0.3161, -0.4211]],

        [[ 0.6981, -1.1824,  0.8674],
         [-0.0533,  0.2441, -0.0291],
         [ 1.2880, -0.1411, -1.1692],
         [-1.5560, -0.1891,  0.5358],
         [-1.9611, -0.1812, -0.5768],
         [ 0.4714, -0.1694,  0.5322],
         [ 0.8916, -0.5556, -0.0128]],

        [[-0.6312, -0.0763, -0.5359],
         [ 0.5823, -0.9829,  0.9143],
         [ 0.4783,  0.0525, -0.7761],
         [ 0.0366,  1.2533,  0.6225],
         [-0.5989, -0.4228, -0.8854],
      

In [21]:
with torch.no_grad():
    output = highway(inpt)
output

tensor([[[-2.1061,  0.3132, -0.4994],
         [-0.2118,  0.8178,  0.1355],
         [-0.8197, -0.2709, -0.4870],
         [ 0.4071, -1.5224, -1.1336],
         [-0.2439,  0.5121, -0.0488],
         [-0.2092,  0.5703, -0.1436],
         [ 0.8281,  0.0540, -1.3558]],

        [[-0.3821,  0.0963, -1.0313],
         [ 0.2335,  0.2770,  0.2613],
         [-0.1056, -0.6733, -0.8632],
         [ 0.1647, -0.9303, -0.4257],
         [ 0.5562,  0.0577,  1.4045],
         [ 0.8008, -0.4622,  0.8325],
         [ 0.0252, -0.2509, -0.3225]],

        [[ 0.6569, -0.9875,  0.7496],
         [ 0.0318,  0.1861, -0.0222],
         [ 1.0667, -0.1238, -0.8322],
         [-1.1935, -0.0457,  0.4204],
         [-1.5594, -0.0109, -0.3924],
         [ 0.4552, -0.1171,  0.4337],
         [ 0.7849, -0.4652, -0.0103]],

        [[-0.4947, -0.0548, -0.3970],
         [ 0.5528, -0.8002,  0.7849],
         [ 0.4054,  0.0429, -0.5665],
         [ 0.3169,  0.9120,  0.5386],
         [-0.4926, -0.3127, -0.6496],
      

In [22]:
assert list(output.shape) == [BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE]