In [1]:
import torch
import string
import torch.nn as nn
import torch.nn.functional as F
import time
import math
import tqdm

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Current device:", device)


Current device: cuda


In [2]:
file = open("/content/tiny-shakespeare.txt","r",encoding="utf-8")
text = file.read()
text = text.replace("\n" , " ").lower()
punctuation_chars = string.punctuation
text = ''.join(char for char in text if char not in punctuation_chars)

In [3]:
tokens = text.split(" ")
vocab = list(set(tokens))

In [4]:
for i in tqdm.tqdm(vocab):
    if tokens.count(i) < 5:
        tokens.remove(i)
vocab = list(set(tokens))

100%|██████████| 12849/12849 [00:57<00:00, 223.69it/s]


In [5]:
vocab_to_idx = {}
idx_to_vocab = {}
vocab_size = len(vocab)
for idx,v in enumerate(vocab):
    vocab_to_idx[v] = idx
    idx_to_vocab[idx] = v

In [6]:
tokens_num = []
for i in tokens:
    tokens_num.append(vocab_to_idx[i])

In [7]:
x = []
y = []
x_num = []
y_num = []
max_len = 10
for i in range(len(tokens) - max_len - 1):
    x.append(tokens[i:max_len+i])
    y.append(tokens[max_len+i])
    x_num.append(tokens_num[i:max_len+i])
    y_num.append(tokens_num[max_len+i])

In [8]:
for i in range(10):
    print(x[i])
    print(y[i])
for i in range(10):
    print(x_num[i])
    print(y_num[i])

['first', 'citizen', 'before', 'we', 'proceed', 'any', 'further', 'hear', 'me', 'speak']

['citizen', 'before', 'we', 'proceed', 'any', 'further', 'hear', 'me', 'speak', '']
all
['before', 'we', 'proceed', 'any', 'further', 'hear', 'me', 'speak', '', 'all']
speak
['we', 'proceed', 'any', 'further', 'hear', 'me', 'speak', '', 'all', 'speak']
speak
['proceed', 'any', 'further', 'hear', 'me', 'speak', '', 'all', 'speak', 'speak']

['any', 'further', 'hear', 'me', 'speak', '', 'all', 'speak', 'speak', '']
first
['further', 'hear', 'me', 'speak', '', 'all', 'speak', 'speak', '', 'first']
citizen
['hear', 'me', 'speak', '', 'all', 'speak', 'speak', '', 'first', 'citizen']
you
['me', 'speak', '', 'all', 'speak', 'speak', '', 'first', 'citizen', 'you']
are
['speak', '', 'all', 'speak', 'speak', '', 'first', 'citizen', 'you', 'are']
all
[5632, 3263, 3904, 2364, 972, 2094, 5173, 5004, 4088, 6176]
0
[3263, 3904, 2364, 972, 2094, 5173, 5004, 4088, 6176, 0]
383
[3904, 2364, 972, 2094, 5173, 5004, 4

In [9]:
dmodel = 512
heads = 4
batch_size = 32
max_len = 10
shape = (batch_size,max_len,dmodel)
sentence = torch.Tensor(x_num).long()
label = torch.Tensor(y_num).long()

In [10]:
batch = []
for i in range(sentence.shape[0]//32):
    if i == 0:
        batch.append([sentence[0:32],label[0:32]])
    else:
        batch.append([sentence[i*32:(i+1)*32],label[i*32:(i+1)*32]])

In [48]:
class PositionalEncoding(nn.Module):
    '''
    Converts the vector embedding of batch of sequence to their positional encoding vectors.

    Arguments:
            encoded_sentence : embbeding vector which is to be Positional Encoded.
            shape : shape of embbeding vector => tuple(batch_size,max_len,dmodel)

    Returns :
            positional encoded vector

    '''
    def __init__(self,shape):
        super().__init__()
        self.max_len = shape[1]
        self.dmodel = shape[2]
        self.batch_size = shape[0]

    def forward(self,x):
        #create a position vector containing position of words
        position = torch.arange(0, self.max_len, device=device).float().unsqueeze(1)

        #applies the formula for and creates divsion term
        div_term = torch.exp(torch.arange(0, self.dmodel, 2, device=device).float() * -(math.log(10000.0) / self.dmodel))

        #creates the zeros vector of sentence shape
        pos_enc = torch.zeros((self.batch_size, self.max_len, self.dmodel), device=device)

        #applies the formula for sin(even) and cos(even)
        pos_enc[:,:,0::2] = torch.sin(position * div_term)
        pos_enc[:,:,1::2 ] = torch.cos(position * div_term)

        #shape(batch_size,max_len,dmodel)
        return pos_enc


In [49]:
class MultiHeadAttention(nn.Module):
    def __init__(self,shape,heads):
        super().__init__()

        self.shape = shape
        self.max_len = shape[1]
        self.dmodel = shape[2]
        self.batch_size = shape[0]
        self.heads = heads
        self.head_size = int(self.dmodel/heads)

        #defines the shape of multiheaded matrix
        self.multi_headed_shape = (self.shape[0],self.shape[1],self.heads,self.head_size)

        self.k_linear = nn.Linear(self.dmodel,self.dmodel)
        self.q_linear = nn.Linear(self.dmodel,self.dmodel)
        self.v_linear = nn.Linear(self.dmodel,self.dmodel)

    def split_heads(self,matrix,shape):
        return matrix.view(*self.shape)


    def attention(self,k,q,v):
        '''
        applies the attention formula for single heads

        Arguments:
                k : key
                q : query
                v : value)
        Returns :
                single matrix same as shape of k,q,v
        '''
        return torch.matmul(F.softmax((torch.matmul(q,k.transpose(-1,-2)))/(torch.sqrt(torch.tensor(dmodel/heads))),dim=-1) , v)

    def forward(self,x):
        # shape(batch_size,max_len,dmodel)
        K_prime = self.k_linear(x)
        Q_prime = self.q_linear(x)
        V_prime = self.v_linear(x)

        #applies split head
        K_prime = self.split_heads(K_prime,self.shape)
        Q_prime = self.split_heads(Q_prime,self.shape)
        V_prime = self.split_heads(V_prime,self.shape)

        #applies attention and then concatinate
        return self.attention(K_prime,Q_prime,V_prime).view(*self.shape)



In [50]:
class AddAndNorm(nn.Module):
    def __init__(self,dmodel):
        super().__init__()
        self.dmodel = dmodel

    def forward(self,x,residual):
        return torch.add(residual , F.layer_norm(x,normalized_shape=(self.dmodel,)))


In [51]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(512,512,bias=True)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(512,512,bias=True)

    def forward(self,x):
        return self.linear2(self.relu1(self.linear1(x)))



In [52]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,shape):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        self.positional_encoding =  PositionalEncoding(shape)
        self.multi_headed_attention = MultiHeadAttention(shape,4)
        self.add_and_norm1 = AddAndNorm(512)
        self.feed_forward = FeedForward()
        self.add_and_norm2 = AddAndNorm(512)
        self.softmax = nn.Softmax(dim=1)
        self.linear3 = nn.Linear(512,512)
        self.linear4 = nn.Linear(512*10,vocab_size)

    def forward(self,x):
        out = self.token_embedding_table(x)
        residual = self.positional_encoding(out)
        out = self.multi_headed_attention(residual)
        residual = self.add_and_norm1(out,residual,)
        out = self.feed_forward(residual)
        out = self.add_and_norm2(out,residual)
        out = self.linear3(out)
        out = out.view(32,-1)
        out = self.linear4(out)
        return self.softmax(out)


In [53]:
model = Encoder(vocab_size,shape)
criterition = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [54]:
import random
random.shuffle(batch)

In [56]:
model = model.to(device)

In [None]:
for epoch in range(10):
    losses = []
    for b in tqdm.tqdm(batch):
        out = model(b[0].to(device))
        l = F.one_hot(b[1],num_classes=vocab_size).float().to(device)
        loss = criterition(out ,l)
        loss.backward()
        optimizer.step()
        losses.append(loss)
    print(sum(losses) / len(losses))