<a href="https://colab.research.google.com/github/hnipun/ColabProjects/blob/master/neox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install labml



In [2]:
import math

import torch
from torch import optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import labml.utils.download
from labml import lab
from labml import logger, monit
import torch.nn.functional as F

In [3]:
SEQ_LENGTH = 128
BATCH_SIZE = 64


class CustomTextDataset(Dataset):
    def __init__(self):
        labml.utils.download.download_file(
            'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',
            lab.get_data_path() / 'tiny_shakespeare.txt',
        )
        with open(lab.get_data_path() / 'tiny_shakespeare.txt', "r") as f:
            self.text = f.read()
        self.text = list(self.text)
        self.vocab = {t: ids for ids, t in enumerate(set(self.text))}
        self.text_tensor = torch.tensor(self._to_vocab(self.text))

    def _to_vocab(self, text):
        res = []

        for t in text:
            res.append(self.vocab[t])

        return res

    def __len__(self):
        return len(self.text) // SEQ_LENGTH

    def __getitem__(self, idx):
        return self.text_tensor[SEQ_LENGTH * idx:SEQ_LENGTH * (idx + 1)]


In [4]:
ctd = CustomTextDataset()
train_dataloader = DataLoader(CustomTextDataset(), batch_size=BATCH_SIZE, shuffle=True)

In [5]:
next(iter(train_dataloader))

tensor([[ 2, 12,  7,  ..., 34, 18, 34],
        [40, 34, 59,  ..., 58, 29,  1],
        [54, 53, 30,  ..., 45,  1, 59],
        ...,
        [ 0, 55, 45,  ..., 59, 59, 32],
        [36, 12,  7,  ..., 51, 45,  0],
        [64, 45,  0,  ...,  0, 11, 39]])

In [6]:
import torch.nn as nn

In [7]:
NUM_EMBEDDINGS = 65
EMBEDDING_DIM = 16
N_TRANSOFRMER_LAYERS = 3


NUM_ATTENTION_HEADS = 8
NUM_FEATURES = 32

In [8]:
device = torch.device('cuda')
print(device)

def get_mask(seq_length): # write a test for mask
  # batch_size, seq_length, seq_length, heads is enough
  # 1, seq_length, seq_length, 1 is enough
  a = torch.ones(seq_length, seq_length, dtype=bool)
  a = torch.triu(a, diagonal=1) 
  a = torch.unsqueeze(a, 0)
  a = torch.unsqueeze(a, -1)

  return a

class MultiHeadAttentionLayer(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.norm_layer = nn.LayerNorm(EMBEDDING_DIM)
    
    self.key_linear = torch.nn.Linear(EMBEDDING_DIM, NUM_ATTENTION_HEADS * NUM_FEATURES)
    self.query_linear = torch.nn.Linear(EMBEDDING_DIM, NUM_ATTENTION_HEADS * NUM_FEATURES)
    self.value_linear = torch.nn.Linear(EMBEDDING_DIM, NUM_ATTENTION_HEADS * NUM_FEATURES)

    self.softmax = torch.nn.Softmax(dim=2)
    self.linear_layer =  torch.nn.Linear(NUM_ATTENTION_HEADS * NUM_FEATURES, EMBEDDING_DIM)


  def forward(self, x):
    x_residual = x
    x = self.norm_layer(x)
    seq_length = x.shape[1]
    kx = self.key_linear(x)
    # logger.log(f'kx layer {kx.size()}')
    kx = kx.view(BATCH_SIZE,seq_length , NUM_ATTENTION_HEADS, NUM_FEATURES)
    # logger.log(f'kx layer {kx.size()}')

    qx = self.query_linear(x)
    qx = kx.view(BATCH_SIZE, seq_length, NUM_ATTENTION_HEADS, NUM_FEATURES)

    vx = self.value_linear(x)
    vx = kx.view(BATCH_SIZE, seq_length, NUM_ATTENTION_HEADS, NUM_FEATURES)

    score = torch.einsum('bihd,bjhd ->bijh', qx, kx) #is this correct ?
    score = score/(NUM_FEATURES)**0.5

    # logger.log(f'score layer {score.size()}')

    mask = get_mask(seq_length).to(x.device)
    score = score.masked_fill(mask, float("-inf"))

    probs = self.softmax(score)

    # logger.log(f'probs layer {probs.size()}')

    output = torch.einsum('bijh,bjhd ->bihd', probs, vx)
    # logger.log(f'output layer {output.size()}')
    output = output.reshape(BATCH_SIZE, seq_length, NUM_ATTENTION_HEADS * NUM_FEATURES)
    # logger.log(f'output layer {output.size()}')
    output = self.linear_layer(output)
    # logger.log(f'output layer {output.size()}')

    # logger.log(f'kx layer {kx.size()}')

    return x + x_residual


class FeedForwardLayer(nn.Module):
  def __init__(self):
    self.norm_layer = nn.LayerNorm(EMBEDDING_DIM)
    self.linear_layer1 =  torch.nn.Linear(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM/2)
    self.activation_layer = nn.ReLU()
    self.linear_layer2 =  torch.nn.Linear(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_DIM)


  def forward(self, x):
    x_residual = x
    x = self.norm_layer(x)
    x = self.linear_layer1(x)
    x = self.activation_layer(x)
    x = self.linear_layer2(x)

    return x + x_residual


class TrnsformerLayer(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.multi_head_attention = MultiHeadAttentionLayer()
    self.feed_forward_layer = FeedForwardLayer()

  def forward(self, x): # (batch_size, seq_length, emd_dim)
    x = self.multi_head_attention(x)
    # logger.log(f'multihead attention layer {x.size()}')
    return self.feed_forward_layer(x)

cuda


In [9]:
class Trsnsformer(nn.Module):
  def __init__(self):
        super().__init__()
        self.embeddings = torch.nn.Embedding(NUM_EMBEDDINGS, EMBEDDING_DIM)
        self.transormer_layers = nn.ModuleList([TrnsformerLayer()] * N_TRANSOFRMER_LAYERS)
        self.readout_layer = torch.nn.Linear(EMBEDDING_DIM, NUM_EMBEDDINGS)

  def forward(self, x):
        # logger.log(f'input {x.size()}') # (batch_size, seq_length)
        x = self.embeddings(x)
        # logger.log(f'embedding layer {x.size()}') # (batch_size, seq_length, emd_dim)
        for transormer_layer in self.transormer_layers:
          x = transormer_layer(x)
        # logger.log(f'transormer layer {x.size()}') # (batch_size, seq_length, emd_dim)

        return self.readout_layer(x) # (batch_size, seq_length, vocab_size)

In [11]:
t = Trsnsformer()
t(next(iter(train_dataloader))).size()

torch.Size([64, 128, 65])

In [11]:
model = Trsnsformer()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

NUM_EPOCHS = 100
for epoch in monit.loop(NUM_EPOCHS):
    for train_data in monit.iterate('train', train_dataloader):
      train_data = train_data.to(device)

      if train_data.size()[0] < 64:
        continue

      optimizer.zero_grad()
      out = model(train_data[:,:-1])
      out = out.to(device)

      loss = F.cross_entropy(out.view(BATCH_SIZE*127, 65),train_data[:,1:].reshape(BATCH_SIZE*127))
      loss.backward()
      optimizer.step()

    logger.log(str(loss.item()))

AttributeError: ignored

In [34]:
d = next(iter(train_dataloader))
print(d.size())
print(d[:,:-1].size())
print(d[:,1:].size())

torch.Size([64, 128])
torch.Size([64, 127])
torch.Size([64, 127])
