In [1]:
import torch
from torch import nn, optim
from torch.nn import functional as F
torch.set_printoptions(sci_mode=False)

In [2]:
device = 'cuda:0'

In [3]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace, Sequence, ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.trainers import BpeTrainer

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.decoder = ByteLevelDecoder()

trainer = BpeTrainer(special_tokens=["[UNK]", "[START]", "[END]", "[PAD]"], vocab_size=16384)

files = ["corpus150.txt"]
tokenizer.train(files, trainer)

#tokenizer.encode('[START] A duck is a carnivorous animal')

In [5]:
import os
import random
import codecs
from timeit import timeit

class Data:
    def __init__(self):
        filename = 'corpus150.txt'
        
        self.file = codecs.open(filename, 'r', encoding='utf-8', errors='ignore')
        self.file_length = os.stat(filename).st_size
        
        print('Loaded dataset file of size', self.file_length)
        
    def sample_batch(self, n=32, length=240):
        # sample a lot of strings of certain length
        strs = []
        for i in range(n):
            self.file.seek(random.randrange(0, self.file_length - length))
            strs.append(self.file.read(length))
        
        # encode with tokenizer
        x = [encoding.ids for encoding in tokenizer.encode_batch(strs)]
        
        # shorten the long ones
        min_len = min(map(len, x))
        x = [ids[0:min_len] for ids in x]
        
        # put it into pytorch preferred format (torch.tensor, with shape (sequence, batch))
        x = torch.tensor(x)
        x = x.transpose(1, 0)
        
        return x
        
dataset = Data()

#timeit(dataset.sample_batch, number=100) / 100

Loaded dataset file of size 1537774


In [6]:
#(torch.rand((2,2)) > 0.8).float() * torch.ones()

In [7]:
def checker_board(d_model):
    half = (d_model) // 2
    texture = torch.cat([
        torch.ones((half, 1)),
        torch.zeros((half, 1))
    ], dim=1).view((-1,))
    
    return texture

print(checker_board(8))
print(-checker_board(8) + 1)

tensor([1., 0., 1., 0., 1., 0., 1., 0.])
tensor([0., 1., 0., 1., 0., 1., 0., 1.])


In [8]:
def pos_embedding(x):
        # x: (pos, n, i)
        
        length = x.shape[0]
        batch_size = x.shape[1]
        d_model = x.shape[2]

        i = torch.arange(0, d_model).view((1, 1, -1)).expand(length, -1, d_model).to(device).float()
        pos = torch.arange(0, length).view((-1, 1, 1)).expand(length, -1, d_model).to(device).float()
        
        z = pos / 10000 ** (i / d_model)
        
        sin = torch.sin(z)
        cos = torch.cos(z)
        
        sin_mask = checker_board(d_model).to(device)
        cos_mask = -sin_mask + 1
                
        pe = (sin_mask * sin) + (cos_mask * cos)
        pe = pe.expand(length, batch_size, d_model)
        
        return x + pe

In [9]:
class Model(nn.Module):
    def __init__(self, dropout=0.1, embedding_dim=512, heads=8, num_layers=3):
        super(Model, self).__init__()
        # config
        self.dropout = dropout
        self.embedding_dim = embedding_dim
        self.heads = heads
        self.num_layers = num_layers
        
        self.start_token = torch.tensor([[1]]).to(device)
        
        # layers
        self.embedding = nn.Embedding(num_embeddings=16384, embedding_dim=embedding_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(embedding_dim, heads, dim_feedforward=2048, dropout=dropout)
        decoder_layer = nn.TransformerDecoderLayer(embedding_dim, heads, dim_feedforward=2048, dropout=dropout)
        
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=3)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=3)
        
        self.unembedding = nn.Linear(512, 16384)
        self.unembedding.weight.data = self.embedding.weight.data
    
    def forward(self, x):
        x = self.embedding(x)
        
        source = pos_embedding(x)
        target = pos_embedding(torch.cat([
            self.embedding(self.start_token).expand(1, source.shape[1], -1),
            x
        ], dim=0)[:-1])
        
        source_mask = (torch.rand((source.shape[0], source.shape[1], 1)) > 0.1).float().expand(-1, -1, self.embedding_dim).to(device)
        source = source * source_mask
        
        memory = self.encoder(source)
        output = self.decoder(target, memory)
        
        output = self.unembedding(output)
        
        return output
        
model = Model().to(device)

In [10]:
from IPython.display import clear_output

optimizer = optim.Adam(model.parameters(), lr=0.0006)
torch.cuda.empty_cache()
for i in range(1000):
    optimizer.zero_grad()
    x = dataset.sample_batch(n=64, length=320).to(device)
    
    y = model.forward(x)
    
    #print(y[1,0])
    
    loss = nn.CrossEntropyLoss()(y.view((-1, 16384)), x.reshape((-1)))
    loss.backward()
    optimizer.step()
    
    print('loss', loss)    
    #print(x.shape, y.shape)
    #print(x[0:5, 0], y[0:5, 0].argmax(dim=1))
    print(tokenizer.decode(x[:, 0].tolist()))
    print('================================')
    print(tokenizer.decode(y[:, 0].argmax(dim=1).tolist()))
    clear_output(wait=True)
    #break

loss tensor(0.9766, device='cuda:0', grad_fn=<NllLossBackward>)
 hells) from the island. Ordnance is still buried or lying on the ground. Other items have washed down gullies and still other unexploded ordnance lies beneath the waters offshore. In 1981, the entire island was included on the National Register of Historic Places.

 hellsls from the island. Or Orance still still buried or lying on the ground. Other items have needs down burn and and still unexploded unexploded ordnance lies b bath the waters offshore. In 1981, the entire was was included on the National Reg hundreds of little scientists..

