In [1]:
# Get training data.
!wget https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt

--2023-04-04 12:23:56--  https://raw.githubusercontent.com/lorenlugosch/infer_missing_vowels/master/data/train/war_and_peace.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3196229 (3.0M) [text/plain]
Saving to: ‘war_and_peace.txt’


2023-04-04 12:23:57 (5.66 MB/s) - ‘war_and_peace.txt’ saved [3196229/3196229]



In [15]:
# Imports
import torch
from tqdm import tqdm
from torch import nn
import torch.nn.functional as F
import math
import IPython

In [16]:
# Encoder network
# The encoder is any network that can take as input a variable-length sequence: so, RNNs, CNNs, and self-attention/Transformer encoders will all work.

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_heads, dropout):
        super(Encoder, self).__init__()
        self.pos_encoder = PositionalEncoding(input_size)
        encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=num_heads, dim_feedforward=hidden_size, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(input_size))
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, mask):
        IPython.embed()
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_key_padding_mask=mask)
        output = self.dropout(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
# Testing output
input = torch.randn(10, 32, 512)
encoder = Encoder(512, 1024, 6, 8, 0.1)
enc_output = encoder(input, None)


Python 3.8.16 (default, Dec  7 2022, 01:27:54) 
Type 'copyright', 'credits' or 'license' for more information
IPython 8.6.0 -- An enhanced Interactive Python. Type '?' for help.

Out[1]: 
tensor([[[-2.8397e-01, -9.0466e-01, -1.3844e-01,  ...,  1.2482e-01,
          -1.6752e+00, -9.1686e-01],
         [ 2.2266e-01,  3.1771e-01, -4.3722e-01,  ..., -1.5623e-01,
          -7.4187e-01,  8.8158e-02],
         [ 1.7360e-01, -4.5565e-01, -1.2003e-01,  ...,  2.5013e-01,
           7.7036e-01,  3.0499e-01],
         ...,
         [ 5.8290e-01, -1.1187e-01,  9.9949e-01,  ...,  4.4725e-01,
          -9.8940e-01,  1.5680e-01],
         [-6.3546e-01, -2.5261e-01,  1.5095e+00,  ..., -1.8255e+00,
           8.3142e-01, -2.1155e-01],
         [-4.5400e-01, -8.5388e-01, -2.2525e-01,  ..., -2.6675e-01,
          -3.1050e-01,  1.5431e+00]],

        [[-1.3305e+00, -5.4794e-01,  9.7787e-02,  ...,  1.3837e+00,
          -1.2150e+00,  1.0774e+00],
         [-1.1048e+00, -5.8884e-01, -9.2180e-01,  ..., -3.780

KeyboardInterrupt: Interrupted by user

In [14]:
# Predictor network
# The predictor is any causal network (= can't look at the future): in other words, unidirectional RNNs, causal convolutions, or masked self-attention.

class MaskedSelfAttention(nn.Module):
    def __init__(self, input_size, hidden_size, num_heads, dropout):
        super(MaskedSelfAttention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.attention = nn.MultiheadAttention(input_size, num_heads, dropout)
        self.layer_norm = nn.LayerNorm(input_size)

    def forward(self, x, mask):
        # x: (seq_len, batch_size, input_size)
        x = x.transpose(0, 1) # (batch_size, seq_len, input_size)
        self.attention_output, _ = self.attention(x, x, x, attn_mask=mask)
        # attention_output: (batch_size, seq_len, input_size)
        x = x + self.attention_output
        x = self.layer_norm(x)
        x = x.transpose(0, 1) # (seq_len, batch_size, input_size)
        return x

class PredictorNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_heads, dropout):
        super(PredictorNetwork, self).__init__()
        self.masked_self_attention_layers = nn.ModuleList([
            MaskedSelfAttention(input_size, hidden_size, num_heads, dropout) for _ in range(num_layers)
        ])
        self.feedforward_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, input_size),
                nn.Dropout(dropout),
            ) for _ in range(num_layers)
        ])

    def forward(self, x, mask):
        for attention_layer, feedforward_layer in zip(self.masked_self_attention_layers, self.feedforward_layers):
            x = attention_layer(x, mask)
            x = feedforward_layer(x)
        return x
    
# Testing output with a mask
input = torch.randn(10, 32, 512)
predictor = PredictorNetwork(512, 1024, 6, 8, 0.1)

mask = torch.zeros(32, 32)
mask = torch.triu(mask, diagonal=1)
mask = mask.bool()

pred_output = predictor(input, mask)
pred_output.shape

torch.Size([10, 32, 512])