In [1]:
from k import model

import torch
from torch import nn
import torch.nn.functional as F
import torch.distributions as dist

import numpy as np

from argparse import ArgumentParser, Namespace
import random, tqdm, gzip
import os
from k import PileDataset
from transformers import GPT2TokenizerFast

In [2]:
def sample(lnprobs, temperature=1.0):
    if temperature == 0.0:
        return lnprobs.argmax()

    p = F.softmax(lnprobs / temperature, dim=0)
    cd = dist.Categorical(p)
    return cd.sample()

In [3]:
def get_dataloader(batch_size):
    dataset = PileDataset()
    return torch.utils.data.DataLoader(dataset, batch_size, num_workers=2, shuffle=True)

In [4]:
def sample_batch(data, seq_len, batch_size):
    starts = torch.randint(size=(batch_size,), low=0, high=data.size(0) - seq_len - 1)
    seqs_inputs = [data[start : start + seq_len] for start in starts]
    seqs_target = [data[start + 1 : start + seq_len + 1] for start in starts]
    inputs = torch.cat([s[None, :] for s in seqs_inputs], dim=0).to(torch.long)
    target = torch.cat([s[None, :] for s in seqs_target], dim=0).to(torch.long)
    return inputs, target

In [5]:
def sample_sequence(
    transformer, seed, max_context, length, temperature=0.5, verbose=False
):
    sequence = seed.detach().clone()

    if verbose:
        print("Prompt:")
        print(GPT2TokenizerFast.from_pretrained("gpt2").decode(seed[0][:]))

    print("Generation:")
    for _ in range(length):
        input = sequence[-max_context:]
        output = transformer(input)
        c = sample(output[0, -1, :], temperature)

        if verbose:
            print(GPT2TokenizerFast.from_pretrained("gpt2").decode(c))
            
        sequence = torch.cat([sequence, c.unsqueeze(0).unsqueeze(0)], dim=1)
    return sequence

In [6]:
def infer(params):
    if params.seed < 0:
        seed = random.randint(0, 1000000)
    else:
        torch.manual_seed(params.seed)

    data_train = get_dataloader(params.batch_size)
    """ data_train, data_test = (
        (torch.cat([data_train, data_val], dim=0), data_test)
        if arg.final
        else (data_train, data_val)
    ) """
    
    transformer = model.Transformer(
        params.embedding_dim,
        params.num_heads,
        params.mask,
        params.dropout,
        params.forward_expansion,
        params.depth,
        params.seq_len,
        params.vocab_len,
        params.device,
    ).to(params.device)

    opt = torch.optim.Adam(lr=params.lr, params=transformer.parameters())
    sch = torch.optim.lr_scheduler.LambdaLR(
        opt, lambda i: min(i / (params.lr_warmup / params.batch_size), 1.0)
    )

    instances_seen = 0
    for i in tqdm.trange(params.num_batches):
        opt.zero_grad()
        raw = next(iter(data_train))
        source = raw[:,0:params.seq_len]
        source_len = source.size(1)
        pad_len = params.seq_len - source_len
        source = torch.cat([source, torch.zeros([1, pad_len], dtype=torch.long)], dim=1)
        target = raw[:,1:params.seq_len + pad_len + 1]
        data_test = raw[:,:params.test_size]
        instances_seen += len(data_train)
        if params.device != "cpu" and  torch.cuda.is_available():
            source, target = source.cuda(), target.cuda()
        output = transformer(source)
        loss = F.nll_loss(output.transpose(2, 1), target, reduction="mean")
        loss.backward()
        if params.gradient_clipping > 0.0:
            nn.utils.clip_grad_norm_(transformer.parameters(), params.gradient_clipping)
        opt.step()
        sch.step()

        if i != 0 and (i % params.test_every == 0 or i == params.num_batches - 1):
            with torch.no_grad():
                seedfr = random.randint(0, data_test.size(1))
                seed = data_test[:,seedfr : seedfr + params.seq_len].to(torch.long)
                if params.device != "cpu" and torch.cuda.is_available():
                    seed = seed.cuda()

                sample_sequence(
                    transformer,
                    seed,
                    max_context=params.seq_len,
                    verbose=True,
                    length=params.sample_length,
                )

In [8]:
num_batches = 100
batch_size = 1
data = None
embedding_dim = 2
num_heads = 2
mask = False
dropout = 0.8
forward_expansion = 4
depth = 8
seq_len = 128
vocab_len = 52015
device = "cuda"
seed = -1
lr = 0.0001
final = False
test_every = 1500
gradient_clipping = 1
lr_warmup = 5000
sample_length = 600
test_size = 25

params = Namespace(
    num_batches=num_batches,
    batch_size=batch_size,
    data=data,
    embedding_dim=embedding_dim,
    num_heads=num_heads,
    mask=mask,
    dropout=dropout,
    forward_expansion=forward_expansion,
    depth=depth,
    seq_len=seq_len,
    vocab_len=vocab_len,
    device=device,
    seed=seed,
    lr=lr,
    final=final,
    test_every=test_every,
    gradient_clipping=gradient_clipping,
    lr_warmup=lr_warmup,
    sample_length=sample_length,
    test_size = test_size
)

infer(params)

 99%|█████████▉| 99/100 [00:11<00:00,  8.82it/s]

Prompt:
 understand a Chinese ship is working to map it.ckerAtopobium vagina
Generation:
 GoPro
new
MAC
 sulf
Changes
 traveler
 filtered
HB
gro
 tends
 degree
 swift
 conferences
network
 predator
 Noon
elfare
 libertarians
 hars
Insert
 APIs
772
pleasant
 renewable


 99%|█████████▉| 99/100 [01:44<00:01,  1.06s/it]


KeyboardInterrupt: 