In [16]:
import torch
import torch.nn
from torch.nn import functional as F

In [3]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [4]:
with open('shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [17]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [6]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}

encode = lambda s : [stoi[ch] for ch in s]
decode = lambda l : ''.join([itos[i] for i in l])

In [7]:
data = torch.tensor(encode(text), dtype=torch.long)

In [8]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[:n]

In [13]:
# dimension values
batch_size = 4
block_size = 8
n_embd = 32
n_head = 4
d_k = 32
d_v = d_k

In [10]:
def get_batch(split):
    # need to get batch_size number of samples
    # choose from the specified split
    split_data = train_data if split == 'train' else val_data

    # get batch_size number of starting indicies
    ix = torch.randint(len(split_data) - block_size, (batch_size,))

    print(ix.shape)

    # for each i in ix,
    # i is the starting index
    # so we take from i to i + block_size
    x = torch.stack([split_data[i:i+block_size] for i in ix])
    # y is just the element after each x
    y = torch.stack([split_data[i+1:i+block_size+1] for i in ix])

    return x, y

In [14]:
xb, yb = get_batch('train')
xb

torch.Size([4])


tensor([[42, 57,  1, 46, 43,  1, 51, 63],
        [30, 16,  1, 21, 34, 10,  0, 14],
        [10,  0, 32, 46, 43, 56, 43,  1],
        [46, 43, 57, 58,  8,  0,  0, 19]])

In [None]:
class SelfAttention(nn.Module):

    def __init__(self, n_embd, n_head):
        
        self.n_embd = n_embd
        self.n_head = n_head
        self.d_k = n_embd // n_head
        
        assert (self.n_head * self.d_k == self.n_embd), "Embedding size not divisible by number of heads"

        self.query = nn.Linear(self.n_embd, self.d_k)
        self.key = nn.Linear(self.n_embd, self.d_k)
        self.value = nn.Linear(self.n_embd, self.d_k)

    def forward(self, X):
        Q = self.query(X)
        K = self.key(X)
        V = self.value(X)

        