In [2]:
import random as r

from engine import Value
from modules import *
from ops import *
from gpt import *

In [3]:
batch_size = 2
vocab_len = 10
model_dim = 8
max_seq_len = 5
seq_len = 3
num_heads = 2
head_dim = 4
mlp_mult = 4

In [30]:
class CrossEntropyLoss(Module):
    def __init__(self, vocab_len: int, pad_token: int = None):
        self.vocab_len = vocab_len
        self.pad_token = pad_token

    def __call__(self, logits, targets):
        '''
        inputs: 
        logits - list of lists of lists of shape (batch_size, seq_len, vocab_len) full of Value objects
        targets - list of lists of shape (batch_size, seq_len) full of integers representing token indices

        output: a single Value object representing loss of the model
        '''
        assert isinstance(targets, list) and isinstance(targets[0], list) and isinstance(targets[0][0], int)
        one_hots = vector_wise_apply(self._one_hot, targets)

        log_logits = vectorr_wise_apply(log, logits)

    def _one_hot(self, targets_vec):
        '''
        turns list of tokens into list of one-hot vectors with 1's at the index of the given token
        meant to be used with vector_wise_apply
        '''
        assert all(isinstance(t, int) for t in targets_vec)
        return [[0] * t + [1] + [0] * (vocab_len - t - 1) for t in targets_vec]

In [26]:
# TODO: make Embedding module also do unembedding w/ shared weights? 
# wouldn't be exactly faithful to pytorch implementation but i'd like to use gradient accumulation & save parameters

In [28]:
logits = [[[Value(r.uniform(-4,4)).exp() for _ in range(vocab_len)]
      for _ in range(seq_len)]
     for _ in range(batch_size)]
logits = vector_wise_apply(softmax, logits)
pretty_print_tensor(logits)
celoss = CrossEntropyLoss(vocab_len, pad_token = vocab_len - 1)
targets = [[r.randint(0, vocab_len - 1) for _ in range(seq_len)]
           for _ in range(batch_size)]
pretty_print_tensor(targets)
loss = celoss(logits, targets)
print(loss)
pretty_print_tensor(loss)

[
  [
    [Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=1.000, grad=0.000)]
    [Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.003, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.997, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000)]
    [Value(data=0.000, grad=0.000), Value(data=0.715, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.285, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000), Value(data=0.000, grad=0.000)]
  ]
  [
    [Value(data=0.000, grad=0.000), Value