In [38]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import io
import math
import numpy

In [39]:
EMBEDDING_SIZE = 10
input_text = 'text.txt'

char_to_token = {}
token_to_char = {}
def load_char_to_token():
    global char_to_token
    with io.open(input_text, 'r', encoding='utf-8') as f:
        while True:
            c = f.read(1)
            if not c:
                break
            
            if c not in char_to_token:
                next_token = len(char_to_token)
                char_to_token[c] = next_token
                token_to_char[next_token] = c
load_char_to_token()

def tokens_from_file(path):
    tokens = []
    with io.open(input_text, 'r', encoding='utf-8') as f:
        tokens = [[char_to_token[c] for c in f.read()]]
    return torch.LongTensor(tokens)

def tokens_to_string(tokens):
    if isinstance(tokens, torch.Tensor):
        return ''.join([token_to_char[t.item()] for t in tokens.squeeze(0)])
    else:
        return ''.join([token_to_char[t] for t in tokens])

_embedding = nn.Embedding(len(char_to_token), EMBEDDING_SIZE)
_embedding.weight.data.uniform_(-0.1, 0.1)

def get_embedding_from_str(in_str):
    tokens = [[char_to_token[c]] for c in in_str]
    return _embedding(torch.LongTensor(tokens))

def get_embedding(tensor):
    return _embedding(tensor)

In [40]:
# data tests
file_tokens = tokens_from_file(input_text)
assert torch.all(torch.eq(file_tokens[0, :10], torch.Tensor([[0, 0, 0, 0, 1, 2, 3, 4, 5, 6]])))
assert tokens_to_string([1, 2, 3, 4, 23]) == '™😀Thf'
assert tokens_to_string(file_tokens).startswith('\x00\x00\x00\x00™😀This ™is a tutorial on how to train a seq')

In [41]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [42]:
sz = 1
src_mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
src_mask = src_mask.float().masked_fill(src_mask == 0, float('-inf')).masked_fill(src_mask == 1, float(0.0))
src_mask

tensor([[0.]])

In [43]:
nlayers = 10
nhead = 5
nhid = 10
dropout = 0.0

pe = PositionalEncoding(EMBEDDING_SIZE, dropout=dropout)
encoder_layers = TransformerEncoderLayer(EMBEDDING_SIZE, nhead, nhid, dropout)
transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
decoder = nn.Linear(EMBEDDING_SIZE, len(char_to_token))
decoder.bias.data.zero_()
decoder.weight.data.uniform_(-0.1, 0.1)
softmax_layer = nn.Softmax(dim=2)

def forward_pass(input_tokens):
    input_x = get_embedding(input_tokens)
    positional_embedding = pe(input_x)

    assert input_x.size() == torch.Size([1, 15, 10])
    assert positional_embedding.size() == torch.Size([1, 15, 10])
    assert target_y.size()[0] == input_x.size()[1]
    
    output = transformer_encoder(positional_embedding, src_mask)  # src_mask ?

    # why is this `torch.Size([1, 750, 52])` ? I want the output to be the next character `torch.Size([1, 1, 52])`
    output = decoder(output)

    assert output.size()[1] == target_y.size()[0]
    
    return output

def full_forward_pass(input_tokens):
    assert input_tokens.size()[0] == 1

    output = forward_pass(input_tokens)
    output = softmax_layer(output)

    sampled_tokens = []
    for d1 in output:
        for d2 in d1:
            sampled_tokens.append(numpy.random.choice(len(char_to_token), 1, p=d2.detach().numpy())[0])

    return tokens_to_string(sampled_tokens)

In [44]:
criterion = nn.CrossEntropyLoss()
lr = 0.1 # learning rate
all_params = list(_embedding.parameters()) + list(transformer_encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.SGD(all_params, lr=lr)

In [45]:
INPUT_SIZE = 15
file_tokens = tokens_from_file(input_text)

# print(full_forward_pass(input_x))
input_x = file_tokens[:, 0:INPUT_SIZE]
target_y = file_tokens[0, INPUT_SIZE:INPUT_SIZE + INPUT_SIZE] # todo: batch x/y
print('Input : ' + tokens_to_string(input_x))
print('Target: ' + tokens_to_string(target_y))

print(full_forward_pass(input_x))
for i in range(0, 1000):
    output = forward_pass(input_x)
    
    optimizer.zero_grad()
    loss = criterion(output.view(-1, len(char_to_token)), target_y)
    print('loss={}'.format(loss))
    loss.backward()
    optimizer.step()
print(full_forward_pass(input_x))

Input :     ™😀This ™is 
Target: a tutorial on h
h,t+i/+Pgk1ikak
loss=3.9363009929656982
loss=3.813607692718506
loss=3.699815273284912
loss=3.5930111408233643
loss=3.4915006160736084
loss=3.3945586681365967
loss=3.3021998405456543
loss=3.2146971225738525
loss=3.1319186687469482
loss=3.0546085834503174
loss=2.9831597805023193
loss=2.9175543785095215
loss=2.8576085567474365
loss=2.803119659423828
loss=2.7536938190460205
loss=2.7089877128601074
loss=2.6685683727264404
loss=2.632028818130493
loss=2.598968029022217
loss=2.569014072418213
loss=2.541811466217041
loss=2.517005443572998
loss=2.494253396987915
loss=2.4733152389526367
loss=2.45390248298645
loss=2.4358396530151367
loss=2.418870687484741
loss=2.403000593185425
loss=2.3878915309906006
loss=2.3740644454956055
loss=2.363231658935547
loss=2.3611907958984375
loss=2.3905413150787354
loss=2.3361425399780273
loss=2.323998212814331
loss=2.3287689685821533
loss=2.3063817024230957
loss=2.320010185241699
loss=2.2854959964752197
loss=2.292831659

loss=0.8581545352935791
loss=0.847164511680603
loss=0.8396280407905579
loss=0.8375439047813416
loss=0.8517509698867798
loss=0.8462415337562561
loss=0.8871976137161255
loss=0.8402477502822876
loss=0.8354102969169617
loss=0.8332251310348511
loss=0.8454775214195251
loss=0.837508499622345
loss=0.8579017519950867
loss=0.8367950320243835
loss=0.8433831334114075
loss=0.8263952136039734
loss=0.8251783847808838
loss=0.816170334815979
loss=0.8150919079780579
loss=0.8083489537239075
loss=0.8065115809440613
loss=0.8008866906166077
loss=0.7983406186103821
loss=0.7956148386001587
loss=0.7938087582588196
loss=0.7920390367507935
loss=0.7906805872917175
loss=0.7893834710121155
loss=0.7882734537124634
loss=0.787223756313324
loss=0.7862738966941833
loss=0.7853627800941467
loss=0.7845226526260376
loss=0.78371262550354
loss=0.782952070236206
loss=0.7822150588035583
loss=0.7815171480178833
loss=0.780839204788208
loss=0.7801944613456726
loss=0.7795661687850952
loss=0.7789669036865234
loss=0.7783814072608948


loss=0.8919863104820251
loss=0.8508689999580383
loss=0.833502471446991
loss=0.8234732151031494
loss=0.8158571124076843
loss=0.809741199016571
loss=0.8046063780784607
loss=0.8001479506492615
loss=0.796467125415802
loss=0.7935218811035156
loss=0.7909994721412659
loss=0.7888745069503784
loss=0.7870274782180786
loss=0.7853626608848572
loss=0.783816933631897
loss=0.7824181914329529
loss=0.781092643737793
loss=0.7798481583595276
loss=0.7786790132522583
loss=0.7776117920875549
loss=0.7766193151473999
loss=0.7756927609443665
loss=0.774821400642395
loss=0.7739995121955872
loss=0.7732225656509399
loss=0.7724856734275818
loss=0.7717859148979187
loss=0.7711225748062134
loss=0.770485520362854
loss=0.7698779702186584
loss=0.7692974805831909
loss=0.7687426209449768
loss=0.768214762210846
loss=0.7677106261253357
loss=0.7672275304794312
loss=0.7667643427848816
loss=0.7663194537162781
loss=0.7658847570419312
loss=0.7654721140861511
loss=0.7650667428970337
loss=0.7646773457527161
loss=0.7643033266067505


In [None]:
encoder_layer = nn.TransformerEncoderLayer(d_model=EMBEDDING_SIZE, nhead=10)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(1, 750, EMBEDDING_SIZE)
transformer_encoder(src)

In [None]:
transformer_encoder.layers[0].self_attn.out_proj.weight

In [None]:
char_to_token['d']

In [None]:
a = torch.randn(10)

In [None]:
import torch.nn as nn

In [None]:
dictionary = 'abcdefghiklmnopqrstuvwxyz '
def tokenize(s):
    return [dictionary.index(c) for c in s]
embed = nn.Embedding(len(dictionary), 3)
embed(torch.LongTensor(tokenize('hello world')))

In [None]:
embed(torch.LongTensor([10]))

In [None]:
embed.weight

In [None]:
torch.LongTensor([1, 2])

In [None]:
torch.LongTensor(1)

In [None]:
dictionary.index('z')

In [None]:
import math
max_len = 8
d_model = 4
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
pe