In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as fn

torch.manual_seed(1337)

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
device

'cpu'

In [2]:
from tqdm import tqdm

In [3]:
from importlib import reload

In [4]:
class Config():
    def __init__(self, context_size, batch_size):
        self.n_layer = 1
        self.n_head = 4
        self.embed_size = 8
        self.context_size = context_size
        self.batch_size = batch_size


In [5]:
config = Config(
    context_size = 16,
    batch_size = 64
)


In [6]:
import utils.preprocess as pp
reload(pp)

<module 'utils.preprocess' from '/home/davids/projects/intro-transformer/utils/preprocess.py'>

In [7]:
dataset = pp.ShortSequenceDataset('data/restriction-sites.txt', device,
    context_size=config.context_size, batch_size=config.batch_size)
config.vocab_size = len(dataset.vocab)

In [8]:
dataset[0:10]

(tensor([[30, 30, 30, 30, 30, 30, 30, 30, 30, 15, 15,  4, 17, 19, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 15, 19, 17, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15, 15, 26,  4, 15, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,  4, 15, 15, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15, 15, 26,  4, 15, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,  4, 15, 15, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 17, 15, 26, 19, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 17, 17, 19, 19, 26],
         [30, 30, 30, 30, 30, 15, 17, 17, 26, 19, 17,  1,  9,  4, 13,  2],
         [30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 17, 17, 28, 19, 19, 26]]),
 tensor([[30, 30, 30, 30, 30, 30, 30, 30, 15, 15,  4, 17, 19, 26, 26,  0],
         [30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 15, 19, 17, 26, 26,  0],
         [30, 30, 30, 30, 30, 30, 30, 30, 15, 15, 26,  4, 15, 26, 26,  0],
         [30, 30, 30, 3

In [9]:
( dataset.decode(dataset[7][0]), dataset.decode(dataset[7][1]) )

('^^^^^^^^^A/CCGGT', '^^^^^^^^A/CCGGT$')

In [10]:
len(dataset)

456

In [11]:
dataset[0:3]

(tensor([[30, 30, 30, 30, 30, 30, 30, 30, 30, 15, 15,  4, 17, 19, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 15, 19, 17, 26, 26],
         [30, 30, 30, 30, 30, 30, 30, 30, 30, 15, 15, 26,  4, 15, 26, 26]]),
 tensor([[30, 30, 30, 30, 30, 30, 30, 30, 15, 15,  4, 17, 19, 26, 26,  0],
         [30, 30, 30, 30, 30, 30, 30, 30, 15,  4, 15, 19, 17, 26, 26,  0],
         [30, 30, 30, 30, 30, 30, 30, 30, 15, 15, 26,  4, 15, 26, 26,  0]]))

In [12]:
class GPT(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed = nn.Embedding(config.vocab_size, config.embed_size)
        self.unembed = nn.Linear(config.embed_size, config.vocab_size)

    def forward(self, X, Y=None):
        # X is (B, T), Y is (B, T)
        
        embeds = self.embed(X)  # (B, T, D)
        logits = self.unembed(embeds)  # (B, T, C)

        if Y is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            Y = Y.view(B*T)
            loss = fn.cross_entropy(logits, Y)
        
        return logits, loss

    def generate(self, X, n):
        """
        Generate n new token codes given context X.
        """
        
        for _ in range(n):
            # get prediction
            logits, loss = self(X)
            # focus on last time step
            logits = logits[:, -1, :]
            probs = fn.softmax(logits, dim=1)  # (B, C)
            new_x = torch.multinomial(probs, num_samples=1)
            X = torch.cat((X, new_x), dim=1)  # (B, T+1)
        return X

In [13]:
model = GPT(config).to(device)
batch = dataset[0:3]
logits, loss = model(batch[0], batch[1])
loss

tensor(3.7275, grad_fn=<NllLossBackward0>)

In [14]:
init_code = torch.tensor([[dataset.init_code]]).to(device)

In [15]:
dataset.decode( model.generate(init_code, 8)[0] )

'^H6T_YYS3'

In [16]:
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [17]:
n_epochs = 100
for epoch in tqdm(range(n_epochs)):
    for step, batch in enumerate(dataset.loader):
        X, Y = batch
        # evaluate loss
        logits, loss = model(X, Y)
        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()
print(loss.item())

100%|██████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 35.05it/s]

1.6972150802612305





In [20]:
dataset.decode( model.generate(init_code, 8)[0] )

'^^^H)$C$T'