# FractaFormer

this base version is going to be absurdly terribly no-good inefficient because we're taking the biggest computational issue ($O(t^2)$ attention) and making it way worse by doing MANY of them at once and then having to keep track of each parameter's gradient from MANY different perspectives. This is basically just an extension of [MatFormer+](https://github.com/evintunador/matryoshkaGPT/blob/main/MatFormer%2B.ipynb) where instead of one inner model, we have 2 (or whatever number you specify) models inside 1 at each layer

# TODO
- ~output~
    - ~tensor~
    - ~tuple~
    - triple check test
- ~loss~
    - ~tuple~
    - triple check test
- ~model itself~
    - ~tensor~
    - ~tuple~
    - triple check test
- adjust verboseness to be function-specific

In [1]:
# Importing pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# used for the tokenizer
import pickle
import os

# Imports used for the config
import dataclasses 
from typing import Optional

# Imports used for the model
import re
from typing import Any, List, Sequence, Tuple, Union
import numpy as np

# used in the training loop
import time

# The Dataset

the dataset we'll be using is just TinyShakespeare for sake of simplicity & ability to do run/train locally on any computer

In [2]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


# The Tokenizer

We'll be using a very simple tokenizer I previoiusly trained off of the TinyShakespeare dataset that has 128 total tokens and ignores stuff like special tokens & regex. 

In [3]:
import classes.simple_tokenizer as simple_tokenizer

# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']


# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = simple_tokenizer.SimpleTokenizer(loaded_stoi, loaded_merges)
print("vocab length: ", tokenizer.vocab_len)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?")
print("Encoded:", encoded_text, len(encoded_text))

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text, len(decoded_text))

vocab length:  128
Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30, 53, 83, 53, 12] 37
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou Romeo? 49


# Config

In [4]:
import classes.ff_config as ff_config
config = ff_config.Config(tokenizer.vocab_len)

print("single large model -> hierarchy of many smaller models inside")
print(f"model_count: {config.model_count}")
print(f"model_dim_list: {config.model_dim_list}")
print(f"head_dim_list: {config.head_dim_list}")
print(f"verbose: {config.verbose}")

single large model -> hierarchy of many smaller models inside
model_count: [1, 2, 4]
model_dim_list: [128, 64, 32]
head_dim_list: [32, 16, 8]
verbose: {'RMSNorm': False, 'MLP': False, 'MQA': False, 'Layer': False, 'OutputLayer': False, 'FractalLoss': False, 'FractalFormer': False, 'Sampler': False, 'Generate': False}


# Rotary Positional Encoding (RoPE)

i don't think i need to adjust the code for this one as long as i always call it individually

# RMSNorm

Layernorm is relatively simple code-wise. However, of note is the fact that during training, the entire full length vector gets normalized whereas during inference we only layernorm the sub-vector we've been given if we're not using the full model size. This is interesting because RMSNorm puts a vector of length $d$ onto a hypersphere of radius $\sqrt{d}$ which means that while the embeddings of the largest model exist on a hypersphere of the aforementioned size, for each number of layers $i\in\mathbb{N}$ s.t. $0 < i \leq$ `config.model_count` the embeddings are placed onto a hypersphere of radius $\sqrt{\frac{d}{s^i}}$ where $s=$`config.split`. I'm not sure yet exactly how to interpret this concatenation of vectors geometrically. When you combine the entries of two hypserspheres to make a larger hypserspheres, what happens to the feature groupings on the surface of the smaller hyperspheres? I presume there are some type of interaction effects or something. 

The following cell was designed to help you visualize what's happening with RMSNorm's splicing. With RMSNorm we'll only have to think about doing this with individual tensors, but with future methods like MLP and MQA we'll have to create an entirely separate forward method used during training that deals with tuples of tensors. The thing to pay attention to here is the size of the scale weights. scale_weights' entries are 0's because we've not yet undergone training

In [5]:
import classes.rms_norm as rms_norm

# Testing our RMSNorm's forward()
print("--------- Micro Hyperparameters -------")
hold = config.hidden_size
config.hidden_size = 4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,config.hidden_size)
print(f"x: {x.shape}\n{x}")
norm = rms_norm.RMSNorm(config.hidden_size)
norm.verbose = True
y = norm(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,config.hidden_size//2)
print(f"x: {x.shape}\n{x}")
norm = rms_norm.RMSNorm(config.hidden_size)
norm.verbose = True
y = norm(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,config.hidden_size//2)
print(f"x: {x.shape}\n{x}")
norm = rms_norm.RMSNorm(config.hidden_size)
norm.verbose = True
y = norm(x, model=1)
print(f"y: {y.shape}\n{y}")

print("---------- RESET CONFIG --------")
config.hidden_size = hold
print("model_count: ", config.model_count)

# clear up memory
del hold, x, y, norm

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 4])
tensor([[[0.1906, 0.8151, 0.7329, 0.6793],
         [0.1616, 0.3718, 0.9625, 0.6126]]])
------------- RMSNorm.forward() ------------
x: torch.Size([1, 2, 4])
tensor([[[0.1906, 0.8151, 0.7329, 0.6793],
         [0.1616, 0.3718, 0.9625, 0.6126]]])
normed x: torch.Size([1, 2, 4])
tensor([[[0.2924, 1.2506, 1.1244, 1.0423],
         [0.2669, 0.6142, 1.5899, 1.0118]]])
dim: 4
skip: 0
spliced scale: torch.Size([4])
tensor([0., 0., 0., 0.], grad_fn=<SliceBackward0>)
scaled normed x: torch.Size([1, 2, 4])
tensor([[[0.2924, 1.2506, 1.1244, 1.0423],
         [0.2669, 0.6142, 1.5899, 1.0118]]], grad_fn=<MulBackward0>)
------------- END RMSNorm.forward() ------------
y: torch.Size([1, 2, 4])
tensor([[[0.2924, 1.2506, 1.1244, 1.0423],
         [0.2669, 0.6142, 1.5899, 1.0118]]], grad_fn=<MulBackward

# Multi-Layer Perceptron


<p align="center">
<img src="./images/ffwd.jpeg" width="512"/>
</p>

The following two cells are designed to help you comprehend what's happening. If you walk through every single print statement and follow along even down to watching what happens to each weight, you'll be able to clearly see what's happening with the odd splicing behavior. In order to make this somewhat feasible, I've set very small matrices for these examples. However I will admit it is still inevitably a pain, which is why I included the drawings above.

In [6]:
import classes.mlp as my_mlp
# Testing our MLP's forwardTensor()
verbose = True
print("--------- Micro Hyperparameters -------")
hold = config.hidden_size
config.hidden_size = 4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,4)
print(f"x: {x.shape}\n{x}")
mlp = my_mlp.MLP(4,8)
y = mlp(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,2)
print(f"x: {x.shape}\n{x}")
mlp = my_mlp.MLP(4,8)
y = mlp(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,2,2)
print(f"x: {x.shape}\n{x}")
mlp = my_mlp.MLP(4,8)
y = mlp(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold
print("model_count: ", config.model_count)

# clear up memory
del hold, x, y, mlp

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 4])
tensor([[[0.9599, 0.7425, 0.4285, 0.6888],
         [0.4940, 0.4207, 0.9316, 0.2704]]])
y: torch.Size([1, 2, 4])
tensor([[[ 0.2219, -0.0904,  0.1551,  0.0908],
         [ 0.3040, -0.0764,  0.0677,  0.1387]]], grad_fn=<AddBackward0>)
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 2])
tensor([[[0.1443, 0.7489],
         [0.0171, 0.4325]]])
y: torch.Size([1, 2, 2])
tensor([[[0.0676, 0.3493],
         [0.0720, 0.3463]]], grad_fn=<AddBackward0>)
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 2, 2])
tensor([[[0.4050, 0.4628],
         [0.0385, 0.7982]]])
y: torch.Size([1, 2, 2])
tensor([[[-0.3495,  0.3693],
         [-0.3573,  0.3609]]], grad_fn=<AddBackward0>

In [7]:
# Testing our MLP's forwardTuple()
verbose = True
print("--------- Micro Hyperparameters -------")
hold1, hold2 = config.hidden_size, config.levels
config.hidden_size = 4
config.levels = 2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

mlp = my_mlp.MLP(4,8)
x = ((torch.randn((1,2,4)),),
     (torch.randn((1,2,2)),torch.randn((1,2,2)))
    )
print(f"x: {x}")
out = mlp(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, x, out, mlp

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
x: ((tensor([[[-1.7272e-02,  1.7605e+00,  1.3065e+00,  1.0095e-01],
         [ 2.2699e+00,  1.1732e-03, -5.2974e-01, -9.1841e-03]]]),), (tensor([[[-0.4960,  0.9603],
         [ 0.4245,  1.8475]]]), tensor([[[ 0.7208,  0.4512],
         [-1.1678,  1.6859]]])))
out: ((tensor([[[ 0.0859, -0.5522, -0.5765, -0.2148],
         [ 0.0753,  0.2515, -0.0000,  0.5056]]], grad_fn=<MulBackward0>),), (tensor([[[ 0.2315, -0.3805],
         [ 0.0000, -0.4768]]], grad_fn=<MulBackward0>), tensor([[[-0.0000, 0.1465],
         [-0.0000, 0.1279]]], grad_fn=<MulBackward0>)))
---------- RESET CONFIG --------
model_count:  [1, 2, 4]
model_dim_list:  [128, 64, 32]


# Attention

To subset the attention heads, we have to not only splice according to the model's embedding dimension but also take into account new smaller head sizes and how they're spaced throughout the matrix. I'm assuming you know how self-attention works well enough to look at this weight matrix and get the idea

<p align="center">
<img src="./images/sa.jpeg" width="512"/>
</p>

then we've gotta concatenate the outputs of each head

<p align="center">
<img src="./images/mha_concat.jpeg" width="512"/>
</p>

and after that linearly project them

<p align="center">
<img src="./images/mha_proj.jpeg" width="512"/>
</p>

this is the place where our splicing gets conceptually annoying. instead of just grabbing the matrix in the upper corner, because of the way attention head output concatenation works we actually need to skip over certain parts of the linear projection matrix and then concatenate them together in order to use them. Here's an example of what the matrix multiplication looks like. on the left is a simplified version of the concatenated attention heads where i just showed it as a matrix rather than a tensor, and then on the right is the actual projection matrix. notice how the numbers in the pink output matrix look similar to the first column of the purple output matrix with a positive number, its negative, and then a smaller positive number; that's the self-similarity in action. the yellow arrows point to the parts that get skipped over. obviously this would look a lot uglier with bigger matrices & incorporating the blue/green layer

<p align="center">
<img src="./images/mha_proj_matmul.jpeg" width="512"/>
</p>

And here are the detailed print statements for the attention mechanism

In [8]:
import classes.multi_query_attention as mqa

# Testing our Attention's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.num_attention_heads, config.head_dim, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,8)
print(f"x: {x.shape}\n{x}")
att = mqa.MultiQueryAttention(config)
y = att(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,4)
print(f"x: {x.shape}\n{x}")
att = mqa.MultiQueryAttention(config)
y = att(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,4)
print(f"x: {x.shape}\n{x}")
att = mqa.MultiQueryAttention(config)
y = att(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.max_position_embeddings = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, x, att, y

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [8, 4, 2]
head_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 8])
tensor([[[0.0263, 0.9244, 0.4639, 0.6496, 0.1376, 0.2457, 0.0823, 0.5641],
         [0.0526, 0.2470, 0.2450, 0.9060, 0.0645, 0.1045, 0.1379, 0.9632],
         [0.7057, 0.2345, 0.5654, 0.4798, 0.5494, 0.4583, 0.3758, 0.6214]]])
y: torch.Size([1, 3, 8])
tensor([[[ 0.0852, -0.2793, -0.0199, -0.0007, -0.1245,  0.2309, -0.1641,
           0.2679],
         [ 0.1774, -0.1397, -0.0198, -0.0094, -0.0488,  0.2022, -0.1168,
           0.1838],
         [ 0.2226, -0.1038, -0.0217,  0.0018, -0.0344,  0.2372, -0.1085,
           0.2067]]], grad_fn=<UnsafeViewBackward0>)
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 4])
tensor([[[0.3838, 0.1565, 0.9070, 0.441

In [9]:
# Testing our Attention's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4, hold5 = config.hidden_size, config.num_attention_heads, config.head_dim, config.levels, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.levels = 2
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

att = mqa.MultiQueryAttention(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,8)),),(torch.randn((1,3,4)),torch.randn((1,3,4))))
print(f"x: {x}")
out = att(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.levels = hold4
config.max_position_embeddings = hold5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, hold5, x, att, out

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [8, 4]
head_dim_list:  [4, 2]
x: ((tensor([[[-2.2259e+00,  4.0993e-01, -1.0243e+00,  1.2434e+00,  9.5531e-01,
          -8.1069e-02,  2.4555e-01,  5.9443e-01],
         [-6.1069e-01,  7.2299e-01, -3.8122e-01,  9.6473e-02,  4.5401e-01,
           5.4707e-01,  1.5144e+00, -1.6327e+00],
         [ 4.8189e-01, -9.3894e-01, -1.5024e+00, -4.4929e-04, -3.3427e-01,
          -2.0939e+00, -7.9109e-02,  1.4455e+00]]]),), (tensor([[[ 0.1748, -2.0631,  0.1800, -0.4377],
         [-0.1954, -2.6433, -2.1828, -0.4821],
         [ 0.9416, -0.8894, -0.9836, -0.0447]]]), tensor([[[-1.9438, -0.5415,  0.7859,  2.4925],
         [-0.9194, -0.2384,  0.5221, -0.0625],
         [-0.8240,  0.1747,  0.8672,  0.9473]]])))
out: ((tensor([[[-0.0224, -0.1037, -0.0299, -0.3154, -0.0000,  0.1995, -0.0606,
          -0.1362],
         [ 0.1232,  0.2063, -0.2081, -0.1087, -0.1401,  0.3522,  0.0790,
          -0.2274],
         [-0.1498, -0.22

# Layer

nothing too interesting here besides the absurd amount of memory we're probably taking up with these tuples

In [10]:
import classes.layer as layyyyer
# Testing our Layer's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.num_attention_heads, config.head_dim, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size)
print(f"x: {x.shape}\n{x}")
layer = layyyyer.Layer(config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.max_position_embeddings = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, x, layer, y

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [8, 4, 2]
head_dim_list:  [4, 2, 1]
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 8])
tensor([[[0.2900, 0.3502, 0.3640, 0.6985, 0.3658, 0.9914, 0.2533, 0.3429],
         [0.6997, 0.9056, 0.1009, 0.3460, 0.8976, 0.8859, 0.0885, 0.6918],
         [0.5278, 0.8558, 0.6953, 0.0507, 0.0577, 0.9630, 0.4264, 0.3923]]])
y: torch.Size([1, 3, 8])
tensor([[[ 0.4562,  0.2389,  0.4436,  0.5384,  0.1791,  0.7488,  0.5640,
          -0.0758],
         [ 0.9391,  0.4394,  0.1791,  0.1711,  0.6333,  0.5040,  0.4082,
           0.0405],
         [ 0.6622,  0.6216,  0.7912, -0.1079, -0.2573,  0.6049,  0.6932,
          -0.0120]]], grad_fn=<AddBackward0>)
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 4])
tensor([[[0.6970, 0.5876, 0.6012, 0.4722],
   

In [11]:


# Testing our Layer's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4, hold5 = config.hidden_size, config.num_attention_heads, config.head_dim, config.levels, config.max_position_embeddings
config.hidden_size = 8
config.num_attention_heads = 2
config.head_dim = 4
config.levels = 2
config.max_position_embeddings = 3
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

layer = layyyyer.Layer(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,8)),),(torch.randn((1,3,4)),torch.randn((1,3,4))))
print(f"x: {x}")
out = layer(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.num_attention_heads = hold2
config.head_dim = hold3
config.levels = hold4
config.max_position_embeddings = hold5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)
print("head_dim_list: ", config.head_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, hold5, x, layer, out

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [8, 4]
head_dim_list:  [4, 2]
x: ((tensor([[[-0.4899,  0.6559,  0.8784, -0.1362,  0.4753,  0.3774,  1.1159,
           0.1347],
         [-1.1374,  1.5822, -1.3046,  0.2054, -0.1641, -0.8444,  0.9152,
          -1.2537],
         [-1.6206, -0.2979, -0.5589, -1.7389,  0.2237,  0.8430,  0.2355,
          -0.6535]]]),), (tensor([[[ 0.1224,  1.3570, -0.3866, -1.7016],
         [ 0.5770, -0.3415,  0.8091,  1.2086],
         [-0.0276, -0.4859, -0.2800,  1.2966]]]), tensor([[[-1.6194, -1.2758,  0.4798, -1.4096],
         [ 1.1032, -0.4800,  0.1265, -0.7214],
         [ 1.0496, -0.6565, -0.1099,  0.6584]]])))
out: ((tensor([[[-1.4147,  0.4076,  1.6376, -0.7523,  0.6936, -0.1867,  1.3497,
           0.4808],
         [-1.5719,  1.1740, -0.7212, -0.0457,  0.3139, -1.3633,  1.0313,
          -0.6909],
         [-1.9420, -0.5451, -0.0535, -1.9322,  0.5954,  0.0571,  0.5031,
          -0.7269]]], grad_fn=<AddBackward0>),)

# Output Layer

In [12]:
import classes.output_layer as output_layer
# Testing our OutputLayer's forwardTensor()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2 = config.hidden_size, config.vocab_size
config.hidden_size = 4
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size)
print(f"x: {x.shape}\n{x}")
layer = output_layer.OutputLayer(embedding, config)
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x)
print(f"y: {y.shape}\n{y}")

print(f"|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the second sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-")
x = torch.rand(1,3,config.hidden_size//config.split)
print(f"x: {x.shape}\n{x}")
y = layer(x, model=1)
print(f"y: {y.shape}\n{y}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.vocab_size = hold2
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, x, layer, y

--------- Micro Hyperparameters -------
model_count:  [1, 2, 4]
model_dim_list:  [4, 2, 1]
embedding: torch.Size([5, 4])
tensor([[ 1.3118, -0.6339, -1.6035, -1.1829],
        [ 0.1080,  0.9321,  0.3213,  1.3060],
        [-0.3626, -1.0572,  0.1246,  0.0184],
        [ 0.5701, -1.2204, -1.1347,  0.6017],
        [-1.7379, -0.3902,  0.6982, -0.8460]])
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the big model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 4])
tensor([[[0.5452, 0.3534, 0.4683, 0.2556],
         [0.3696, 0.0615, 0.8479, 0.3796],
         [0.0931, 0.4082, 0.7007, 0.8249]]])
y: torch.Size([1, 3, 5])
tensor([[[-1.0834,  2.5314, -2.1498, -1.2731, -2.2169],
         [-2.2045,  2.1072, -0.3068, -1.2832, -0.7550],
         [-3.1223,  3.5583, -1.1117, -1.3775, -0.8727]]],
       grad_fn=<UnsafeViewBackward0>)
|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|- the first sub-model |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-
x: torch.Size([1, 3, 2])
tens

In [13]:
# Testing our Layer's forwardTuple()
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.levels, config.max_position_embeddings, config.hidden_size
config.hidden_size = 4
config.levels = 2
config.max_position_embeddings = 3
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

layer = layyyyer.Layer(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
x = ((torch.randn((1,3,config.hidden_size)),),
     (torch.randn((1,3,config.hidden_size//config.split)),torch.randn((1,3,config.hidden_size//config.split))))
print(f"x: {x}")
out = layer(x)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
config.max_position_embeddings = hold3
config.vocab_size = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, x, layer, out, embedding

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
embedding: torch.Size([5, 4])
tensor([[ 0.2224, -0.1793, -0.4423,  0.8017],
        [-0.1385, -1.2339, -0.8531,  0.1575],
        [-0.5926,  2.3769,  0.0477, -0.0497],
        [ 2.3957,  0.3885, -1.6190,  0.4254],
        [ 0.2385,  0.6624, -0.7998,  0.4367]])
x: ((tensor([[[-0.4031,  1.4628, -1.1429, -0.2040],
         [ 0.2421,  1.5902,  2.5933,  0.8067],
         [-1.5138,  1.1025,  1.0453, -1.1641]]]),), (tensor([[[-0.9932, -2.0957],
         [-0.3885, -0.8978],
         [-0.5158,  1.3748]]]), tensor([[[-0.7641, -1.2413],
         [-1.6640, -1.5790],
         [ 0.3750, -0.1341]]])))
out: ((tensor([[[-1.0999,  2.1041, -1.2098, -0.8285],
         [-0.1051,  1.9027,  2.5249,  0.7481],
         [-1.9295,  1.6160,  0.8707, -1.1355]]], grad_fn=<AddBackward0>),), (tensor([[[-1.0618, -2.1726],
         [-0.4744, -0.9711],
         [-0.6836,  1.4032]]], grad_fn=<AddBackward0>), tensor([[[-0.4680, -1.5190],


# Loss Function

In [14]:
import classes.fractal_loss as fractal_loss

# Testing our FractalLoss
verbose = True

print("--------- Micro Hyperparameters -------")
hold1, hold2, hold3, hold4 = config.hidden_size, config.levels, config.max_position_embeddings, config.hidden_size
config.hidden_size = 4
config.levels = 2
config.max_position_embeddings = 3
config.vocab_size = 5
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

embedding = torch.randn(config.vocab_size, config.hidden_size)
print(f"embedding: {embedding.shape}\n{embedding}")

loss = fractal_loss.FractalLoss(config)
# we need to make sure to send in a tuple of the expected size. above we set hidden_size=4 and levels=2
logits = ((torch.randn((2,3,config.vocab_size)),),
     (torch.randn((2,3,config.vocab_size)),torch.randn((2,3,config.vocab_size))))
print(f"logits: {logits}")
target = torch.randint(config.vocab_size, (2,3)).unsqueeze(0)
print(f"target: {target}")
out = loss(logits, target)
print(f"out: {out}")

verbose = False
print("---------- RESET CONFIG --------")
config.hidden_size = hold1
config.levels = hold2
config.max_position_embeddings = hold3
config.vocab_size = hold4
print("model_count: ", config.model_count)
print("model_dim_list: ", config.model_dim_list)

# clear up memory
del hold1, hold2, hold3, hold4, embedding, loss, logits, target, out

--------- Micro Hyperparameters -------
model_count:  [1, 2]
model_dim_list:  [4, 2]
embedding: torch.Size([5, 4])
tensor([[ 0.4059, -2.0473, -0.2735, -0.5570],
        [ 0.7993, -0.7838,  1.1809,  1.2022],
        [ 0.1442,  1.4582,  0.6399, -0.5694],
        [-1.1310, -0.1624,  0.3523, -0.3711],
        [-0.3034,  0.2289, -0.0687,  2.2043]])
logits: ((tensor([[[ 0.6436,  0.3571,  0.1013,  0.2318, -0.1090],
         [ 0.1568,  0.2087, -2.3688,  0.8231,  0.9686],
         [ 0.6816, -1.4875,  0.2660, -1.7342,  0.5034]],

        [[-1.2039,  1.9883, -0.2778,  1.4461,  0.3314],
         [-0.3510,  0.1949,  1.0903, -0.1730, -1.4147],
         [-1.0661,  0.9731, -0.8389, -0.3800,  0.1361]]]),), (tensor([[[ 0.7224, -1.7460, -1.7500,  0.1530,  0.7799],
         [ 0.7485,  1.1115, -0.2012, -0.4903,  0.0189],
         [ 1.0612,  0.8323, -0.1647, -3.4104,  0.6588]],

        [[ 0.2231,  0.1398,  1.6119,  0.4783, -1.0464],
         [-0.5405, -0.1441,  1.0267,  0.5986, -0.3172],
         [-0.5817,

# The Model itself

# Training-related Functions

In [15]:
# Train and test splits
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be our training dataset, the rest for validation
train_data = data[:n]
val_data = data[n:]

In [16]:
# data loading for training which generates a small batch of data of inputs x and targets y
def get_batch(split, batch_size):
    # whether we grab from our training or validation dataset
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
    x = torch.stack([data[i:i+config.max_position_embeddings] for i in ix])
    y = torch.stack([data[i+1:i+config.max_position_embeddings+1] for i in ix])
    x, y = x.to(config.device), y.to(config.device)
    return x, y

In [17]:
# a demonstration of what a batch with batch_size=1 looks like. Notice the one-token offset in characters
xb, yb = get_batch('train', 1)
print(xb)
print(tokenizer.decode(xb.squeeze(0).tolist()))
print("-------")
print(yb)
print(tokenizer.decode(yb.squeeze(0).tolist()))

tensor([[ 37,   1,  34,  21,  75,  23,  21,  26,  19,   1,  17,  16,  35,  13,
          30,  16,   1,  21,  34,  71,  26, 103,   1,  87,  93,   1,  39,   1,
          54,  68,  47,  53,  42,   1,  94,   1,  58,  59,  51,  59,  50,  58,
          59,  67,  57,   1,  40, 115,  47,  50,  57,   8,   0,  13,  61, 106,
           1, 100,  65,   1,  27,  62, 105,  42,   1,  84,   1,  20,  39,  83,
          57,   1,  15, 107,  58,  99,   1,  80,  56,  39,  47, 122,  58,  71,
          18,  73,   1,  31,  53,  83,  56,  91,  58,  66,  94,  44,   1, 100,
          65,   1, 102,  57,   1,  45,  59,  47,  50,  58,  63,   1,  87,  39,
          42,   8,   0,  19,  53,  66,  98,  77,   1,  72,  51,   1,  87,  52,
         104, 125,  21,   1, 100,  81,   1, 116,   1,  87,  77,   1,  72,  51,
           1,  57,  54,  43,  39,  49,  85,  27,  36,  18,  27,  30,  16,  71,
          18,  73,   1, 101,   1,  54,  77,  58,  66,  21,   5,  81,   1, 116,
           1,  58, 115,  59,  40,  99,   1,  72,  43

In [18]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters = 10): # to estimate loss during the training loop
    out = {}
    model.eval() # sets model to eval mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train() # just resets to training mode
    return out

# Instantiating a brand new model

In [19]:
import classes.fractal_former_base as ff_base

# just to make sure nothing got messed up above. 
# if an error gets thrown in one of the test cells then the config values won't reset
print(config)

model = ff_base.FractalFormer_base(config, tokenizer).to(config.device)

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

print(model)

Config(max_position_embeddings=256, num_hidden_layers=4, num_attention_heads=4, num_key_value_heads=1, hidden_size=128, head_dim=32, rms_norm_eps=1e-06)
972.672 K parameters
FractalFormer_base(
  (embedder): Embedding(128, 128)
  (embedder_norm): RMSNorm()
  (layers): ModuleList(
    (0-3): 4 x Layer(
      (self_attn): MultiQueryAttention(
        (drop): Dropout(p=0.1, inplace=False)
      )
      (mlp): MLP(
        (drop): Dropout(p=0.1, inplace=False)
      )
      (input_layernorm): RMSNorm()
      (post_attention_layernorm): RMSNorm()
    )
  )
  (output_layer): OutputLayer(
    (embedding_norm): RMSNorm()
    (final_norm): RMSNorm()
  )
  (criterion): FractalLoss(
    (criterion): CrossEntropyLoss()
  )
)


# Training

In [20]:
# create a PyTorch optimizer
# this is not what they used, but this learning rate & weight decay work for our tiny minGemma
learning_rate = 3e-5
weight_decay = 0.01
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# how long we want to train for
max_iters = 5000

# how often we want to check & see how our loss is doing
eval_interval = 250

# batch size to use
batch_size = 12

# if you want to do debugging
config.verbose['RMSNorm'] = False
config.verbose['MLP'] = False
config.verbose['MQA'] = False
config.verbose['Layer'] = False
config.verbose['OutputLayer'] = False
config.verbose['FractalLoss'] = False
config.verbose['FractalFormer'] = False
config.verbose['Sampler'] = False
config.verbose['Generate'] = False

# ------------ BOOKMARK ----------------

In [21]:
model.train()
start_time = time.time()

# Enable anomaly detection. uncomment these lines if you need to do extensive debugging
#torch.autograd.set_detect_anomaly(True)

print(f"max_iters: {max_iters}")

for iter in range(max_iters):

    # sample a batch of data
    xb, yb = get_batch('train', batch_size)
    
    # train
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1 or iter % 10 == 0:
        current_time = time.time()
        elapsed_time = current_time - start_time
        losses = estimate_loss(model, batch_size)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, time elapsed: {elapsed_time:.2f} seconds")

# Disable anomaly detection after the training loop
#torch.autograd.set_detect_anomaly(False)

max_iters: 5000
step 0: train loss 646.7269, val loss 647.3767, time elapsed: 1.78 seconds
step 10: train loss 638.0670, val loss 639.3229, time elapsed: 28.80 seconds
step 20: train loss 626.4349, val loss 628.4929, time elapsed: 56.47 seconds
step 30: train loss 610.5033, val loss 613.5657, time elapsed: 83.83 seconds
step 40: train loss 590.6318, val loss 593.1642, time elapsed: 113.13 seconds
step 50: train loss 565.5765, val loss 570.2756, time elapsed: 140.79 seconds
step 60: train loss 532.3127, val loss 537.9030, time elapsed: 170.35 seconds
step 70: train loss 495.9375, val loss 502.2664, time elapsed: 197.25 seconds


KeyboardInterrupt: 

# Saving your model

In [None]:
# save the model currently held in memory
# the filename specifies the model's class, hyperparameters, and date/time it was saved
import os

# Ensure the directory exists
model_dir = 'models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Create a shorter, more concise filename
filename = (f'{model.__class__.__name__}'
           f'-v{config.vocab_size}'
           f'-max_t{config.max_position_embeddings}'
           f'-layers{config.num_hidden_layers}'
           f'-heads{config.num_attention_heads}'
           f'-kv_heads{config.num_key_value_heads}'
           f'-hidden{config.hidden_size}'
           f'-intermediate{config.intermediate_size}'
           f'-head_dim{config.head_dim}'
           f'-theta{config.rope_theta}'
           f'-levels{config.levels}'
           f'-split{config.split}'
           f'-lr{learning_rate}'
           f'-decay{weight_decay}'
           f'-batch{batch_size}'
            f'-train_iter{15000}'
           f'--{time.strftime("%Y-%m-%d|%H-%M-%S")}.pth')

# Save the model
model_path = os.path.join(model_dir, filename)
torch.save(model.state_dict(), model_path)

# Load a Pretrained Model

In [None]:
# Initialize a blank model
model = FractalFormer_base(config, tokenizer).to(config.device)  

# here's the path to a minGemma model that i've trained with roughly 1m parameters
path = 'models/FractalFormer_base-v128-max_t256-layers4-heads4-kv_heads1-hidden128-intermediate512-head_dim32-theta100.0-levels3-split2-lr0.0003-decay0.01-batch12--2024-03-06|07-14-57.pth'

# Load the saved state dictionary
model.load_state_dict(torch.load(path))
# REMEMBER TO CHANGE VALUES IN CONFIG TO MATCH THE MODEL YOU'VE LOADED

# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# If you only plan to do inference, switch to evaluation mode
model.eval()

# If you plan to continue training the model, switch to training mode
#model.train()

# Inference

In [None]:
model.eval() # sets model to eval mode
input_str = "JULIET:\nO Romeo, Romeo! wherefore art thou R" # the classic line
max_useable_output_len = config.max_position_embeddings - len(input_str)

for i in range(config.levels):
    for j in range(config.model_count[i]):
        print(f"level: {i}, model: {j}")
        output = model.generate(input_str, 
                                output_len = max_useable_output_len, 
                                temperature=0.7, 
                                top_k = 3, 
                                top_p = 0.95,
                               level = i,
                               model = j)
        print(output)

so there's almost definitely something wrong happening here