In [3]:
%load_ext autoreload
%autoreload 2
from alphatoe import evals, data, game, train
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from transformer_lens import HookedTransformerConfig, HookedTransformer
import json
from transformer_lens import HookedTransformer, HookedTransformerConfig
from typing import Callable, Any
import einops
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import math

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [307]:
"""
I want to generate training data that can be used to have a model predict a dice roll. Input data is 0, prediction [1, 2, 3, 4, 5, 6]
I then want to train an MLP using either one hot encoding or probabilistic encoding. They should both reach the same theoretical loss.
What is the difference in convergence speed?
"""
print("entropy for pdf d6: ", sum([1/6 * -math.log(1/6) for _ in range(6)]))
print("cross entropy pdf d6 with p(X=1): ", F.cross_entropy(torch.tensor([1/6 for _ in range(6)]), torch.tensor([1., 0., 0., 0., 0., 0.])))
print("cross entropy pdf d6 with pdf d6: ", F.cross_entropy(torch.tensor([1/6 for _ in range(6)]), torch.tensor([1/6 for _ in range(6)])))


entropy for pdf d6:  1.7917594692280547
cross entropy pdf d6 with p(X=1):  tensor(1.7918)
cross entropy pdf d6 with pdf d6:  tensor(1.7918)


In [311]:
def generate_data_one_hot_labels(num_samples=10000):
    label_data = torch.randint(0, 6, (num_samples, 1))
    labels = torch.zeros((num_samples, 6))
    labels[torch.arange(num_samples), label_data.squeeze()] = 1
    data = torch.randn((num_samples, 1))
    return data, labels

def generate_data_pdf_labels(num_samples=10000):
    value = 1/6
    labels = torch.ones((num_samples, 6)) * value
    data = torch.randn((num_samples, 1))
    return data, labels

In [312]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(1, 8),
            nn.ReLU(),
            nn.Linear(8, 6),  
            nn.LogSoftmax(dim=1) 
        )

    def forward(self, x):
        return self.layers(x)

In [313]:
model_one_hot = MLP()
model_pdf = MLP()
loss_fn = nn.CrossEntropyLoss() 
optimizer_one_hot = optim.Adam(model_one_hot.parameters(), lr=0.001)
optimizer_pdf = optim.Adam(model_pdf.parameters(), lr=0.001)

data_oh, labels_oh = generate_data_one_hot_labels()
data_pdf, labels_pdf = generate_data_pdf_labels()

train_data_oh, _, train_labels_oh, _ = train_test_split(data_oh, labels_oh, test_size=0.2)
train_data_pdf, _, train_labels_pdf, _ = train_test_split(data_pdf, labels_pdf, test_size=0.2)



In [318]:
for epoch in range(1000):
    for batch in range(0, len(train_data_oh), 2048):
        inputs_logits = model_one_hot(train_data_oh[batch:batch+2048])
        targets = train_labels_oh[batch:batch+2048]

        loss = loss_fn(inputs_logits, targets)

        optimizer_one_hot.zero_grad()
        loss.backward()
        optimizer_one_hot.step()
        test_input = torch.randn(1, 1)
        output_oh = torch.softmax(model_one_hot(test_input), dim=1)
        print(f"one hot logits: {output_oh}")

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    

one hot logits: tensor([[0.1794, 0.1540, 0.1514, 0.1665, 0.1672, 0.1814]],
       grad_fn=<SoftmaxBackward0>)
one hot logits: tensor([[0.1692, 0.1758, 0.1594, 0.1576, 0.1688, 0.1692]],
       grad_fn=<SoftmaxBackward0>)
one hot logits: tensor([[0.1664, 0.1658, 0.1631, 0.1543, 0.1775, 0.1729]],
       grad_fn=<SoftmaxBackward0>)
one hot logits: tensor([[0.1702, 0.1797, 0.1580, 0.1590, 0.1655, 0.1676]],
       grad_fn=<SoftmaxBackward0>)
Epoch 1, Loss: 1.787723422050476
one hot logits: tensor([[0.1686, 0.1738, 0.1600, 0.1569, 0.1705, 0.1701]],
       grad_fn=<SoftmaxBackward0>)
one hot logits: tensor([[0.1688, 0.1745, 0.1598, 0.1572, 0.1699, 0.1698]],
       grad_fn=<SoftmaxBackward0>)
one hot logits: tensor([[0.1698, 0.1476, 0.1828, 0.1579, 0.1725, 0.1694]],
       grad_fn=<SoftmaxBackward0>)
one hot logits: tensor([[0.1875, 0.1592, 0.1256, 0.1743, 0.1613, 0.1922]],
       grad_fn=<SoftmaxBackward0>)
Epoch 2, Loss: 1.7877225875854492
one hot logits: tensor([[0.1810, 0.1550, 0.1465, 0.16

In [315]:
for epoch in range(1000):
    for batch in range(0, len(train_data_pdf), 2048):
        inputs_logits = model_pdf(train_data_pdf[batch:batch+2048])
        targets = train_labels_pdf[batch:batch+2048]

        loss = loss_fn(inputs_logits, targets)

        optimizer_pdf.zero_grad()
        loss.backward()
        optimizer_pdf.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

Epoch 1, Loss: 1.8690282106399536
Epoch 2, Loss: 1.8630291223526
Epoch 3, Loss: 1.8573734760284424
Epoch 4, Loss: 1.852067470550537
Epoch 5, Loss: 1.8471137285232544
Epoch 6, Loss: 1.84250807762146
Epoch 7, Loss: 1.8382412195205688
Epoch 8, Loss: 1.8343006372451782
Epoch 9, Loss: 1.8306701183319092
Epoch 10, Loss: 1.8273322582244873
Epoch 11, Loss: 1.824267864227295
Epoch 12, Loss: 1.8214582204818726
Epoch 13, Loss: 1.8188846111297607
Epoch 14, Loss: 1.8165283203125
Epoch 15, Loss: 1.8143720626831055
Epoch 16, Loss: 1.812401294708252
Epoch 17, Loss: 1.8106001615524292
Epoch 18, Loss: 1.808956265449524
Epoch 19, Loss: 1.8074578046798706
Epoch 20, Loss: 1.8060917854309082
Epoch 21, Loss: 1.804848551750183
Epoch 22, Loss: 1.8037182092666626
Epoch 23, Loss: 1.8026912212371826
Epoch 24, Loss: 1.8017594814300537
Epoch 25, Loss: 1.8009151220321655
Epoch 26, Loss: 1.8001501560211182
Epoch 27, Loss: 1.799457311630249
Epoch 28, Loss: 1.7988295555114746
Epoch 29, Loss: 1.798261046409607
Epoch 30,

In [316]:
with torch.no_grad():
    test_input = torch.randn(1, 1)
    print(test_input)
    output_oh = torch.softmax(model_one_hot(test_input), dim=1)
    output_pdf = torch.softmax(model_pdf(test_input), dim=1)
    print(f"Logits for dice roll one hot: {output_oh}")
    print(f"Logits for dice roll PDF: {output_pdf}")


tensor([[0.1415]])
Logits for dice roll one hot: tensor([[0.1683, 0.1766, 0.1588, 0.1599, 0.1673, 0.1691]])
Logits for dice roll PDF: tensor([[0.1658, 0.1668, 0.1661, 0.1669, 0.1673, 0.1672]])
