diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0920fd9..152de72 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,6 +19,7 @@ jobs: steps: - uses: actions/checkout@v3 + - name: Set up Python 3.10 uses: actions/setup-python@v3 with: @@ -28,7 +29,7 @@ jobs: uses: mamba-org/setup-micromamba@v1 with: micromamba-version: '1.5.6-0' # any version from https://github.com/mamba-org/micromamba-releases - environment-file: environment.cpu_only.test.yml + environment-file: requirements/environment.cpu.test.yml init-shell: >- bash cache-environment: true diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 1265e42..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "llm.c"] - path = llm.c - url = https://github.com/karpathy/llm.c.git diff --git a/README.md b/README.md index d8a4ad4..cf33123 100644 --- a/README.md +++ b/README.md @@ -1,75 +1,90 @@ # Tricycle -> It don't go fast but it do be goin' +

+ tricycle_logo +

-Ever wanted to learn how a deep learning framework *actually* works under the hood? Tricycle might be for you. +Tricycle is a fast, minimal, fully functional deep learning library written from scratch using only python and numpy. -## Overview -Tricycle is a minimal framework for deep learning. The goal of this library is -not to match the speed or complexity or Pytorch or Tensorflow, but instead to get a good understanding of how -deep learning actually works at every level: from automatic differentiation all the way up to modern Language Models. It is built using nothing but standard -Python and Numpy which means that everything be understandable to anyone who knows a bit of Python. +The file `train_smol_gpt.py` trains a 49M parameter, GPT-2 style language model that can produce passable python code in ~2 days on a single RTX 3090. -Here are some things you can do with Tricycle: -- Create tensors -- Perform operations (addition, exponentiation, cosine, ...) on tensors -- Automatic differentiation of tensors -- Manipulate tensors with [einstein notation](https://en.wikipedia.org/wiki/Einstein_notation) -- Successfully train deep learning models -- Use a GPU -- Train a Transformer to produce infinite shakespeare(!) +The entire library, from the automatic differentiation engine to a GPT, is written in ~4500 lines of python + numpy code. -If you would like to learn more about the process of building tricycle, you can check out my [blog](http://bclarkson-code.com) +Using [CuPY](https://cupy.dev/), all Tricycle code can run on a either GPU or a CPU. -## Installation -Tricycle uses [conda](https://docs.conda.io/en/latest/) to manage dependencies. While we do support CPU-only computation, at time of writing, not effort has been put into optimising it. If you do have a CUDA capable GPU I would strongly reccommend installing the gpu version of Tricycle. +- [Installation](#installation) + - [CPU Installation](#cpu-installation) +- [Training a GPT on shakespeare](#training-a-gpt-on-shakespeare) +- [How it works](#how-it-works) + - [Automatic Differentiation](#automatic-differentiation) + - [Einsum](#einsum) + - [Summing along an axis](#summing-along-an-axis) + - [Sum of an entire tensor](#sum-of-an-entire-tensor) + - [Transpose](#transpose) + - [Matrix multiplication](#matrix-multiplication) + - [Building a simple neural network](#building-a-simple-neural-network) + - [Optimisations](#optimisations) + - [Batching](#batching) + - [GPU](#gpu) + - [Fusing](#fusing) + - [Other optimisations](#other-optimisations) + - [Inplace tensor updates](#inplace-tensor-updates) + - [Mathematical optimisations](#mathematical-optimisations) + - [Hardware optimisations](#hardware-optimisations) +- [Contact](#contact) -Training Smol GPT on my GPU takes ~30 mins while training Smol GPT on CPU takes ~122 hours. +## Installation +Tricycle uses [conda](https://docs.conda.io/en/latest/) to manage dependencies. While we do support CPU-only computation, optimisation efforts have been focussed on GPU computation so it is pretty slow. If you do have a CUDA capable GPU I would strongly recommend installing the gpu version of Tricycle. -### GPU Installation If you have a CUDA capable GPU, you can install Tricycle as follows. ```bash -conda env create -f environment.yml +conda env create -f environment.yml -n tricycle conda activate tricycle ``` -If you want to install test-dependencies you can do the following. + +
+ CPU and test installation +If you want to install test dependencies you can do the following. ```bash -conda env create -f environment.test.yml +conda env create -f environment.test.yml -n tricycle conda activate tricycle ``` ### CPU Installation If you want to install Tricycle for CPU, you can do the following. ```bash -conda env create -f environment.cpu_only.yml +conda env create -f environment.cpu.yml -n tricycle conda activate tricycle ``` -If you want to install test-dependencies you can do the following. +If you want to install test dependencies on CPU you can do the following. ```bash -conda env create -f environment.cpu_only.test.yml +conda env create -f environment.cpu.test.yml -n tricycle conda activate tricycle ``` +
-## Usage -Theoretically, as a fully functional deep learning library, you can build any modern Deep Learning model with Tricycle. For example, this is how you can train a small language model on the shakespeare dataset: + +## Training a GPT on Shakespeare +The following toy script will train a small GPT to generate convincing Shakespeare. +On my RTX 3090, this takes ~30 mins. For a more realistic training script with metric tracking, gradient accumulation, a validation dataset etc, take a look at `train_smol_gpt.py` ```python import pickle from tqdm import tqdm -from tricycle.configs import SmolGPTConfig +from tricycle.configs import ShakespeareConfig from tricycle.dataset import CausalLMDataset -from tricycle.loss import cross_entropy +from tricycle.loss import CrossEntropy from tricycle.models import GPT from tricycle.optimisers import AdamW -from tricycle_datasets.shakespeare import ShakespeareChar +from tricycle_datasets.shakespeare import Shakespeare -config = SmolGPTConfig() +config = ShakespeareConfig() model = GPT(config) -tokens = ShakespeareChar(vocab_size=config.vocab_size) +tokens = Shakespeare(vocab_size=config.vocab_size) dataset = ( CausalLMDataset( tokens=tokens, @@ -78,10 +93,10 @@ dataset = ( context_window=config.context_window, ) .batch() + .shuffle() .to_tensor() - .to_vector() ) -loss_fn = cross_entropy +loss_fn = CrossEntropy() optimiser = AdamW( learning_rate=config.max_learning_rate, weight_decay=config.weight_decay, @@ -89,37 +104,487 @@ optimiser = AdamW( ) model.to_gpu() - -best_loss = float("inf") -losses = [] -for step in tqdm(range(config.steps)): +loading_bar = tqdm(range(config.steps)) +for step in loading_bar: optimiser.step() inputs, outputs = next(dataset) inputs = inputs.to_gpu() outputs = outputs.to_gpu() logits = model(inputs) - loss = loss_fn(outputs, logits).sum() / ( - config.gradient_accumulation_steps - * config.batch_size - * config.context_window - ) + loss = loss_fn(outputs, logits) loss.backward() + loading_bar.set_description(f"loss: {loss:.3f}") model.update(optimiser) # save results with open("model.pkl", "wb") as f: pickle.dump(model, f) ``` +Once trained, you can generate infinite shakespeare plays as follows: + +```bash +python inference.py model.pkl +``` + +## How it works +Tricycle code centers around objects called `Tensors`. A `Tensor` is a wrapper around a numpy array that adds some extra features: + +```python +from tricycle.tensor import to_tensor + +tensor = to_tensor([1,2,3]) +print(tensor) # Output: Tensor([1. 2. 3.]) +``` + +You can do a lot of things with a tensor + +```python +from tricycle.functions import Softmax + +a = to_tensor([1,2,3]) +b = to_tensor([4,5,6]) + +# addition +print(a + b) # Output: Tensor([5. 7. 9.], name=badd) + +# comparison +print(a < b) # Output: Tensor([ True True True]) + +# more complex functions +print(Softmax()(a)) # Output: Tensor([0.09003057 0.24472848 0.66524094], name=softmax) + +``` + +### Automatic Differentiation +Unlike vanilla numpy, every operation in Tricycle is attached to a derivative. +When you do some operations on your `Tensor`, Tricycle keeps track of what you did and allows you to differentiate the output. + +```python +x = to_tensor(2) + +y = x ** 2 + 3 * x + 4 +print(y) # Output: Tensor(14.0, name=+ 4) + +# derivative of y with respect to (wrt) x is +# 2 * x + 3 = 7 +y.backward() # differentiate wrt y +print(x.grad) # Output: Tensor(7.0) +``` + +This works on multidimensional tensors + +```python +import numpy as np + +shape = (6,5,4,3,2) +a = to_tensor(np.random.random(shape)) +b = to_tensor(np.random.random(shape)) + +c = a * b # elementwise multiply + +c.backward() # differentiate wrt c +assert a.grad.close_to(b) # derivative of c wrt a is b +assert b.grad.close_to(a) # derivative of c wrt b is a +``` + +And even works through complex operations like attention + +```python +from tricycle.blocks import MultiHeadSelfAttention + +attention = MultiHeadSelfAttention( + embedding_dim=32, + n_heads=2, + context_window=32, +) + +# batch_size, n_tokens, embedding_dim +shape = (4,32,32) +input = to_tensor(np.ones(shape), is_batched=True) + +output = attention(input) +output.backward() # differentiate wrt output + +print(input.grad) # Output: Tensor([[[ 2.5441039 -2.0558214 -1.7923143 ... +assert input.grad.shape == (4,32,32) +``` + +When you run an operation (`Op`), the output has two pieces of information attached: + - `args`: The inputs to the function + - `back_fns`: The functions that should be executed to calculate the derivative wrt each of the inputs + +Surprisingly, this all that you need to perform automatic differentiation on an arbitrarily complicated sequence of `Op`s. +Because we keep track of the `args` for each operation, we can start at the output of a set of `Op`s and traverse through them to reach every input to the sequence: the operations form a tree. + +Thanks to the [chain rule](https://en.wikipedia.org/wiki/Chain_rule), if we apply each `back_fn` that we pass through on our way through the tree, when we get to an input, we will have calculated the derivative of the output wrt the input. +Despite implementing it myself, I still feel like this couldn't possibly work, and yet it does! + +The entirety of the algorithm can be found in [`tensor.py`](https://github.com/bclarkson-code/Tricycle/blob/update-readme/src/tricycle/tensor.py#L145). +It ends up being a topological sort to figure out which order to traverse the tree and then a simple traversal, applying the `back_fns` along the way. + +If you want a more detailed explanation, I've talked about it on [my blog](https://bclarkson-code.com/posts/llm-from-scratch-scalar-autograd/post.html). + +### Einsum + +Tricycle makes use of (in my opinion underutilised) einsum operations. +Einsum is a generalisation of a large number of matrix operations. + +You can use it by assigning each axis in your matrices a letter of the +alphabet (called an index). You can define the operation you want to perform +by simply listing the indices you want in your inputs and output, separated by +an arrow. + +For example, you can define the transpose of a 2d tensor as follows: + +```python +from tricycle.einsum import Einsum + +a = to_tensor([[1,2],[3,4]]) +print(Einsum("ij->ji")(a)) # Output: Tensor([[1. 3.], [2. 4.]], name=einsum ij->ji) +``` + +Here, we use einsum to swap indices i and j: a transpose. + +There are only two rules to remember with einsum: + - If an index does not appear in the output, any inputs that contain it + will be summed along that axis: + ```python + print(Einsum("ij->i")(a)) # Tensor([3. 7.], name=einsum ij->i) + ``` + + - If an index appears in more than one input, the tensors will be multiplied + along that axis + + ```python + b = to_tensor([[5,6],[7,8]) + print(Einsum("ij,jk->ik")(a,b)) # Tensor([[19. 22.], [43. 50.]], name=einsum ij,jk->ik) + ``` + +For example: +#### Summing along an axis + +https://github.com/bclarkson-code/Tricycle/assets/57139598/c575c958-19ed-4406-8a1b-d2390663ba96 + +#### Sum of an entire tensor + +https://github.com/bclarkson-code/Tricycle/assets/57139598/efbb5eaa-656c-40e5-a32d-b0f5e7bd28f5 + +#### Transpose + +https://github.com/bclarkson-code/Tricycle/assets/57139598/f8b35a6b-f102-44f1-a7cd-b6b2e765f275 + +#### Matrix multiplication + +https://github.com/bclarkson-code/Tricycle/assets/57139598/1ed18428-11de-4990-a0f4-12d1310d6898 + +Becuase every `Op` in Tricycle needs a derivative, we need to figure out what the +derivative of `Einsum` is. Thankfully, if you sit down and go through the +maths (index notation is really helpful here) you'll find that you can follow +these two, really simple rules to differentiate an einsum operation wrt a +given input: + + - Swap the indices for the input and output + - Replace the original input with your current derivative + +For example, the derivative of a transpose works like this: + +```python +# forward operation +y = Einsum('ij->ji')(a) + +# swap the input with the current grad (a grid of ones in this case) +grad = to_tensor(np.ones_like(y)) + +# swap the indices +derivative = Einsum('ji->ij')(grad) +``` + +And for a more complex operation (a dense layer on a 4d input) like this: + +```python +# forward operation +input = to_tensor(np.random.random((5, 4, 3, 2))) +weights = to_tensor(np.random.random((3,6))) +y = Einsum('zxTb,bW->zxTW')(inputs, weights) + +grad = to_tensor(np.ones_like(y)) + +# swap the indices + replace inputs +derivative = Einsum('zxTb,zxTW->bW')(inputs, grad) +``` + +This little trick significantly simplifies code, as well as reducing the +amount of maths I had to do to implement different operations. + +### Building a simple neural network + +Einsum and an automatic differentiation engine are all we need to build a simple neural network. Lets try to train a model on the [iris dataset](https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html) +We can start with a [`Dense` Layer](https://github.com/bclarkson-code/Tricycle/blob/main/src/tricycle/layers.py#L34). + +```python +from tricycle.layers import Dense + +x = to_tensor([1,2,3]) +layer = Dense(from_size=3, to_size=1) + +print(layer(x)) # Output: Tensor([-2.238703], name=dense) +``` + +Next, neural networks need a nonlinearity (otherwise they reduce to expensive linear regressions). +Tricycle has a few [nonlinearities](https://github.com/bclarkson-code/Tricycle/blob/update-readme/src/tricycle/activation.py) (also called activation functions). Here we can choose the simplest: `ReLU`. + +```python +from tricycle.activation import ReLU + +x = to_tensor([-1, 0, 1]) +activation_fn = ReLU() + +print(activation_fn(x)) # Output: Tensor([0. 0. 1.], name=> 0) +``` + +We also need a loss function. We're predicting a category so we can use CrossEntropy + +```python +from tricycle.loss import CrossEntropy + +label = to_tensor([0, 1, 2], dtype=int) +predicted = to_tensor([[0,0,1], [0,0,1], [0,0,1]]) +loss = CrossEntropy() + +print(loss(label, predicted)) # Output: Tensor(1.2181114, name=cross_entropy) +``` + +Finally, we need an optimiser to update our weights. We can use [Stochastic Gradient Descent](https://github.com/bclarkson-code/Tricycle/blob/main/src/tricycle/optimsers.py#L14). +In Tricycle, you can use an optimiser the weights of a model as follows: + +```python +from tricycle.activation import ReLU +from tricycle.layers import Dense, Sequential +from tricycle.optimisers import StochasticGradientDescent + +# build a model +layer_1 = Dense(4, 16) +layer_2 = Dense(16, 3) +relu = ReLU() +model = Sequential(layer_1, relu, layer_2) + +# create an optimiser +optimiser = StochasticGradientDescent(learning_rate=1e-1) + +# do a forward and backward pass +x = to_tensor([1,2,3,4]) +out = model(x) +out.backward() + +# update the weights +model.update(optimiser) +``` + +We can put all of this together to train a simple neural network on the iris +dataset. + +```python +import numpy as np +from sklearn.datasets import load_iris + +from tricycle.activation import ReLU +from tricycle.tensor import to_tensor +from tricycle.layers import Dense, Sequential +from tricycle.loss import CrossEntropy +from tricycle.optimisers import StochasticGradientDescent + +LEARNING_RATE = 1e-1 +N_STEPS = 1000 + +np.random.seed(42) +X, y = load_iris(return_X_y=True) +inputs = to_tensor(X, is_batched=True) + +# The class labels need to be ints for crossentropy +outputs = to_tensor(y, is_batched=True, dtype=int) + +# create a model +layer_1 = Dense(4, 16) +layer_2 = Dense(16, 3) +relu = ReLU() +model = Sequential(layer_1, relu, layer_2) + +loss_fn = CrossEntropy() +optimiser = StochasticGradientDescent(learning_rate=LEARNING_RATE) + +for step in range(N_STEPS): + y_pred = model(inputs) + loss = loss_fn(outputs, y_pred) + if step == 0: + print(f"Initial loss: {loss}") # Output: Initial loss: Tensor(3.974701, name=cross_entropy) + + loss.backward() + model.update(optimiser) + +print(f"Final loss: {loss}") # Output: Final loss: Tensor(0.08622341, name=cross_entropy) + +# Calculate accuracy +predicted_labels = np.argmax(y_pred.array, axis=-1) +accuracy = (predicted_labels == outputs.array).mean() +print(f"Accuracy: {accuracy:.2f}") # Output: Accuracy: 0.97 +``` + +### Optimisations + +Deep learning is famously computationally heavy. If we want to train anything +in a reasonable amount of time, there are several optimisations we need to make. + +#### Batching +The first, and arguably most important, optimisation is batching. Instead of +applying operations to each input individually, if we are clever about how we design +an operation, we can apply an operation to multiple operations at once. + +For example, suppose we are multiplying a batch of tensors by a weight matrix. +We could do it like this: + +```python +# batch of 1024 64x64 tensors +inputs = to_tensor(np.ones((1024, 64, 64))) +weights = to_tensor(np.random.random((64,64))) + +output = [Einsum('ij,jk->ik')(inp, weights) for inp in inputs] +# 62.2 ms ± 186 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) +``` + +But we can use the properties of `Einsum` to do the same thing like this + +```python +output = Einsum('aij,jk->aik')(inputs, weights) +# 29.1 ms ± 99.2 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) +``` + +Which is more than 2x faster. + +Some `Op`s in tricycle behave slightly differenly, depending on +whether a tensor batched or not. You can tell tricycle to use the batched +version of `Op`s for a tensor by simply calling `.to_batched`. To convert it +back, you can call `.from_batched`. + +#### GPU +As well as batching, another improvement that has a big impact on performance +is using a GPU. For this, we can use a library called [CuPY](https://cupy.dev/). +CuPY lets you run numpy code on a GPU. This means that we can use the same code +for CPU as well as GPU computation which greatly simplifies the codebase ( +and avoids me needing to write CUDA kernels, for now). + +Every tensor in tricycle has an `.xp` method. By default, this is just the +numpy library: + +``` +import numpy as np + +tensor = to_tensor([1,2,3]) + +assert tensor.xp == np +``` + +But if you call `.to_gpu` on a tensor, this is the cupy library: + +``` +import cupy as cp + +tensor = to_tensor([1,2,3]) + +tensor.to_gpu() + +assert tensor.xp == cp +``` + +(`xp` stands for `np` or `cp` because x is an "unknown"). This is really handy +because it lets us write functions like this: + +```python +def forward(self, tensor: Tensor): + """ + Apply softmax. The softmax is only applied to the final + dimension of the tensor + Note: the tensor is normalised for numeric stability + """ + xp = tensor.xp + + exp = xp.exp( + # subtract the largest value for numeric stability + tensor.array - xp.max(tensor.array, axis=-1, keepdims=True) + ) + denominator = xp.sum(exp, axis=-1, keepdims=True) + self._out = exp / denominator + + result = to_tensor(self._out) + result.args = (tensor,) + result.name = "softmax" + result.is_batched = tensor.is_batched + result.back_fns = (self.backward,) + + return result +``` + +Becuase cupy has the same interface as numpy, this function will automatically +run on the right device, with no code changes. + +#### Fusing + +One of the problems I faced when trying to use Tricycle is that it used up +a lot more memory than I expected. Because the `args` and `back_fns` need to +be stored for every `Op`, a lot of memory was being used to store intermediate +values. + +For more operations like `Softmax`, this quickly adds up. However, +we can avoid a lot of this overhead by pre-computing the combined derivative. +In the case of `Softmax` (see above), we could have built it entirely out of +low level Tricycle operations and this does work. When you sit down and work +out the derivative for softmax manually, it turns out to be pretty simple: + +```python +def backward(self, grad: Tensor) -> Tensor: + xp = grad.xp + + inner = xp.sum(grad.array * self._out, axis=-1, keepdims=True) + self._grad = self._out * (grad.array - inner) + return to_tensor( + self._grad, + is_batched=grad.is_batched, + requires_grad=grad.requires_grad, + ) +``` + +This kind of operation is a very common optimisation technique in deep learning +called 'Operator Fusing'. This ends up being a big optimisation for tricycle +because it lets us replace operations like `MultiHeadSelfAttention`, which +would usually have 10s of intermediate values, with a single `forward` and +`backward` function with a minimal set of intermediate values. + +#### Other optimisations +While batching, using a GPU and fusing are the major optimisations, I'd like +to provide some honorable mentions. + +##### Inplace tensor updates +While probably obvious to many readers, updating tensors in-place rather than +replacing them with a new tensor caused a big speedup. + +##### Mathematical optimisations +Operations like `CrossEntropy` can be implemented by applying a softmax and then +applying the crossentropy operation but, if you do a bit of algebra, +you can do something called the `log-sum-exp` trick to simplify the expression +and cut down on the computations needed. -This will fetch the complete works of shakespeare, build it into a dataset, tokenise it, and train a simple GPT on it. +##### Hardware optimisations +As mentioned above, the GPU computation was performed on an NVIDIA RTX 3090. +Understandably, this gets quite hot when training (probably something to do with +it being in my cupboard?) which can reduce performance due to thermal +throttling. However, I found that by removing my computer case and placing +a household fan on top, I get about 30% better performance. -As you can see, it looks pretty similar to other frameworks like PyTorch. However, because Tricycle is much smaller and simpler, if you want to figure out how something works, you can dive into the code and get an answer in a few minutes instead of hours. +![IMG_0713](https://github.com/bclarkson-code/Tricycle/assets/57139598/958f12b4-caaa-4f2a-b9d0-2f5a7fc1e5a5) -For a proper training script with all the bells and whistles (logging, gradient accumulation etc.) take a look at `train_smol_gpt.py` which will train a transformer to produce infinite shakespeare in ~35 minutes (on my machine, with an RTX 3090). +Putting all of these things together, Tricycle can train a small language model on shakespeare in ~30 mins. Andrej Karpathy can [do this in pytorch](https://github.com/karpathy/nanoGPT/tree/master) in around 7 minutes on my machine (with a like-for-like config) which, given that the entire Tricycle project is in python, means that Tricycle is surprisingly fast. That said, more work is needed to get the speed up. ## Contact -Want to learn more / have a chat / work together? -You can send an email to: [bclarkson-code@proton.me](mailto:bclarkson-code@proton.me) +Want to work together? You can reach me at: [bclarkson-code@proton.me](mailto:bclarkson-code@proton.me) diff --git a/assets/EinsumIJKToIK.mp4 b/assets/EinsumIJKToIK.mp4 new file mode 100644 index 0000000..5424d16 Binary files /dev/null and b/assets/EinsumIJKToIK.mp4 differ diff --git a/assets/EinsumIJTo.mp4 b/assets/EinsumIJTo.mp4 new file mode 100644 index 0000000..0c67f59 Binary files /dev/null and b/assets/EinsumIJTo.mp4 differ diff --git a/assets/EinsumIJToJ.mp4 b/assets/EinsumIJToJ.mp4 new file mode 100644 index 0000000..e7e9dce Binary files /dev/null and b/assets/EinsumIJToJ.mp4 differ diff --git a/assets/EinsumIJToJI.mp4 b/assets/EinsumIJToJI.mp4 new file mode 100644 index 0000000..5db64cb Binary files /dev/null and b/assets/EinsumIJToJI.mp4 differ diff --git a/experiment.py b/experiment.py deleted file mode 100644 index a03d087..0000000 --- a/experiment.py +++ /dev/null @@ -1,244 +0,0 @@ -""" -Training script for training a SmolGPT model on the complete -works of shakespeare. - -The hyperparams for this model are very much a work in progress -""" - -import os -import pickle -from pathlib import Path - -import mlflow -import numpy as np -from omegaconf import OmegaConf -from ray import train, tune -from tqdm import tqdm - -from inference import get_sample -from tricycle.binary import _shapes_match -from tricycle.dataset import CausalLMDataset -from tricycle.loss import cross_entropy -from tricycle.models import GPT -from tricycle.optimisers import StochasticGradientDescent -from tricycle.scheduler import lr_schedule -from tricycle.tokeniser import BPETokeniser -from tricycle_datasets.shakespeare import Shakespeare - -EXPERIMENT_NAME = "SmolGPT:base:find_lr_schedule" - -search_space = { - "model": { - "embedding_dim": 384, - "context_window": 256, - "vocab_size": 1024, - "n_heads": 6, - "n_layers": 6, - "expansion_ratio": 4, - "activation_fn": "gelu", - "input_dropout_prob": 0.2, - "attention_dropout_prob": 0.2, - "residual_dropout_prob": 0.2, - "linear_dropout_prob": 0.2, - "batch_size": 12, - }, - "train": { - "max_learning_rate": 1e-3, - "min_learning_rate": 1e-4, - "warmup_steps": 100, - "weight_decay": 0, - "momentum": 0, - "shuffle": True, - }, - "mlflow": { - "tracking_uri": "http://localhost:5000", - "experiment_name": EXPERIMENT_NAME, - }, - "experiment": { - "train_steps": 25_000, - "valid_steps": 5, - "valid_every": 25, - "num_trials": 1, - }, -} - - -def train_model(config): - np.random.seed(0) - config = OmegaConf.create(config) - - mlflow.set_tracking_uri(config.mlflow.tracking_uri) - mlflow.set_experiment(config.mlflow.experiment_name) - os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" - - model = GPT(config.model) - - current_dir = Path(__file__).parent.absolute() - raw_data_path = current_dir / "datasets/shakespeare/raw_data.txt" - tokeniser_path = current_dir / "datasets/shakespeare/tokeniser.pkl" - token_path = current_dir / "datasets/shakespeare/tokens_1024.pkl" - shakespeare = Shakespeare( - vocab_size=config.model.vocab_size, - raw_data_path=raw_data_path, - tokeniser_path=tokeniser_path, - token_path=token_path, - ) - - # train-test split - n_valid_tokens = ( - config.model.context_window - + config.experiment.valid_steps * config.model.batch_size - + 1 - ) - n_train_tokens = ( - config.model.context_window - + config.experiment.train_steps * config.model.batch_size - + 1 - ) - assert n_train_tokens + n_valid_tokens < len( - shakespeare - ), "Dataset too small" - train_dataset = ( - CausalLMDataset( - tokens=shakespeare[:n_train_tokens], - vocab_size=config.model.vocab_size, - batch_size=config.model.batch_size, - context_window=config.model.context_window, - ) - .batch() - .to_tensor() - .to_vector() - ) - if config.train.shuffle: - train_dataset = train_dataset.shuffle() - test_dataset = ( - CausalLMDataset( - tokens=shakespeare[-n_valid_tokens:], - vocab_size=config.model.vocab_size, - batch_size=config.model.batch_size, - context_window=config.model.context_window, - ) - .batch() - .to_tensor() - .to_vector() - ) - loss_fn = cross_entropy - optimiser = StochasticGradientDescent( - learning_rate=lr_schedule( - 0, - max_learning_rate=config.train.max_learning_rate, - min_learning_rate=config.train.min_learning_rate, - warmup_steps=config.train.warmup_steps, - total_steps=config.experiment.train_steps, - ), - weight_decay=config.train.weight_decay, - momentum=config.train.momentum, - ) - - model.to_gpu() - - with mlflow.start_run(): - for key, values in config.items(): - mlflow.log_params({f"{key}/{k}": v for k, v in values.items()}) - - for step, (inputs, outputs) in tqdm( - enumerate(train_dataset), total=config.experiment.train_steps - ): - inputs = inputs.to_gpu() - outputs = outputs.to_gpu() - - logits = model(inputs) - loss = loss_fn(outputs, logits).from_vector().mean().mean() - loss.backward() - model.update(optimiser) - - mlflow.log_metric("train_loss", float(loss.numpy()), step=step) - - # clean up the computational graph - loss.cleanup() - - # update the lr - optimiser.learning_rate = lr_schedule( - step, - max_learning_rate=config.train.max_learning_rate, - min_learning_rate=config.train.min_learning_rate, - warmup_steps=config.train.warmup_steps, - total_steps=config.experiment.train_steps, - ) - mlflow.log_metric( - "learning_rate", optimiser.learning_rate, step=step - ) - - # validation - if step % config.experiment.valid_every == 0: - valid_loss = 0 - for inputs, outputs in test_dataset: - logits = model(inputs) - try: - loss = ( - loss_fn(outputs, logits) - .from_vector() - .mean() - .mean() - ) - except Exception as e: - raise Exception( - inputs.shape, outputs.shape, logits.shape - ) - - valid_loss += float(loss.numpy()) - loss.cleanup() - valid_loss /= len(test_dataset) - - sample_text = "HAMLET: To be or not to be" - assert isinstance(shakespeare.tokeniser, BPETokeniser) - predicted = get_sample( - sample_text, model=model, tokeniser=shakespeare.tokeniser - ) - model.zero_grad() - - mlflow.log_metric("valid_loss", valid_loss, step=step) - mlflow.log_text(predicted, f"generated/{step}.txt") - train.report({"valid_loss": valid_loss}) - - # final loss - valid_loss = 0 - for inputs, outputs in test_dataset: - logits = model(inputs) - loss = loss_fn(outputs, logits).from_vector().mean().mean() - valid_loss += float(loss.numpy()) - loss.cleanup() - valid_loss /= len(test_dataset) - mlflow.log_metric("valid_loss", valid_loss, step=len(train_dataset)) - - # save the model - model_dir = Path( - f"/home/ben/Documents/Tricycle/results/{EXPERIMENT_NAME}/models" - ) - model_dir.mkdir(parents=True, exist_ok=True) - with open( - model_dir / f"lr_{config.train.learning_rate}.pkl", "wb" - ) as f: - pickle.dump(model, f) - - return {"valid_loss": valid_loss} - - -if __name__ == "__main__": - tuner = tune.Tuner( - tune.with_resources( - train_model, - {"gpu": 1, "cpu": 16}, - ), - tune_config=tune.TuneConfig( - metric="valid_loss", - num_samples=search_space["experiment"]["num_trials"], - ), - run_config=train.RunConfig( - storage_path=Path("results").absolute(), - name=EXPERIMENT_NAME, - ), - param_space=search_space, - ) - results = tuner.fit() - results.get_dataframe().to_csv(f"{EXPERIMENT_NAME}_results.csv") diff --git a/get_memory_usage.py b/get_memory_usage.py deleted file mode 100644 index eb12298..0000000 --- a/get_memory_usage.py +++ /dev/null @@ -1,19 +0,0 @@ -import humanize -import pandas as pd - -df = pd.read_csv("memory.log") -df["memory_diff"] = df["used_bytes"] - df["used_bytes"].shift() -df["time_diff"] = df["timestamp"] - df["timestamp"].shift() -df = df.dropna() - -print(df.iloc[:49]) -grouped = df.groupby("stage").agg({"memory_diff": "mean", "time_diff": "mean"}) -grouped = grouped.sort_values(by="memory_diff", ascending=False) -grouped["memory_diff"] = grouped["memory_diff"].apply(humanize.naturalsize) -grouped["time_diff"] = (grouped["time_diff"] * 1e6).astype(int) - - -print(grouped) - -# df["memory_diff_human"] = df["memory_diff"].apply(humanize.naturalsize) -# print(df[["stage", "memory_diff_human", "time_diff_μs"]]) diff --git a/inference.py b/inference.py index 6aa963b..9b086f8 100644 --- a/inference.py +++ b/inference.py @@ -6,12 +6,12 @@ import tiktoken from tqdm import tqdm -from tricycle.configs import SmolGPTConfig +from tricycle.configs import ShakespeareConfig, SmolGPTConfig from tricycle.functions import Softmax from tricycle.layers import Dropout, Layer from tricycle.models import GPT from tricycle.tensor import to_tensor -from tricycle.tokeniser.tokeniser import BPETokeniser, BPETokeniserNumba +from tricycle.tokeniser import BPETokeniser from tricycle_datasets.codeparrot import CodeParrot from tricycle_datasets.shakespeare import Shakespeare @@ -48,7 +48,6 @@ def deactivate_dropout(model: Layer) -> Layer: # TODO: allow tokensiers that arent shakespeare def generate( model: GPT, - tokeniser: BPETokeniser | BPETokeniserNumba | tiktoken.core.Encoding, tokens: np.ndarray | None = None, sample=True, temperature=0.8, @@ -65,17 +64,17 @@ def generate( encoded = to_tensor( [tokens], dtype=int, requires_grad=False - ).to_vector() + ).to_batched() pred = model(encoded) pred = Softmax()(pred / temperature) if pred.on_gpu: probabilities = pred.xp.asnumpy( - pred._data[0][config.context_window - 1] + pred.array[0][config.context_window - 1] ) else: - probabilities = pred._data[0][config.context_window - 1] + probabilities = pred.array[0][config.context_window - 1] # sample according to probabilities if sample: @@ -90,7 +89,7 @@ def generate( def get_sample( model: GPT, - tokeniser: BPETokeniser | BPETokeniserNumba | tiktoken.core.Encoding, + tokeniser: BPETokeniser | tiktoken.core.Encoding, sample_tokens: np.ndarray | None = None, ) -> str: """ @@ -123,18 +122,20 @@ def get_sample( if __name__ == "__main__": np.random.seed(0) - config = SmolGPTConfig() + config = ShakespeareConfig() dataset = Shakespeare(config.vocab_size) - model = load_model(sys.argv[1]) - model.to_gpu(0) + import cupy + + with cupy.cuda.Device(1): + model = load_model(sys.argv[1]) + model.to_gpu(1) deactivate_dropout(model) sample_text = dataset.raw_data_path.read_text()[:2048] - for token in generate( - text=sample_text, model=model, tokeniser=dataset, sample=True - ): + sample_tokens = dataset.tokeniser.encode(sample_text) + for token in generate(tokens=sample_tokens, model=model, sample=True): token = int(token) token = dataset.decode([token]) print(token, end="", flush=True) diff --git a/environment.cpu_only.test.yml b/requirements/environment.cpu.test.yml similarity index 90% rename from environment.cpu_only.test.yml rename to requirements/environment.cpu.test.yml index fa302c3..5eafa8f 100644 --- a/environment.cpu_only.test.yml +++ b/requirements/environment.cpu.test.yml @@ -14,6 +14,8 @@ dependencies: - mlflow - psutil - numba + - tiktoken + - datasets # test dependencies - scikit-learn - pytest @@ -23,4 +25,4 @@ dependencies: # install tricycle - pip - pip: - - -e . + - -e ../ diff --git a/environment.cpu_only.yml b/requirements/environment.cpu.yml similarity index 81% rename from environment.cpu_only.yml rename to requirements/environment.cpu.yml index e1afae2..8634e1c 100644 --- a/environment.cpu_only.yml +++ b/requirements/environment.cpu.yml @@ -12,6 +12,8 @@ dependencies: - mlflow - psutil - numba + - tiktoken + - datasets - pip - pip: - - -e . + - -e ../ diff --git a/environment.test.yml b/requirements/environment.test.yml similarity index 87% rename from environment.test.yml rename to requirements/environment.test.yml index 4d35151..bdbb011 100644 --- a/environment.test.yml +++ b/requirements/environment.test.yml @@ -14,7 +14,10 @@ dependencies: - mlflow - psutil - numba + - tiktoken + - datasets # gpu dependencies + - cuda-version==12 - cudnn - cutensor - nccl @@ -29,4 +32,4 @@ dependencies: # install tricycle - pip - pip: - - -e . + - -e ../ diff --git a/environment.yml b/requirements/environment.yml similarity index 82% rename from environment.yml rename to requirements/environment.yml index 6098305..8d7e5bc 100644 --- a/environment.yml +++ b/requirements/environment.yml @@ -13,7 +13,10 @@ dependencies: - mlflow - psutil - numba + - tiktoken + - datasets # gpu dependencies + - cuda-version==12 - cudnn - cutensor - nccl @@ -22,4 +25,4 @@ dependencies: # install tricycle - pip - pip: - - -e . + - -e ../ diff --git a/src/tricycle/activation.py b/src/tricycle/activation.py index 2dceab7..9639da7 100644 --- a/src/tricycle/activation.py +++ b/src/tricycle/activation.py @@ -51,11 +51,11 @@ def backward(self, grad: Tensor): left = xp.tanh(inner) cosh = xp.cosh(inner) right = coef / (cosh * cosh) - self._grad = 0.5 * (1 + left + right) * grad._data + self._grad = 0.5 * (1 + left + right) * grad.array result = to_tensor( self._grad, - is_vector=grad.is_vector, + is_batched=grad.is_batched, requires_grad=grad.requires_grad, ) result.name = "gelu_back" @@ -63,13 +63,13 @@ def backward(self, grad: Tensor): def forward(self, tensor: Tensor): xp = tensor.xp - self._input = tensor._data - inner = self.CONST_1 * (tensor._data + self.CONST_2 * tensor._data**3) - result = tensor._data * 0.5 * (1 + xp.tanh(inner)) + self._input = tensor.array + inner = self.CONST_1 * (tensor.array + self.CONST_2 * tensor.array**3) + result = tensor.array * 0.5 * (1 + xp.tanh(inner)) result = to_tensor( result, - is_vector=tensor.is_vector, + is_batched=tensor.is_batched, requires_grad=tensor.requires_grad, ) result.name = "gelu" @@ -135,7 +135,7 @@ def forward(self, x: Tensor): x = self.linear(x) # this is slow and terrible hack left, right = x.split(2) - if right.is_vector: + if right.is_batched: bias = self.bias.repeat(right.shape[1]) else: bias = self.bias.repeat(right.shape[0]) diff --git a/src/tricycle/attention.py b/src/tricycle/attention.py index bbd8076..fa2a0dc 100644 --- a/src/tricycle/attention.py +++ b/src/tricycle/attention.py @@ -40,7 +40,7 @@ def backward(self, grad: Tensor): xp = grad.xp in_shape = (self.batch_size, self.context_window, self.embedding_dim) - attention = grad._data + attention = grad.array # TODO: come up with a better name # smush @@ -92,7 +92,7 @@ def backward(self, grad: Tensor): def forward(self, tensor: Tensor): xp = tensor.xp - assert tensor.is_vector + assert tensor.is_batched # split the input into 3 peices self._input = tensor @@ -101,7 +101,7 @@ def forward(self, tensor: Tensor): value = tensor[:, :, self.embedding_dim * 2 :] # Figure out how big everything is - self.batch_size = key._data.shape[0] + self.batch_size = key.array.shape[0] self.head_size = self.embedding_dim // self.n_heads self.n_tokens = key.shape[-2] head_shape = ( @@ -113,9 +113,9 @@ def forward(self, tensor: Tensor): out_shape = (self.batch_size, self.n_tokens, self.embedding_dim) # reshape and reorder the heads - key = key._data - query = query._data - value = value._data + key = key.array + query = query.array + value = value.array key = key.reshape(head_shape) query = query.reshape(head_shape) @@ -149,7 +149,7 @@ def forward(self, tensor: Tensor): attention = xp.einsum("BNIj, BNjH -> BINH", attention, value) attention = attention.reshape(out_shape) - result = to_tensor(attention, is_vector=True) + result = to_tensor(attention, is_batched=True) result.back_fns = (self.backward,) result.args = (self._input,) return result diff --git a/src/tricycle/binary.py b/src/tricycle/binary.py index 0779966..3d47706 100644 --- a/src/tricycle/binary.py +++ b/src/tricycle/binary.py @@ -1,21 +1,40 @@ -from functools import partial +""" +In tricycle (because it makes the derivatives easier) we only allow operations +on two matrices if they are the same shape. We call these `binary` operations. +This file contains all of the binary operations in tricycle + +In deep learning, almost all of the time you can use an einsum operation to +handle what you want to do. This includes: + - Transposing + - Elementwise multiplication + - Matrix multiplication + - ... + +Interestingly, all of the operations here can be made out of clever +combinations of unary operations and einsums, (exercise for the reader?) +but it is a bit more efficient to give them their own, optimised `Op`s +""" from numpy.typing import ArrayLike from tricycle.ops import Einsum, Op from tricycle.tensor import Tensor, nothing, select_backend, to_tensor -from tricycle.unary import UnaryDivide, UnaryMultiply +from tricycle.unary import UnaryDivide def _shapes_match(tensor_1: Tensor, tensor_2: Tensor) -> bool: + """ + Binary operations can only be performed if the matrices are the same shape + This function checks that we are allowed to apply a binary Op. + """ # sourcery skip: assign-if-exp, merge-duplicate-blocks, remove-redundant-if - if tensor_1.is_vector and tensor_2.is_vector: + if tensor_1.is_batched and tensor_2.is_batched: shape_1 = tensor_1.shape shape_2 = tensor_2.shape - elif tensor_1.is_vector: + elif tensor_1.is_batched: shape_1 = tensor_1.shape[1:] shape_2 = tensor_2.shape - elif tensor_2.is_vector: + elif tensor_2.is_batched: shape_1 = tensor_1.shape shape_2 = tensor_2.shape[1:] else: @@ -24,7 +43,7 @@ def _shapes_match(tensor_1: Tensor, tensor_2: Tensor) -> bool: if shape_1 != shape_2: raise ValueError( - f"Shapes {shape_1} and {shape_2} do not match: {tensor_1._data.shape}, {tensor_2._data.shape}" + f"Shapes {shape_1} and {shape_2} do not match: {tensor_1.array.shape}, {tensor_2.array.shape}" ) return shape_1 == shape_2 @@ -34,45 +53,45 @@ def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """ Applies the cosine function, elementwise, to a tensor """ - xp = select_backend(tensor_1._data, tensor_2._data) + xp = select_backend(tensor_1.array, tensor_2.array) assert _shapes_match(tensor_1, tensor_2) - self._out = xp.add(tensor_1._data, tensor_2._data) + self._out = xp.add(tensor_1.array, tensor_2.array) result = to_tensor(self._out) result.args = (tensor_1, tensor_2) result.back_fns = (nothing, nothing) result.name = "badd" - if tensor_1.is_vector or tensor_2.is_vector: - result.is_vector = True + if tensor_1.is_batched or tensor_2.is_batched: + result.is_batched = True return result class BinarySubtract(Op): def back_fn_2(self, grad: Tensor) -> Tensor: - self._grad = -grad._data + self._grad = -grad.array result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: """ Subtract one tensor from another """ - xp = select_backend(tensor_1._data, tensor_2._data) + xp = select_backend(tensor_1.array, tensor_2.array) assert _shapes_match(tensor_1, tensor_2) - self._out = xp.subtract(tensor_1._data, tensor_2._data) + self._out = xp.subtract(tensor_1.array, tensor_2.array) result = to_tensor(self._out) result.args = (tensor_1, tensor_2) result.back_fns = (nothing, self.back_fn_2) result.name = "badd" - if tensor_1.is_vector or tensor_2.is_vector: - result.is_vector = True + if tensor_1.is_batched or tensor_2.is_batched: + result.is_batched = True return result @@ -109,15 +128,15 @@ class BinaryMax(Op): _is_bigger_2: ArrayLike | None def back_fn_1(self, grad: Tensor) -> Tensor: - self._grad_1 = grad._data * self._is_bigger_1 + self._grad_1 = grad.array * self._is_bigger_1 result = to_tensor(self._grad_1) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def back_fn_2(self, grad: Tensor) -> Tensor: - self._grad_2 = grad._data * self._is_bigger_2 + self._grad_2 = grad.array * self._is_bigger_2 result = to_tensor(self._grad_2) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: @@ -128,19 +147,19 @@ def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: The two tensors must have the same shape if elements are equal, return the first """ - xp = select_backend(tensor_1._data, tensor_2._data) + xp = select_backend(tensor_1.array, tensor_2.array) assert _shapes_match(tensor_1, tensor_2) - self._out = xp.maximum(tensor_1._data, tensor_2._data) + self._out = xp.maximum(tensor_1.array, tensor_2.array) - self._is_bigger_1 = tensor_1._data > tensor_2._data - self._is_bigger_2 = tensor_1._data <= tensor_2._data + self._is_bigger_1 = tensor_1.array > tensor_2.array + self._is_bigger_2 = tensor_1.array <= tensor_2.array result = to_tensor(self._out) result.args = (tensor_1, tensor_2) result.back_fns = (self.back_fn_1, self.back_fn_2) result.name = "bmax" - result.is_vector = tensor_1.is_vector or tensor_2.is_vector + result.is_batched = tensor_1.is_batched or tensor_2.is_batched return result @@ -149,15 +168,15 @@ class BinaryMin(Op): _is_smaller_2: Tensor | None def back_fn_1(self, grad: Tensor) -> Tensor: - self._grad_1 = grad._data * self._is_smaller_1 + self._grad_1 = grad.array * self._is_smaller_1 result = to_tensor(self._grad_1) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def back_fn_2(self, grad: Tensor) -> Tensor: - self._grad_2 = grad._data * self._is_smaller_2 + self._grad_2 = grad.array * self._is_smaller_2 result = to_tensor(self._grad_2) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: @@ -168,19 +187,19 @@ def forward(self, tensor_1: Tensor, tensor_2: Tensor) -> Tensor: The two tensors must have the same shape if elements are equal, return the first """ - xp = select_backend(tensor_1._data, tensor_2._data) + xp = select_backend(tensor_1.array, tensor_2.array) assert _shapes_match(tensor_1, tensor_2) - self._out = xp.minimum(tensor_1._data, tensor_2._data) + self._out = xp.minimum(tensor_1.array, tensor_2.array) - self._is_smaller_1 = tensor_1._data < tensor_2._data - self._is_smaller_2 = tensor_1._data >= tensor_2._data + self._is_smaller_1 = tensor_1.array < tensor_2.array + self._is_smaller_2 = tensor_1.array >= tensor_2.array result = to_tensor(self._out) result.args = (tensor_1, tensor_2) result.back_fns = (self.back_fn_1, self.back_fn_2) result.name = "bmax" - result.is_vector = tensor_1.is_vector or tensor_2.is_vector + result.is_batched = tensor_1.is_batched or tensor_2.is_batched return result @@ -188,11 +207,11 @@ class BinaryMask(Op): _mask: ArrayLike | None = None def back_fn(self, grad: Tensor) -> Tensor: - xp = select_backend(grad._data, self._mask) - self._grad = xp.where(self._mask, grad._data, 0) + xp = select_backend(grad.array, self._mask) + self._grad = xp.where(self._mask, grad.array, 0) result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor, mask: Tensor) -> Tensor: @@ -200,19 +219,19 @@ def forward(self, tensor: Tensor, mask: Tensor) -> Tensor: Apply a binary mask to a numpy array, setting values to 0 where the mask is True """ - xp = select_backend(tensor._data, mask._data) + xp = select_backend(tensor.array, mask.array) assert _shapes_match(tensor, mask) assert ( not mask.requires_grad ), "Cannot compute gradient of a binary mask" - self._out = xp.where(mask._data, tensor._data, 0) - self._mask = mask._data + self._out = xp.where(mask.array, tensor.array, 0) + self._mask = mask.array result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = "bmask" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result diff --git a/src/tricycle/blocks.py b/src/tricycle/blocks.py index 840a5a0..bece37a 100644 --- a/src/tricycle/blocks.py +++ b/src/tricycle/blocks.py @@ -32,9 +32,9 @@ def masked_fill( Apply an attention_mask to a tensor """ xp = tensor.xp - repeats = tensor.shape[1] if tensor.is_vector else tensor.shape[0] + repeats = tensor.shape[1] if tensor.is_batched else tensor.shape[0] mask = xp.stack( - [full_mask[: mask_shape[0], : mask_shape[1]]._data] * repeats + [full_mask[: mask_shape[0], : mask_shape[1]].array] * repeats ) mask = to_tensor(mask, requires_grad=False, name="mask") result = tensor + mask @@ -56,7 +56,7 @@ def __init__( embedding_dim: int, n_heads: int, context_window: int, - residual_dropout_prob: float, + residual_dropout_prob: float = 0.0, initialiser=init_xavier, ): # set the constants diff --git a/src/tricycle/configs.py b/src/tricycle/configs.py index 46cc1c4..d4cee07 100644 --- a/src/tricycle/configs.py +++ b/src/tricycle/configs.py @@ -3,6 +3,10 @@ class GPTConfig: + """ + Base config for GPT models + """ + embedding_dim: int context_window: int vocab_size: int @@ -33,6 +37,51 @@ class GPTConfig: mlflow_tracking_uri: str mlflow_experiment_name: str + def dict(self) -> dict[str, int | float | str | bool]: + out = {} + for k, v in self.__class__.__dict__.items(): + if k.startswith("__"): + continue + + if callable(v): + continue + out[k] = v + return out + + +class ShakespeareConfig(GPTConfig): + embedding_dim = 384 + context_window = 256 + vocab_size = 1024 + n_heads = 6 + n_layers = 6 + expansion_ratio = 4 + activation_fn = "gelu" + + input_dropout_prob = 0.2 + residual_dropout_prob = 0.2 + linear_dropout_prob = 0.2 + + max_learning_rate = 1e-3 + min_learning_rate = 1e-4 + warmup_steps = 100 + weight_decay = 1e-1 + momentum = 0 + beta1 = 0.9 + beta2 = 0.99 + + steps = 5000 + eval_interval = 250 + eval_steps = 128 + batch_size = 32 + gradient_accumulation_steps = 1 + sample_size = 512 + + device_idx = 0 + + mlflow_enabled = True + mlflow_tracking_uri = "http://localhost:5000" + class SmolGPTConfig(GPTConfig): embedding_dim = 384 @@ -66,14 +115,3 @@ class SmolGPTConfig(GPTConfig): mlflow_enabled = True mlflow_tracking_uri = "http://localhost:5000" - - def dict(self) -> dict[str, int | float | str | bool]: - out = {} - for k, v in SmolGPTConfig.__dict__.items(): - if k.startswith("__"): - continue - - if callable(v): - continue - out[k] = v - return out diff --git a/src/tricycle/cuda/softmax.cu b/src/tricycle/cuda/softmax.cu deleted file mode 100644 index 4c73c3b..0000000 --- a/src/tricycle/cuda/softmax.cu +++ /dev/null @@ -1,171 +0,0 @@ -// Kernels for efficiently computing softmax derivatives -extern "C" __global__ void -softmax_back_fn_3d(const float *softmax_result, const float *grad, - const int n_batches, const int n_tokens, - const int n_elements, float *out) { - // find indices for batch and token - int i = blockDim.x * blockIdx.x + threadIdx.x; - - if (i < n_batches * n_tokens * n_tokens * n_elements) { - int batch_idx = i / (n_tokens * n_tokens * n_elements); - int head_idx = (i / (n_tokens * n_elements)) % n_tokens; - int token_idx = (i / n_elements) % n_tokens; - int element_idx = i % n_elements; - - int offset = batch_idx * n_tokens * n_elements * n_elements + - head_idx * n_tokens * n_elements + token_idx * n_elements; - - float *out_idx = out + offset; - const float *softmax_idx = softmax_result + offset; - const float *grad_idx = grad + offset; - - float result = 0.0; - for (int j = 0; j < n_elements; j++) { - float indicator = j == element_idx ? 1.0f : 0.0f; - float deriv = softmax_idx[element_idx] * (indicator - softmax_idx[j]); - result += deriv * grad_idx[element_idx]; - } - out_idx[element_idx] = result; - } - - /* - for (int i = 0; i < n_elements; i++) { - float result = 0.0; - - for (int j = 0; j < n_elements; j++) { - float indicator = i == j ? 1.0f : 0.0f; - float deriv = softmax_idx[i] * (indicator - softmax_idx[j]); - result += deriv * grad_idx[i]; - } - out_id[i] = __float2int_rn(i); - } - */ - out_id[0] = 1.0; -} -// __global__ void softmax_back_fn_2d(const float *softmax_result, -// const float *grad, const int n_elements, -// float *out) { -// int indicator, i, j, deriv; -// -// // index for vector -// int t = blockDim.x * blockIdx.x + threadIdx.x; -// -// // index for element in vector -// int tid = blockDim.y * blockIdx.y + threadIdx.y; -// i = tid / n_elements; -// j = tid % n_elements; -// -// if (i == j) { -// indicator = 1; -// } else { -// indicator = 0; -// } -// -// deriv = softmax_result[t, i] * (indicator - softmax_result[t, j]); -// out[t, j] = deriv * grad[t, i]; -// } -// -// __global__ void softmax_autoregressive_backward_kernel2( -// const float *grad, const float *softmax_result, int n_batches, int -// n_tokens, int n_elements, float *out, ) { -// int t3 = blockIdx.x * blockDim.x + threadIdx.x; -// int idx = blockIdx.y * n_tokens * n_tokens; -// if (t3 >= n_tokens) { -// return; -// } -// -// for (int t = t3; t < n_tokens; t++) { -// float result = 0.0; -// const float *softmax_result_bth = softmax_result + idx + t * n_tokens; -// const float *grad_bth = grad + idx + t * n_tokens; -// float *out_bth = out + idx + t * n_tokens; -// -// for (int t2 = 0; t2 <= t; t2++) { -// float indicator = t2 == t3 ? 1.0f : 0.0f; -// float local_derivative = -// softmax_result_bth[t2] * (indicator - softmax_result_bth[t3]); -// result += local_derivative * grad_bth[t2]; -// } -// -// out_bth[t3] = result; -// } -// } -// __global__ void softmax_back_fn_3d(const float *softmax_result, -// const float *grad, -// -// const int n_tokens, const int n_elements, -// float *out) { -// int indicator, i, j, b, t, deriv; -// -// // find indices for batch and token -// int xid = blockDim.x * blockIdx.x + threadIdx.x; -// b = xid / n_tokens; -// t = xid % n_tokens; -// -// // index for element in vector -// int tid = blockDim.y * blockIdx.y + threadIdx.y; -// i = tid / n_elements; -// j = tid % n_elements; -// -// if (i == j) { -// indicator = 1; -// } else { -// indicator = 0; -// } -// -// deriv = softmax_result[b, t, i] * (indicator - softmax_result[b, t, j]); -// out[b, t, j] = deriv * grad[b, t, i]; -// } -// __global__ void softmax_back_fn_4d(const float *softmax_result, -// const float *grad, const int n_heads, -// const int n_tokens, const int n_elements, -// float *out) { -// int indicator, i, j, b, h, t, remainder, deriv; -// -// // find indices for batch and token -// int xid = blockDim.x * blockIdx.x + threadIdx.x; -// b = xid / (n_tokens * n_heads); -// remainder = xid % (n_tokens * n_heads); -// h = remainder / n_tokens; -// t = remainder % n_tokens; -// -// // index for element in vector -// int tid = blockDim.y * blockIdx.y + threadIdx.y; -// i = tid / n_elements; -// j = tid % n_elements; -// -// if (i == j) { -// indicator = 1; -// } else { -// indicator = 0; -// } -// -// deriv = softmax_result[b, h, t, i] * (indicator - softmax_result[b, h, t, -// j]); -// // sometimes this returns nans -// out[b, h, t, j] = deriv * grad[b, h, t, i]; -// } -// __global__ void softmax_back_fn_3d_a(const float *grad, -// const float *softmax_result, -// const int n_batches, const int n_tokens, -// const int n_elements, float *out) { -// int t3 = blockIdx.x * blockDim.x + threadIdx.x; -// int idx = blockIdx.y * n_tokens * n_tokens; -// if (t3 >= n_tokens) { -// return; -// } -// for (int t = t3; t < n_tokens; t++) { -// float result = 0.0; -// const float *softmax_result_bth = softmax_result + idx + t * n_tokens; -// const float *grad_bth = grad + idx + t * n_tokens; -// float *out_bth = out + idx + t * n_tokens; -// for (int t2 = 0; t2 <= t; t2++) { -// float indicator = t2 == t3 ? 1.0f : 0.0f; -// float local_derivative = -// softmax_result_bth[t2] * (indicator - softmax_result_bth[t3]); -// result += local_derivative * grad_bth[t2]; -// } -// out_bth[t3] = result; -// } -// } -// } diff --git a/src/tricycle/dataset.py b/src/tricycle/dataset.py index 0f6ab4b..fe23249 100644 --- a/src/tricycle/dataset.py +++ b/src/tricycle/dataset.py @@ -57,7 +57,7 @@ def copy(self): class InfiniteBatchDataset(Dataset): is_infinite = True _to_tensor = False - is_vector = False + is_batched = True def __init__(self, inputs: Sequence, outputs: Sequence, batch_size: int): super().__init__(inputs, outputs) @@ -83,12 +83,12 @@ def __getitem__(self, idx: int): if self._to_tensor: batch_inputs = to_tensor( batch_inputs, - is_vector=self.is_vector, + is_batched=self.is_batched, dtype=batch_outputs.dtype, ) batch_outputs = to_tensor( batch_outputs, - is_vector=self.is_vector, + is_batched=self.is_batched, dtype=batch_outputs.dtype, ) return batch_inputs, batch_outputs @@ -97,14 +97,6 @@ def to_tensor(self): self._to_tensor = True return self - def to_vector(self): - self.is_vector = True - return self - - def from_vector(self): - self.is_vector = False - return self - class CausalLMDataset: def __init__( @@ -156,14 +148,14 @@ def __getitem__(self, idx: int): inputs, requires_grad=False, name="inputs", - is_vector=self.is_batch, + is_batched=self.is_batch, dtype=np.uint32, ) outputs = to_tensor( outputs, requires_grad=False, name="output", - is_vector=self.is_batch, + is_batched=self.is_batch, dtype=np.uint32, ) if self.device is not None: diff --git a/src/tricycle/einsum.py b/src/tricycle/einsum.py index 9aa9286..116b89c 100644 --- a/src/tricycle/einsum.py +++ b/src/tricycle/einsum.py @@ -1,3 +1,40 @@ +""" +Einsum is a generalisation of a large number of matrix operations. + +You can use it by assigning each axis in your matrices a letter of the +alphabet (called an index). You can define the operation you want to perform +by simply listing the indices you want in your inputs and output, separated by +an arrow. + +For example, you can define the transpose of a 2d tensor as follows: + +>>> a = to_tensor([[1,2],[3,4]]) +>>> Einsum("ij->ji")(a) +Tensor([[1. 3.] + [2. 4.]], name=einsum ij->ji) + +Here, we use einsum to swap indices i and j: a transpose. + +There are only two rules to remember with einsum: + - If an index does not appear in the output, any inputs that contain it + will be summed along that axis: + + >>> Einsum("ij->i")(a) + Tensor([3. 7.], name=einsum ij->i) + + - If an index appears in more than one input, the tensors will be multiplied + along that axis + + >>> b = to_tensor([[5,6],[7,8]) + >>> Einsum("ij,jk->ik")(a,b) + Tensor([[19. 22.] + [43. 50.]], name=einsum ij,jk->ik) + + + +You can use einsum to perform all of these operations: +""" + import itertools import re from typing import Sequence @@ -52,6 +89,12 @@ def __str__(self): class EinsumBackOp: + """ + The backward operation for an einsum operation. This is done by + swapping the indices and tensors for an input with the output. + E.g "ij,jk->ik" with idx = 0 would become "ik,jk->ij" + """ + def __init__( self, idx: int, tensors: Sequence[Tensor], subscript: Subscript ): @@ -119,9 +162,9 @@ def _build_back_ops(self, tensors: Sequence[Tensor], subscript: Subscript): """ assert len(tensors) == len(subscript.inputs) - # To avoid adding a bunch of special cases for vectorised - # operations, we replace any vectorised operations with - # their non-vectorised counterparts + # To avoid adding a bunch of special cases for batched + # operations, we replace any batched operations with + # their non-batched counterparts subscript = Subscript(subscript.subscript.replace("z", "")) back_functions = [] @@ -156,7 +199,7 @@ def _handle_single_tensor( [tensor] = tensors ones = to_tensor( xp.ones(tensor.shape), - is_vector=tensor.is_vector, + is_batched=tensor.is_batched, requires_grad=False, ) tensors = [tensor, ones] @@ -168,33 +211,33 @@ def _handle_single_tensor( return subscript, tensors - def _handle_vectorised( + def _handle_batched( self, subscript: Subscript, tensors: Sequence[Tensor] ) -> tuple[Subscript, Sequence[Tensor], bool]: """ - If a tensor is labelled as being vectorised, add an extra dimension + If a tensor is labelled as being batched, add an extra dimension to its indices. """ inputs = [] - vectorise_output = False + batch_output = False for idx, tensor in zip(subscript.inputs, tensors): - if tensor.is_vector: + if tensor.is_batched: inputs.append(["z"] + idx) - vectorise_output = True + batch_output = True else: inputs.append(idx) output = subscript.output - if vectorise_output: + if batch_output: if "z" in subscript.subscript: raise ValueError( "`z` cannot be used in an einsum subscript on " - "non-vectorised tensors because " - "it is reserved for vectorised indices." + "non-batched tensors because " + "it is reserved for batched indices." ) output = ["z"] + output subscript = Subscript.from_split(inputs, output) - return subscript, tensors, vectorise_output + return subscript, tensors, batch_output def _replace_infinity(self, tensors: Sequence[Tensor]): """ @@ -204,12 +247,13 @@ def _replace_infinity(self, tensors: Sequence[Tensor]): xp = select_backend(*tensors) processed = [] for tensor in tensors: - if not xp.isinf(tensor._data).any(): + if not xp.isinf(tensor.array).any(): processed.append(tensor) continue new_tensor = to_tensor( - xp.nan_to_num(tensor._data), is_vector=tensor.is_vector + xp.nan_to_num(tensor.array), + is_batched=tensor.is_batched, ) new_tensor.args = tensor.args new_tensor.back_fns = tensor.back_fns @@ -220,14 +264,14 @@ def _replace_infinity(self, tensors: Sequence[Tensor]): def __call__(self, *tensors: Tensor): xp = select_backend(*tensors) - subscript, tensors, vectorise_output = self._handle_vectorised( + subscript, tensors, batch_output = self._handle_batched( self.subscript, tensors ) subscript, tensors = self._handle_single_tensor(subscript, tensors) - tensor_data = [t._data for t in tensors] + tensor_data = [t.array for t in tensors] result = to_tensor(xp.einsum(str(subscript), *tensor_data)) - if vectorise_output: - result.is_vector = True + if batch_output: + result.is_batched = True result.args = tuple(tensors) result.back_fns = tuple(self._build_back_ops(tensors, subscript)) diff --git a/src/tricycle/functions.py b/src/tricycle/functions.py index a534f38..7fe8e39 100644 --- a/src/tricycle/functions.py +++ b/src/tricycle/functions.py @@ -8,11 +8,11 @@ class Softmax(Op): def backward(self, grad: Tensor) -> Tensor: xp = grad.xp - inner = xp.sum(grad._data * self._out, axis=-1, keepdims=True) - self._grad = self._out * (grad._data - inner) + inner = xp.sum(grad.array * self._out, axis=-1, keepdims=True) + self._grad = self._out * (grad.array - inner) return to_tensor( self._grad, - is_vector=grad.is_vector, + is_batched=grad.is_batched, requires_grad=grad.requires_grad, ) @@ -25,7 +25,9 @@ def forward(self, tensor: Tensor): xp = tensor.xp exp = xp.exp( - tensor._data - xp.max(tensor._data, axis=-1, keepdims=True) + # subtract the largest value for numeric stability + tensor.array + - xp.max(tensor.array, axis=-1, keepdims=True) ) denominator = xp.sum(exp, axis=-1, keepdims=True) self._out = exp / denominator @@ -33,7 +35,7 @@ def forward(self, tensor: Tensor): result = to_tensor(self._out) result.args = (tensor,) result.name = "softmax" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched result.back_fns = (self.backward,) return result diff --git a/src/tricycle/layers.py b/src/tricycle/layers.py index 429e165..5562b8a 100644 --- a/src/tricycle/layers.py +++ b/src/tricycle/layers.py @@ -50,24 +50,24 @@ def __init__( def weight_back_fn(self, grad: Tensor): xp = grad.xp - result = xp.einsum(self._weight_subscript, self._input, grad._data) + result = xp.einsum(self._weight_subscript, self._input, grad.array) return to_tensor( result, requires_grad=grad.requires_grad, name="back_dense", - is_vector=False, + is_batched=False, ) def grad_back_fn(self, grad: Tensor): xp = grad.xp result = xp.einsum( - self._grad_subscript, self.weights._data, grad._data + self._grad_subscript, self.weights.array, grad.array ) return to_tensor( result, requires_grad=grad.requires_grad, name="back_dense", - is_vector=True, + is_batched=True, ) def forward(self, tensor: Tensor): @@ -91,24 +91,24 @@ def forward(self, tensor: Tensor): case _: raise NotImplementedError( f"Cannot pass tensor with shape {tensor.shape} " - f"and {tensor.is_vector=}" + f"and {tensor.is_batched=}" "through a Dense layer" ) result = to_tensor( tensor.xp.einsum( subscript, - tensor._data, - self.weights._data, + tensor.array, + self.weights.array, ) ) self._grad_subscript = grad_subscript self._weight_subscript = weight_subscript - self._input = tensor._data + self._input = tensor.array result.name = "dense" result.args = (self.weights, tensor) result.back_fns = (self.weight_back_fn, self.grad_back_fn) - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -127,30 +127,6 @@ def from_gpu(self): return self -# class Dropout(Layer): -# def __init__(self, probability: float): -# self.probability = probability -# -# def backward(self, grad: Tensor): -# return to_tensor(self._mask * grad._data, is_vector=grad.is_vector) -# -# def forward(self, tensor: Tensor): -# if self.probability == 0: -# return tensor -# xp = tensor.xp -# coef = 1 / (1 - self.probability) -# -# self._mask = (xp.random.rand(*tensor.shape) > self.probability).astype( -# tensor.dtype -# ) * coef -# self._out = self._mask * tensor._data -# result = to_tensor(self._out, is_vector=tensor.is_vector) -# result.args = (tensor,) -# result.back_fns = (self.backward,) -# -# return result - - class Dropout(Layer): def __init__(self, probability: float): self.probability = probability @@ -164,7 +140,7 @@ def forward(self, tensor: Tensor): xp.random.rand(*tensor.shape) > self.probability ).astype(tensor.dtype) * coef random_mask = to_tensor( - random_mask, is_vector=True, requires_grad=False + random_mask, is_batched=True, requires_grad=False ) return BinaryMultiply()(tensor, random_mask) @@ -175,10 +151,10 @@ def __init__(self, embedding_dim: int, eps=1e-5): self.eps = eps self.gamma = to_tensor( - np.ones((embedding_dim,)), requires_grad=True, is_vector=False + np.ones((embedding_dim,)), requires_grad=True, is_batched=False ) self.beta = to_tensor( - np.zeros((embedding_dim,)), requires_grad=True, is_vector=False + np.zeros((embedding_dim,)), requires_grad=True, is_batched=False ) def forward(self, tensor: Tensor): @@ -192,7 +168,7 @@ def forward(self, tensor: Tensor): numpy.ndarray: Normalized tensor of the same shape as x. """ xp = tensor.xp - x = tensor._data + x = tensor.array # Compute mean and variance along the feature dimension self._mean = x.mean(axis=-1, keepdims=True) @@ -201,11 +177,11 @@ def forward(self, tensor: Tensor): # Normalize and scale x_norm = (x - self._mean) / xp.sqrt(self._var + self.eps) - output = self.gamma._data * x_norm + self.beta._data + output = self.gamma.array * x_norm + self.beta.array output = to_tensor( output, - is_vector=tensor.is_vector, + is_batched=tensor.is_batched, requires_grad=tensor.requires_grad, ) output.back_fns = (self.back_fn, self.beta_back_fn, self.gamma_back_fn) @@ -223,8 +199,8 @@ def gamma_back_fn(self, grad: Tensor): # Compute intermediate values x_norm = (self._input - self._mean) / xp.sqrt(self._var + self.eps) axes = tuple(range(grad.ndim - 1)) - result = xp.sum(grad._data * x_norm, axis=axes) - return to_tensor(result, is_vector=False) + result = xp.sum(grad.array * x_norm, axis=axes) + return to_tensor(result, is_batched=False) def beta_back_fn(self, grad: Tensor): """ @@ -234,8 +210,8 @@ def beta_back_fn(self, grad: Tensor): # Compute intermediate values axes = tuple(range(grad.ndim - 1)) - result = xp.sum(grad._data, axis=axes) - return to_tensor(result, is_vector=False) + result = xp.sum(grad.array, axis=axes) + return to_tensor(result, is_batched=False) def back_fn(self, grad: Tensor): """ @@ -247,7 +223,7 @@ def back_fn(self, grad: Tensor): n = self._input.shape[-1] # Gradients with respect to x - dx_norm = grad._data * self.gamma._data + dx_norm = grad.array * self.gamma.array dvar = xp.sum( dx_norm * (self._input - self._mean) @@ -271,7 +247,7 @@ def back_fn(self, grad: Tensor): return to_tensor( result, - is_vector=grad.is_vector, + is_batched=grad.is_batched, requires_grad=grad.requires_grad, name="back_ln", ) @@ -315,11 +291,11 @@ def build_back_fn(self, rms, input_): def rmsnorm_weight_back_fn(grad): xp = grad.xp result = xp.sum(input_ / rms, axis=-2).sum(0).squeeze() - return to_tensor(result, is_vector=False) + return to_tensor(result, is_batched=False) def rmsnorm_back_fn(grad): xp = grad.xp - scaled_grad = xp.multiply(grad._data, self.weights._data) + scaled_grad = xp.multiply(grad.array, self.weights.array) left = scaled_grad / rms @@ -337,22 +313,24 @@ def rmsnorm_back_fn(grad): f"RMSNorm with tensors of size {input_.ndim} are not yet supported" ) right = square_prod * coef - return to_tensor(left - right, is_vector=grad.is_vector) + return to_tensor(left - right, is_batched=grad.is_batched) return rmsnorm_weight_back_fn, rmsnorm_back_fn def forward(self, tensor: Tensor): xp = tensor.xp - square_sum = (tensor._data * tensor._data).mean(axis=-1) + square_sum = (tensor.array * tensor.array).mean(axis=-1) rms = xp.sqrt(square_sum) rms = xp.expand_dims(rms, -1) - result = xp.divide(tensor._data, (rms + self.REALLY_SMALL_NUMBER)) - result = xp.einsum("...a,a->...a", result, self.weights._data) + result = xp.divide(tensor.array, (rms + self.REALLY_SMALL_NUMBER)) + result = xp.einsum("...a,a->...a", result, self.weights.array) weight_back_fn, back_fn = self.build_back_fn( - rms=rms, input_=tensor._data, is_vector=tensor.is_vector + rms=rms, input_=tensor.array, is_batched=tensor.is_batched + ) + result = to_tensor( + result, is_batched=tensor.is_batched, name="rmsnorm" ) - result = to_tensor(result, is_vector=tensor.is_vector, name="rmsnorm") result.back_fns = ( weight_back_fn, back_fn, @@ -403,9 +381,9 @@ def back_fn(self, grad: Tensor): match grad.ndim - self.input.ndim: case 1: - xp.add.at(out, self.input._data, grad._data) + xp.add.at(out, self.input.array, grad.array) case 2: - xp.add.at(out, self.input._data, grad._data.sum(axis=0)) + xp.add.at(out, self.input.array, grad.array.sum(axis=0)) case _: raise NotImplementedError( f"{grad.ndim=}, {self.input.ndim=} are not supported" @@ -419,13 +397,13 @@ def forward(self, tensor: Tensor): ), "Cannot embed a differentiable tensor" self.input = tensor - if tensor.is_vector: - self._out = self.weights._data[tensor._data.flatten()].reshape( - tensor._data.shape + (-1,) + if tensor.is_batched: + self._out = self.weights.array[tensor.array.flatten()].reshape( + tensor.array.shape + (-1,) ) else: - self._out = self.weights._data[tensor._data] - result = to_tensor(self._out, is_vector=tensor.is_vector) + self._out = self.weights.array[tensor.array] + result = to_tensor(self._out, is_batched=tensor.is_batched) result.args = (tensor, self.weights) diff --git a/src/tricycle/loss.py b/src/tricycle/loss.py index 8835911..b0419a4 100644 --- a/src/tricycle/loss.py +++ b/src/tricycle/loss.py @@ -19,135 +19,6 @@ def mean_square_error(y_true: Tensor, y_pred: Tensor): return square_error.mean() -class CrossEntropy_(Op): - REALLY_SMALL_NUMBER = 1e-8 - REALLY_BIG_NUMBER = 1e8 - - def backward(self, grad: Tensor) -> Tensor: - xp = grad.xp - - self._grad = xp.where(self._y_true == 1, -1 / self._y_pred, 0) - self._grad *= xp.expand_dims(grad._data, -1) - return to_tensor(self._grad, is_vector=grad.is_vector) - - def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: - # sourcery skip: assign-if-exp, reintroduce-else - """ - Calculate the cross entropy loss - """ - # normalise - # TODO: fuse normalising and calculation together - y_pred = Softmax()(y_pred) - - xp = y_pred.xp - - # clip for numeric stability - y_pred._data = y_pred._data.clip( - min=self.REALLY_SMALL_NUMBER, max=self.REALLY_BIG_NUMBER - ) - - # cache inputs for calculating the backwards operations later - self._y_true = y_true._data - self._y_pred = y_pred._data - - indicator = xp.where(y_true._data == 1, -xp.log(y_pred._data), 0) - - self._out = indicator.sum(axis=-1) - - result = to_tensor(self._out, is_vector=y_pred.is_vector) - result.back_fns = (self.backward,) - - # y_true never requires grad so we dont calculate gradients for it - result.args = (y_pred,) - result.name = "cross_entropy" - - return result - - -class BinaryCrossEntropy(Op): - """ - Calculate cross entropy loss, given logits and target indices (as opposed - to one-hot encoded tensors) - """ - - REALLY_SMALL_NUMBER = 1e-8 - REALLY_BIG_NUMBER = 1e8 - - def backward(self, grad: Tensor) -> Tensor: - xp = grad.xp - - match self._y_pred.ndim: - case 3: - out = xp.zeros_like(self._y_pred) - batch_indices = xp.arange(self._y_true.shape[0]) - token_indices = xp.arange(self._y_true.shape[1]) - for b in batch_indices: - out[b, token_indices, self._y_true[b]] = ( - -1 / self._y_pred[b, token_indices, self._y_true[b]] - ) * grad._data[b] - case 2: - indices = xp.arange(self._y_true.shape[0]) - out = -1 / self._y_pred[indices, self._y_true._data] - out *= grad._data - case _: - raise NotImplementedError( - "BinaryCrossEntropy with predictions with ndim: " - f"{self._y_pred.ndim} are not yet supported" - ) - self._grad = out - - return to_tensor(self._grad, is_vector=grad.is_vector) - - def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: - # sourcery skip: assign-if-exp, reintroduce-else - """ - Calculate the cross entropy loss - """ - # normalise - # TODO: fuse normalising and calculation together - y_pred = Softmax()(y_pred) - - xp = y_pred.xp - - # clip for numeric stability - y_pred._data = y_pred._data.clip( - min=self.REALLY_SMALL_NUMBER, max=self.REALLY_BIG_NUMBER - ) - - # cache inputs for calculating the backwards operations later - self._y_true = y_true._data - self._y_pred = y_pred._data - - match self._y_pred.ndim: - case 3: - out = xp.zeros_like(y_true._data) - batch_indices = xp.arange(y_true.shape[0]) - token_indices = xp.arange(y_true.shape[1]) - for b in batch_indices: - out[b] = -xp.log( - y_pred._data[b, token_indices, y_true._data[b]] - ) - case 2: - indices = xp.arange(y_true.shape[0]) - out = -xp.log(y_pred[indices, y_true._data]) - case _: - raise NotImplementedError( - "BinaryCrossEntropy with predictions with ndim: " - f"{self._y_pred.ndim} are not yet supported" - ) - - self._out = out - - result = to_tensor(self._out, is_vector=y_pred.is_vector) - result.back_fns = (self.backward,) - - # y_true never requires grad so we dont calculate gradients for it - result.args = (y_pred,) - result.name = "cross_entropy" - - return result - - class CrossEntropy(Op): """ Calculate cross entropy loss, given logits and target indices (as opposed @@ -156,11 +27,11 @@ class CrossEntropy(Op): def log_softmax(self, tensor: Tensor): xp = tensor.xp - x_max = xp.max(tensor._data, axis=-1, keepdims=True) + x_max = xp.max(tensor.array, axis=-1, keepdims=True) log_sum_exp = x_max + xp.log( - xp.sum(xp.exp(tensor._data - x_max), axis=-1, keepdims=True) + xp.sum(xp.exp(tensor.array - x_max), axis=-1, keepdims=True) ) - return tensor._data - log_sum_exp + return tensor.array - log_sum_exp def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: """ @@ -172,7 +43,7 @@ def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: log_softmax_pred = self.log_softmax(y_pred) # Cache for backward pass - self._y_true = y_true._data + self._y_true = y_true.array self._log_softmax_pred = log_softmax_pred ndim = log_softmax_pred.ndim @@ -181,13 +52,13 @@ def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: batch_indices = xp.arange(y_true.shape[0], dtype=int) token_indices = xp.arange(y_true.shape[1], dtype=int) loss = -log_softmax_pred[ - batch_indices[:, None], token_indices, y_true._data + batch_indices[:, None], token_indices, y_true.array ] elif ndim == 2: indices = xp.arange(y_true.shape[0], dtype=int) - loss = -log_softmax_pred[indices, y_true._data] + loss = -log_softmax_pred[indices, y_true.array] elif ndim == 1: - loss = -log_softmax_pred[y_true._data] + loss = -log_softmax_pred[y_true.array] else: raise NotImplementedError( f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported" @@ -197,7 +68,7 @@ def forward(self, y_true: Tensor, y_pred: Tensor) -> Tensor: loss = loss.mean() self._out = loss - result = to_tensor(self._out, is_vector=False) + result = to_tensor(self._out, is_batched=False) result.back_fns = (self.backward,) result.args = (y_pred,) @@ -216,7 +87,7 @@ def backward(self, grad: Tensor) -> Tensor: grad_output[ batch_indices[:, None], token_indices, self._y_true ] -= 1 - grad_output *= grad._data / ( + grad_output *= grad.array / ( self._y_true.shape[0] * self._y_true.shape[1] ) @@ -224,15 +95,15 @@ def backward(self, grad: Tensor) -> Tensor: indices = xp.arange(self._y_true.shape[0], dtype=int) grad_output = xp.exp(self._log_softmax_pred) grad_output[indices, self._y_true] -= 1 - grad_output *= grad._data / self._y_true.shape[0] + grad_output *= grad.array / self._y_true.shape[0] elif ndim == 1: grad_output = xp.exp(self._log_softmax_pred) grad_output[self._y_true] -= 1 - grad_output *= grad._data + grad_output *= grad.array else: raise NotImplementedError( f"BinaryCrossEntropy with predictions with ndim: {ndim} are not yet supported" ) self._grad = grad_output - return to_tensor(self._grad, is_vector=grad.is_vector) + return to_tensor(self._grad, is_batched=grad.is_batched) diff --git a/src/tricycle/models.py b/src/tricycle/models.py index d21ce58..5b2cead 100644 --- a/src/tricycle/models.py +++ b/src/tricycle/models.py @@ -59,7 +59,7 @@ def forward(self, tensor: Tensor): if tensor.ndim == 1: n_tokens = 1 context_window = tensor.shape[-1] - tensor._data = xp.expand_dims(tensor._data, 0) + tensor.array = xp.expand_dims(tensor.array, 0) else: n_tokens, context_window = tensor.shape assert n_tokens <= self.context_window, ( diff --git a/src/tricycle/ops.py b/src/tricycle/ops.py index dc947f4..88d19fb 100644 --- a/src/tricycle/ops.py +++ b/src/tricycle/ops.py @@ -35,7 +35,9 @@ def forward(self, tensor: Tensor, repeats: int): subscript = Subscript("...,...a->...a") new_shape = tensor.shape + (repeats,) ones = to_tensor( - xp.ones(new_shape), is_vector=tensor.is_vector, requires_grad=False + xp.ones(new_shape), + is_batched=tensor.is_batched, + requires_grad=False, ) return Einsum(subscript)(tensor, ones) @@ -74,10 +76,10 @@ def back_fn(self, grad: Tensor, idx: int) -> Tensor: indices.append(slice(start, end)) else: indices.append(slice(None)) - self._grad[idx][tuple(indices)] = grad._data + self._grad[idx][tuple(indices)] = grad.array result = to_tensor(self._grad[idx]) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward( @@ -90,7 +92,7 @@ def forward( assert isinstance(n_splits, int) - self._out = xp.split(tensor._data, n_splits, axis=axis) + self._out = xp.split(tensor.array, n_splits, axis=axis) self._in_shape = tensor.shape self._axis = axis self._n_splits = n_splits @@ -106,7 +108,7 @@ def back_fn(grad, idx=idx): result = to_tensor(result) result.back_fns = (back_fn,) result.args = (tensor,) - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched results.append(result) return results @@ -117,23 +119,23 @@ class Reshape(Op): def back_fn(self, grad: Tensor) -> Tensor: # sourcery skip: assign-if-exp xp = grad.xp - self._grad = xp.reshape(grad._data, self._original_shape) + self._grad = xp.reshape(grad.array, self._original_shape) result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor, shape: Sequence[int]) -> Tensor: xp = tensor.xp - if tensor.is_vector: + if tensor.is_batched: shape = [tensor.shape[0]] + list(shape) - self._out = xp.reshape(tensor._data, shape) + self._out = xp.reshape(tensor.array, shape) self._original_shape = tensor.shape result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result diff --git a/src/tricycle/optimisers.py b/src/tricycle/optimisers.py index b003e2b..3919900 100644 --- a/src/tricycle/optimisers.py +++ b/src/tricycle/optimisers.py @@ -32,8 +32,8 @@ def update_weight(self, tensor: Tensor): """ assert tensor.grad is not None - if tensor.grad.is_vector: - tensor.grad = tensor.grad.from_vector().e("z...->...") + if tensor.grad.is_batched: + tensor.grad = tensor.grad.from_batched().e("z...->...") grad = self.learning_rate * tensor.grad @@ -48,14 +48,14 @@ def update_weight(self, tensor: Tensor): last_momentum = self.momentum_store[tensor._id] grad += self.momentum * last_momentum - self.momentum_store[tensor._id] = to_tensor(grad._data) + self.momentum_store[tensor._id] = to_tensor(grad.array) # update the value only, leave everything else result = to_tensor( tensor - grad, requires_grad=tensor.requires_grad, name=tensor.name, - is_vector=tensor.is_vector, + is_batched=tensor.is_batched, _id=tensor._id, ) @@ -100,28 +100,28 @@ def update_weight(self, tensor: Tensor) -> Tensor: # initialise stores if key not in self.m: - self.m[key] = xp.zeros_like(tensor._data) + self.m[key] = xp.zeros_like(tensor.array) if key not in self.v: - self.v[key] = tensor.xp.zeros_like(tensor._data) + self.v[key] = tensor.xp.zeros_like(tensor.array) self.m[key] = ( self.betas[0] * self.m[key] - + (1 - self.betas[0]) * tensor.grad._data + + (1 - self.betas[0]) * tensor.grad.array ) self.v[key] = self.betas[1] * self.v[key] + (1 - self.betas[1]) * ( - tensor.grad._data * tensor.grad._data + tensor.grad.array * tensor.grad.array ) m_hat = self.m[key] / (1 - self.betas[0] ** self.t) v_hat = self.v[key] / (1 - self.betas[1] ** self.t) - tensor._data -= self.learning_rate * ( + tensor.array -= self.learning_rate * ( m_hat / (xp.sqrt(v_hat) + self.eps) - + self.weight_decay * tensor._data + + self.weight_decay * tensor.array ) - tensor.grad._data.fill(0) + tensor.grad.array.fill(0) return tensor diff --git a/src/tricycle/reduce.py b/src/tricycle/reduce.py index 2fe641f..c4748e8 100644 --- a/src/tricycle/reduce.py +++ b/src/tricycle/reduce.py @@ -26,13 +26,13 @@ def __call__(self, tensor: Tensor, subscript: Subscript | str): if not reduce_along_axes: return tensor - indicator = tensor._data == tensor.xp.max( - tensor._data, axis=tuple(reduce_along_axes), keepdims=True + indicator = tensor.array == tensor.xp.max( + tensor.array, axis=tuple(reduce_along_axes), keepdims=True ) indicator = to_tensor( - indicator, requires_grad=False, is_vector=tensor.is_vector + indicator, requires_grad=False, is_batched=tensor.is_batched ) - indicator._data = indicator._data.astype(tensor.xp.int8) + indicator.array = indicator.array.astype(tensor.xp.int8) new_subscript = Subscript.from_split([idx, idx], subscript.output) @@ -65,13 +65,13 @@ def __call__(self, tensor: Tensor, subscript: Subscript | str): if not reduce_along_axes: return tensor - indicator = tensor._data == tensor.xp.min( - tensor._data, axis=tuple(reduce_along_axes), keepdims=True + indicator = tensor.array == tensor.xp.min( + tensor.array, axis=tuple(reduce_along_axes), keepdims=True ) indicator = to_tensor( - indicator, requires_grad=False, is_vector=tensor.is_vector + indicator, requires_grad=False, is_batched=tensor.is_batched ) - indicator._data = indicator._data.astype(tensor.xp.int8) + indicator.array = indicator.array.astype(tensor.xp.int8) new_subscript = Subscript.from_split([idx, idx], subscript.output) diff --git a/src/tricycle/tensor.py b/src/tricycle/tensor.py index fb139e0..68c09b0 100644 --- a/src/tricycle/tensor.py +++ b/src/tricycle/tensor.py @@ -22,20 +22,20 @@ class Tensor: """ _id: int - _data: np.ndarray | ArrayLike + array: np.ndarray | ArrayLike args: tuple["Tensor", ...] | None = None back_fns: tuple[Op, ...] | None = None parents: set["Tensor"] | None = None grad: Optional["Tensor"] = None name: Optional[str] = None requires_grad: bool = False - is_vector: bool = False + is_batched: bool = False def __init__( self, data: np.ndarray | ArrayLike, requires_grad: bool = False, - is_vector: bool = False, + is_batched: bool = False, name: str | None = None, _id: int | None = None, ): @@ -44,14 +44,14 @@ def __init__( import cupy if isinstance(data, (np.ndarray, cupy.ndarray)): - self._data = data + self.array = data else: - self._data = np.array(data) + self.array = np.array(data) else: - self._data = np.array(data) + self.array = np.array(data) self.requires_grad = requires_grad - self.is_vector = is_vector + self.is_batched = is_batched self.name = name def _attach_parents(self): @@ -89,9 +89,9 @@ def _calculate_gradients(self, clip: float | None = None): has been computed """ self.grad = to_tensor( - self.xp.ones(self._data.shape, dtype=self.dtype), + self.xp.ones(self.array.shape, dtype=self.dtype), requires_grad=False, - is_vector=self.is_vector, + is_batched=self.is_batched, ) stack: list["Tensor"] = [self] @@ -125,13 +125,13 @@ def _calculate_gradients(self, clip: float | None = None): # gradient clipping if clip is not None: - grad._data = grad.xp.clip(grad._data, -clip, clip) + grad.array = grad.xp.clip(grad.array, -clip, clip) # add gradient if arg.grad is None: arg.grad = grad else: - arg.grad._data += grad._data + arg.grad.array += grad.array except Exception as e: raise e @@ -142,27 +142,6 @@ def _calculate_gradients(self, clip: float | None = None): arg.parents = None stack.append(arg) - def cleanup(self): - """ - Traverse through the graph, deleting all non-parameter nodes in - the graph to avoid a memory leak - """ - stack: list["Tensor"] = [self] - while stack: - node = stack.pop() - - # add children to stack - if node.args: - stack.extend(iter(node.args)) - del node.args - else: - continue - - # delete node - if hasattr(node, "grad") and node.grad is not None: - del node.grad - del node - def backward(self, clip: float | None = None): """ Perform a backward pass through the graph, calculating the gradient @@ -309,47 +288,47 @@ def __pow__(self, other) -> "Tensor": def __lt__(self, other): if isinstance(other, Tensor): - return Tensor(self._data < other._data) - return Tensor(self._data < other) + return Tensor(self.array < other.array) + return Tensor(self.array < other) def __le__(self, other): if isinstance(other, Tensor): - return Tensor(self._data <= other._data) - return Tensor(self._data <= other) + return Tensor(self.array <= other.array) + return Tensor(self.array <= other) def __eq__(self, other): if isinstance(other, Tensor): - return Tensor(self._data == other._data) - return Tensor(self._data == other) + return Tensor(self.array == other.array) + return Tensor(self.array == other) def __ne__(self, other): if isinstance(other, Tensor): - return Tensor(self._data != other._data) - return Tensor(self._data != other) + return Tensor(self.array != other.array) + return Tensor(self.array != other) def __gt__(self, other): if isinstance(other, Tensor): - return Tensor(self._data > other._data) - return Tensor(self._data > other) + return Tensor(self.array > other.array) + return Tensor(self.array > other) def __ge__(self, other): if isinstance(other, Tensor): - return Tensor(self._data >= other._data) - return Tensor(self._data >= other) + return Tensor(self.array >= other.array) + return Tensor(self.array >= other) def __repr__(self): name = f", name={self.name}" if self.name is not None else "" - return f"Tensor({self._data.__str__()}{name})" + return f"Tensor({self.array.__str__()}{name})" def __getitem__(self, idx): - return to_tensor(self._data[idx], requires_grad=self.requires_grad) + return to_tensor(self.array[idx], requires_grad=self.requires_grad) def __setitem__(self, idx, value): - self._data[idx] = value + self.array[idx] = value @property def xp(self): - return select_backend(self._data) + return select_backend(self.array) def e(self, subscript: str) -> "Tensor": """ @@ -366,15 +345,15 @@ def repeat(self, n_repeats: int) -> "Tensor": @property def shape(self) -> Sequence[int]: - return self._data.shape + return self.array.shape @property def ndim(self) -> int: - return self._data.ndim + return self.array.ndim @property def dtype(self) -> np.dtype: - return self._data.dtype + return self.array.dtype def reshape(self, shape: Sequence[int]) -> "Tensor": from tricycle.ops import Reshape @@ -393,11 +372,6 @@ def mean(self) -> "Tensor": def sum(self) -> "Tensor": from tricycle.unary import UnarySum - # if self.is_vector: - # indices = "abcdefghijklmnopqrstuvwxy"[: self.ndim - 1] - # else: - # indices = "abcdefghijklmnopqrstuvwxy"[: self.ndim] - # return self.e(f"{indices}->") return UnarySum()(self) def close_to( @@ -413,27 +387,27 @@ def close_to( """ if not isinstance(other, Tensor): return self.xp.allclose( - self._data, + self.array, self.xp.array(other), equal_nan=equal_nan, rtol=rtol, **kwargs, ) return self.xp.allclose( - self._data, other._data, equal_nan=equal_nan, rtol=rtol, **kwargs + self.array, other.array, equal_nan=equal_nan, rtol=rtol, **kwargs ) - def to_vector(self): + def to_batched(self): """ - Treat this tensor as a vector + Treat this tensor as a batch of tensors """ - return vectorise(self) + return batch(self) - def from_vector(self): + def from_batched(self): """ - Treat a vectorised tensor as a normal tensor + Treat a batched tensor as a normal tensor """ - return unvectorise(self) + return unbatch(self) @property def on_gpu(self): @@ -441,7 +415,7 @@ def on_gpu(self): return False import cupy - return isinstance(self._data, cupy.ndarray) + return isinstance(self.array, cupy.ndarray) def to_gpu(self, device: int = 0): """ @@ -454,7 +428,7 @@ def to_gpu(self, device: int = 0): import cupy cupy.cuda.Device(device).use() - self._data = cupy.asarray(self._data) + self.array = cupy.asarray(self.array) return self def from_gpu(self): @@ -467,7 +441,7 @@ def from_gpu(self): ) import cupy - self._data = cupy.asnumpy(self._data) + self.array = cupy.asnumpy(self.array) return self def zero_grad(self): @@ -479,18 +453,18 @@ def zero_grad(self): def numpy(self): if not CUPY_ENABLED: - return self._data + return self.array import cupy - return cupy.asnumpy(self._data) if self.on_gpu else self._data + return cupy.asnumpy(self.array) if self.on_gpu else self.array def to_tensor( tensor_like: ArrayLike, name: Optional[str] = None, requires_grad: bool = True, - is_vector: bool = False, + is_batched: bool = False, _id: int | None = None, dtype: np.dtype | None = np.float32, **kwargs, @@ -503,7 +477,7 @@ def to_tensor( import cupy if isinstance(tensor_like, Tensor): - array = tensor_like._data + array = tensor_like.array elif isinstance(tensor_like, (np.ndarray, cupy.ndarray)): array = tensor_like if dtype is not None: @@ -514,7 +488,7 @@ def to_tensor( array = np.asarray(tensor_like, dtype=dtype, **kwargs) elif isinstance(tensor_like, Tensor): - array = tensor_like._data + array = tensor_like.array else: array = np.asarray(tensor_like, dtype=dtype, **kwargs) @@ -522,45 +496,45 @@ def to_tensor( array, name=name, requires_grad=requires_grad, - is_vector=is_vector, + is_batched=is_batched, _id=_id, ) -def vectorise(tensor: Tensor) -> Tensor: +def batch(tensor: Tensor) -> Tensor: """ - Tell Tricycle to treat this tensor as a group of vectors + Tell Tricycle to treat this tensor as a batch of tensors """ - if tensor.is_vector: + if tensor.is_batched: return tensor result = to_tensor( - tensor._data, - is_vector=True, + tensor.array, + is_batched=True, requires_grad=tensor.requires_grad, - dtype=tensor._data.dtype, + dtype=tensor.array.dtype, ) result.args = (tensor,) - result.back_fns = (unvectorise,) + result.back_fns = (unbatch,) return result -def unvectorise(tensor: Tensor) -> Tensor: +def unbatch(tensor: Tensor) -> Tensor: """ Tell Tricycle to treat this tensor as a single tensor - (not a group of vectors) + (not a batch of tensors) """ - if not tensor.is_vector: + if not tensor.is_batched: return tensor result = to_tensor( - tensor._data, - is_vector=False, + tensor.array, + is_batched=False, requires_grad=tensor.requires_grad, - dtype=tensor._data.dtype, + dtype=tensor.array.dtype, ) result.args = (tensor,) - result.back_fns = (vectorise,) + result.back_fns = (batch,) return result diff --git a/src/tricycle/unary.py b/src/tricycle/unary.py index b9572c7..4afecdf 100644 --- a/src/tricycle/unary.py +++ b/src/tricycle/unary.py @@ -1,3 +1,10 @@ +""" +When doing tensor calculus, some Operations have a single input and output. +I'm calling these `unary` operations. + +This file contains all of the unary operations in Tricycle +""" + import numbers from numpy.typing import ArrayLike @@ -5,8 +12,6 @@ from tricycle.ops import Op from tricycle.tensor import Tensor, nothing, to_tensor -grad = False - class UnaryAdd(Op): def forward(self, tensor: Tensor, constant: float) -> Tensor: @@ -19,13 +24,13 @@ def forward(self, tensor: Tensor, constant: float) -> Tensor: assert isinstance(tensor, Tensor) assert isinstance(constant, numbers.Number) - self._out = xp.add(tensor._data, constant) + self._out = xp.add(tensor.array, constant) result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (nothing,) result.name = f"+ {constant}" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -35,9 +40,9 @@ class UnaryMultiply(Op): def back_fn(self, grad: Tensor) -> Tensor: xp = grad.xp - self._grad = xp.multiply(grad._data, self._constant) + self._grad = xp.multiply(grad.array, self._constant) result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor, constant: float) -> Tensor: @@ -50,14 +55,14 @@ def forward(self, tensor: Tensor, constant: float) -> Tensor: assert isinstance(tensor, Tensor) assert xp.isscalar(constant) - self._out = xp.multiply(tensor._data, constant) + self._out = xp.multiply(tensor.array, constant) self._constant = constant result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = f"+ {constant}" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -78,12 +83,12 @@ def back_fn(self, grad: Tensor) -> Tensor: xp = grad.xp self._grad = xp.power( - self.input._data, self.constant - 1, dtype=self.input.dtype + self.input.array, self.constant - 1, dtype=self.input.dtype ) - self._grad *= self.constant * grad._data + self._grad *= self.constant * grad.array result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor, constant: float) -> Tensor: @@ -96,7 +101,7 @@ def forward(self, tensor: Tensor, constant: float) -> Tensor: assert isinstance(tensor, Tensor) assert xp.isscalar(constant) - self._out = xp.power(tensor._data, constant) + self._out = xp.power(tensor.array, constant) self.input = tensor self.constant = constant @@ -104,7 +109,7 @@ def forward(self, tensor: Tensor, constant: float) -> Tensor: result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = f"^ {constant}" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -125,10 +130,10 @@ class UnaryMax(Op): is_bigger: Tensor def back_fn(self, grad: Tensor) -> Tensor: - self._grad = grad._data * self.is_bigger._data + self._grad = grad.array * self.is_bigger.array result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor, constant: float) -> Tensor: @@ -141,16 +146,16 @@ def forward(self, tensor: Tensor, constant: float) -> Tensor: assert isinstance(tensor, Tensor) assert xp.isscalar(constant) - self._out = xp.maximum(tensor._data, constant, dtype=tensor.dtype) + self._out = xp.maximum(tensor.array, constant, dtype=tensor.dtype) self.is_bigger = tensor > constant - self.is_bigger.is_vector = tensor.is_vector + self.is_bigger.is_batched = tensor.is_batched result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = f"> {constant}" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -159,10 +164,10 @@ class UnaryMin(Op): is_smaller: Tensor def back_fn(self, grad: Tensor) -> Tensor: - self._grad = grad._data * self.is_smaller._data + self._grad = grad.array * self.is_smaller.array result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor, constant: float) -> Tensor: @@ -175,26 +180,26 @@ def forward(self, tensor: Tensor, constant: float) -> Tensor: assert isinstance(tensor, Tensor) assert xp.isscalar(constant) - self._out = xp.minimum(tensor._data, constant, dtype=tensor.dtype) + self._out = xp.minimum(tensor.array, constant, dtype=tensor.dtype) self.is_smaller = tensor < constant - self.is_smaller.is_vector = tensor.is_vector + self.is_smaller.is_batched = tensor.is_batched result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = f"> {constant}" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result class UnaryExp(Op): def back_fn(self, grad: Tensor) -> Tensor: - self._grad = grad._data * self._out + self._grad = grad.array * self._out result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor) -> Tensor: @@ -203,13 +208,13 @@ def forward(self, tensor: Tensor) -> Tensor: """ xp = tensor.xp - self._out = xp.exp(tensor._data) + self._out = xp.exp(tensor.array) result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = "exp" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -221,10 +226,10 @@ class UnaryLog(Op): def back_fn(self, grad: Tensor) -> Tensor: xp = grad.xp denominator = self._input + self.REALLY_SMALL_NUMBER - self._grad = grad._data * xp.divide(1, denominator) + self._grad = grad.array * xp.divide(1, denominator) result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor) -> Tensor: @@ -233,15 +238,15 @@ def forward(self, tensor: Tensor) -> Tensor: """ xp = tensor.xp - self._out = xp.log(tensor._data) - self._input = tensor._data + self._out = xp.log(tensor.array) + self._input = tensor.array result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = "log" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -251,10 +256,10 @@ class UnarySin(Op): def back_fn(self, grad: Tensor) -> Tensor: xp = grad.xp - self._grad = grad._data * xp.cos(self._input) + self._grad = grad.array * xp.cos(self._input) result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor) -> Tensor: @@ -263,14 +268,14 @@ def forward(self, tensor: Tensor) -> Tensor: """ xp = tensor.xp - self._out = xp.sin(tensor._data) - self._input = tensor._data + self._out = xp.sin(tensor.array) + self._input = tensor.array result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = "sin" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -280,10 +285,10 @@ class UnaryCos(Op): def back_fn(self, grad: Tensor) -> Tensor: xp = grad.xp - self._grad = grad._data * -xp.sin(self._input) + self._grad = grad.array * -xp.sin(self._input) result = to_tensor(self._grad) - result.is_vector = grad.is_vector + result.is_batched = grad.is_batched return result def forward(self, tensor: Tensor) -> Tensor: @@ -292,14 +297,14 @@ def forward(self, tensor: Tensor) -> Tensor: """ xp = tensor.xp - self._out = xp.cos(tensor._data) - self._input = tensor._data + self._out = xp.cos(tensor.array) + self._input = tensor.array result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = "cos" - result.is_vector = tensor.is_vector + result.is_batched = tensor.is_batched return result @@ -314,15 +319,15 @@ def forward(self, tensor: Tensor): class UnarySum(Op): _in_shape: tuple[int] - _in_is_vector: bool + _in_is_batche: bool def back_fn(self, grad: Tensor) -> Tensor: xp = grad.xp - self._grad = xp.full(self._in_shape, grad._data) + self._grad = xp.full(self._in_shape, grad.array) result = to_tensor(self._grad) - result.is_vector = self._in_is_vector + result.is_batched = self._in_is_batched return result def forward(self, tensor: Tensor) -> Tensor: @@ -332,13 +337,13 @@ def forward(self, tensor: Tensor) -> Tensor: xp = tensor.xp # Sum all the values in the tensor - self._out = xp.sum(tensor._data) + self._out = xp.sum(tensor.array) self._in_shape = tensor.shape - self._in_is_vector = tensor.is_vector + self._in_is_batched = tensor.is_batched result = to_tensor(self._out) result.args = (tensor,) result.back_fns = (self.back_fn,) result.name = "sum" - result.is_vector = False # The result of the sum is a scalar + result.is_batched = False # The result of the sum is a scalar return result diff --git a/src/tricycle_datasets/shakespeare.py b/src/tricycle_datasets/shakespeare.py index 7117aad..b3d43df 100644 --- a/src/tricycle_datasets/shakespeare.py +++ b/src/tricycle_datasets/shakespeare.py @@ -5,7 +5,7 @@ import numpy as np import requests -from tricycle.tokeniser import BPETokeniserNumba +from tricycle.tokeniser import BPETokeniser class Shakespeare(abc.Sequence): @@ -61,7 +61,7 @@ def download(self): with open(self.raw_data_path, "w") as f: f.write(raw_data) - def generate(self) -> BPETokeniserNumba: + def generate(self) -> BPETokeniser: """ Download and tokenise the shakespeare dataset """ @@ -70,7 +70,7 @@ def generate(self) -> BPETokeniserNumba: list(self.raw_data_path.read_bytes()), dtype=np.int32 ) if self.tokeniser is None: - self.tokeniser = BPETokeniserNumba(self.vocab_size) + self.tokeniser = BPETokeniser(self.vocab_size) return self.tokeniser.train_ints(raw_data, loading_bar=True) def __getitem__(self, idx: int) -> int | list[int]: diff --git a/temp.py b/temp.py deleted file mode 100644 index 17520d8..0000000 --- a/temp.py +++ /dev/null @@ -1,164 +0,0 @@ -from pathlib import Path - -import numpy as np -import regex as re -from numba import njit, prange -from numba.typed import List -from tqdm.auto import tqdm - -GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" -DATA_PATH = Path("/Users/bco60/Documents/Tricycle/datasets/bee_movie.txt") - - -@njit() -def count_pairs(data, token_id, counts): - for i in range(len(data) - 1): - left, right = data[i], data[i + 1] - counts[left * (token_id + 1) + right] += 1 - return counts - - -@njit -def replace_pair( - data: np.ndarray, pair: tuple[int, int], token_id: int -) -> np.ndarray: - """ - Replace every occurrence of `pair` with `token_id` for a single array - """ - if len(data) == 1: - return data - new = 0 - old = 0 - - while old < len(data) - 1: - left = data[old] - right = data[old + 1] - - if (left, right) == pair: - data[new] = token_id - old += 1 - else: - data[new] = left - new += 1 - old += 1 - - # handle final id not being a match - if old == len(data) - 1 and old != 0: - data[new] = data[old] - new += 1 - - return data[:new] - - -@njit(parallel=True) -def replace_all_pairs(data: np.ndarray, pair: tuple[int, int], token_id: int): - for i in prange(len(data)): - data[i] = replace_pair(data[i], pair, token_id) - return data - - -@njit(parallel=True) -def count_all_pairs(data, token_id): - counts = np.zeros((token_id + 1) ** 2, dtype=np.int32) - for i in range(len(data)): - chunk = data[i] - counts = count_pairs(chunk, token_id, counts) - return counts - - -@njit -def flatten(int_array, size=1024, scaling_factor=2): - flat = np.zeros(size, dtype=np.int32) - idx = 0 - - for chunk in int_array: - for token in chunk: - flat[idx] = token - idx += 1 - if idx == size: - size = int(size * scaling_factor) - empty = np.zeros(size, dtype=np.int32) - empty[: size // 2] = flat - flat = empty - return flat[:idx] - - -class Tokeniser: - vocab_size: int - merges: dict[tuple[int, int | None], int] - pairs: list[tuple[int, int | None]] - - # we cant have less than the number of possible single bytes - MIN_TOKENS = 256 - - def __init__(self, vocab_size: int): - assert ( - vocab_size >= self.MIN_TOKENS - ), f"vocab_size must be >= {self.MIN_TOKENS}" - self.vocab_size = vocab_size - - # initialise our pairs and merges with single byte tokens - self.pairs = [(idx, None) for idx in range(self.MIN_TOKENS)] - self.merges = {(idx, None): idx for idx in range(self.MIN_TOKENS)} - self.vocab = [idx.to_bytes(1, "big") for idx in range(self.MIN_TOKENS)] - - def encode(self, text: str, loading_bar: bool = False): - pattern = re.compile(GPT4_SPLIT_PATTERN) - chunks = pattern.findall(text) - int_array = List( - [ - np.array(list(chunk.encode("utf-8")), dtype=np.int32) - for chunk in chunks - ] - ) - token_ids = range(self.MIN_TOKENS, self.vocab_size) - if loading_bar: - token_ids = tqdm(token_ids, desc="Training tokeniser") - - for token_id in token_ids: - # find the most common pair of tokens - most_common_pair = self.most_common_pair( - count_all_pairs(int_array, token_id), token_id - ) - if most_common_pair is None: - break - - # replace every occurrence of the pair with the new token - int_array = replace_all_pairs( - int_array, most_common_pair, token_id - ) - - # store the new pair and token - self.merges[most_common_pair] = token_id - self.pairs.append(most_common_pair) - left, right = most_common_pair - self.vocab.append(self.vocab[left] + self.vocab[right]) - - if len(self.pairs) != self.vocab_size: - warn(f"Expected {self.vocab_size} pairs, got {len(self.pairs)}") - return flatten(int_array) - - def most_common_pair( - self, counts: np.ndarray, token_id: int - ) -> tuple[int, int] | None: - """ - Return the most common pair - """ - most_common_idx = np.argmax(counts) - - # check if there are no more repeated pairs - if counts[most_common_idx] in {0, 1}: - return None - - left = most_common_idx // (token_id + 1) - right = most_common_idx % (token_id + 1) - - return left, right - - -if __name__ == "__main__": - sample_text = DATA_PATH.read_text() - tokeniser = Tokeniser(1000) - tokens = tokeniser.encode(sample_text) - breakpoint() - print(tokens) diff --git a/tests/test_activations.py b/tests/test_activations.py index 5af1eca..5d15570 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -25,9 +25,9 @@ def test_gelu_full(): assert y.close_to([-0.158808, 0.0, 0.841192]) -def test_gelu_vectorised(): +def test_gelu_batched(): x = to_tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]) - x = x.to_vector() + x = x.to_batched() gelu = GeLU(approximate=False) y = gelu(x) assert y.close_to( diff --git a/tests/test_attention.py b/tests/test_attention.py index 6b7f75d..cf0fdbe 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -64,9 +64,9 @@ def test_attention_combined(): in_tensor = np.random.uniform( -5, 5, (batch_size, n_tokens, projected_size) ) - in_tensor = to_tensor(in_tensor).to_vector() + in_tensor = to_tensor(in_tensor).to_batched() - x = torch.from_numpy(in_tensor._data) + x = torch.from_numpy(in_tensor.array) x.requires_grad = True qu, k, v = x.split(embedding_dim, dim=-1) # pytorch @@ -86,13 +86,13 @@ def test_attention_combined(): n_heads=n_heads, context_window=context_window, ) - tricycle_result = tricycle_attention(in_tensor).from_vector() + tricycle_result = tricycle_attention(in_tensor).from_batched() assert tricycle_result.close_to( pytorch_result.detach().numpy(), equal_nan=True, rtol=1e-3, atol=1e-5 ) - tricycle_result.from_vector().sum().backward() + tricycle_result.from_batched().sum().backward() pytorch_result.sum().backward() assert in_tensor.grad.close_to( diff --git a/tests/test_binary.py b/tests/test_binary.py index 2298025..8416acb 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -39,8 +39,8 @@ def test_can_badd(): # sourcery skip: extract-duplicate-method def test_can_bsub(): # sourcery skip: extract-duplicate-method - in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_vector=True) - in_tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4), is_vector=True) + in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_batched=True) + in_tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4), is_batched=True) out_tensor = BinarySubtract()(in_tensor_1, in_tensor_2) @@ -68,8 +68,8 @@ def test_can_bsub(): # sourcery skip: extract-duplicate-method def test_can_bmul(): - in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_vector=True) - in_tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4), is_vector=True) + in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_batched=True) + in_tensor_2 = to_tensor(np.arange(1, 13).reshape(3, 4), is_batched=True) out_tensor = BinaryMultiply()(in_tensor_1, in_tensor_2) @@ -87,10 +87,10 @@ def test_can_bmul(): def test_can_bdiv(): in_tensor_1 = to_tensor( - np.arange(12).reshape(3, 4), is_vector=True, dtype=float + np.arange(12).reshape(3, 4), is_batched=True, dtype=float ) in_tensor_2 = to_tensor( - np.arange(1, 13).reshape(3, 4), is_vector=True, dtype=float + np.arange(1, 13).reshape(3, 4), is_batched=True, dtype=float ) out_tensor = BinaryDivide()(in_tensor_1, in_tensor_2) @@ -110,17 +110,17 @@ def test_can_bdiv(): assert in_tensor_1.grad is not None assert in_tensor_2.grad is not None - assert in_tensor_1.grad.close_to(1 / in_tensor_2._data) + assert in_tensor_1.grad.close_to(1 / in_tensor_2.array) assert in_tensor_2.grad.close_to( - -in_tensor_1._data / (in_tensor_2._data**2) + -in_tensor_1.array / (in_tensor_2.array**2) ) def test_can_bmax(): - in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_vector=True) + in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_batched=True) in_tensor_2 = to_tensor( - [[0, 0, 0, 0], [100, 100, 100, 100], [8, 9, 10, 11]], is_vector=True + [[0, 0, 0, 0], [100, 100, 100, 100], [8, 9, 10, 11]], is_batched=True ) out_tensor = BinaryMax()(in_tensor_1, in_tensor_2) @@ -141,9 +141,9 @@ def test_can_bmax(): def test_can_bmin(): - in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_vector=True) + in_tensor_1 = to_tensor(np.arange(12).reshape(3, 4), is_batched=True) in_tensor_2 = to_tensor( - [[0, 0, 0, 0], [100, 100, 100, 100], [8, 9, 10, 11]], is_vector=True + [[0, 0, 0, 0], [100, 100, 100, 100], [8, 9, 10, 11]], is_batched=True ) out_tensor = BinaryMin()(in_tensor_1, in_tensor_2) @@ -164,10 +164,10 @@ def test_can_bmin(): def test_can_bmask(): - in_tensor = to_tensor(np.arange(12).reshape(3, 4), is_vector=True) + in_tensor = to_tensor(np.arange(12).reshape(3, 4), is_batched=True) mask = to_tensor( [[0, 0, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], - is_vector=True, + is_batched=True, requires_grad=False, ) out_tensor = BinaryMask()(in_tensor, mask) diff --git a/tests/test_blocks.py b/tests/test_blocks.py index 86fcd0c..d36eea9 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -20,7 +20,7 @@ def test_attention_individually(): This operation is pretty complex so we'll perform each stage with pytorch and then compare the results. Here, I'm comparing with Andrej Karpathy's implementation from NanoGPT - For this test, we're doing everything non-vectorised + For this test, we're doing everything non-batch """ # setup embedding_dim = 15 @@ -38,7 +38,7 @@ def test_attention_individually(): in_tensor = np.random.uniform(-5, 5, (n_tokens, projected_size)) in_tensor = to_tensor(in_tensor) - x = torch.from_numpy(in_tensor._data) + x = torch.from_numpy(in_tensor.array) qu, k, v = x.split(embedding_dim, dim=-1) # pytorch query, key, value = in_tensor.split(3, axis=-1) # tricycle @@ -173,9 +173,9 @@ def test_attention_combined(): in_tensor = np.random.uniform( -5, 5, (batch_size, n_tokens, projected_size) ) - in_tensor = to_tensor(in_tensor).to_vector() + in_tensor = to_tensor(in_tensor).to_batched() - x = torch.from_numpy(in_tensor._data) + x = torch.from_numpy(in_tensor.array) qu, k, v = x.split(embedding_dim, dim=-1) # pytorch query, key, value = in_tensor.split(3, axis=-1) # tricycle @@ -212,7 +212,7 @@ def test_attention_combined(): context_window=context_window, residual_dropout_prob=0, ) - tricycle_result = tricycle_attention.attention(in_tensor).from_vector() + tricycle_result = tricycle_attention.attention(in_tensor).from_batched() assert np.allclose(andrej_result, pytorch_result, rtol=1e-3) assert tricycle_result.close_to(andrej_result) @@ -264,7 +264,7 @@ def test_attention_block(): out_projection_weights, name="out_proj" ) - in_tensor = to_tensor(x, requires_grad=False).to_vector() + in_tensor = to_tensor(x, requires_grad=False).to_batched() tricycle_result = tricycle_attention(in_tensor) c_attn = torch.nn.Linear(embedding_dim, 3 * embedding_dim, bias=False) @@ -288,7 +288,7 @@ def test_attention_block(): andrej_result.detach().numpy(), rtol=1e-3, atol=1e-4 ) - tricycle_loss = tricycle_result.from_vector().e("abc->") + tricycle_loss = tricycle_result.from_batched().e("abc->") andrej_loss = andrej_result.sum() assert tricycle_loss.close_to(andrej_loss.detach().numpy()) @@ -296,7 +296,7 @@ def test_attention_block(): tricycle_loss.backward() andrej_loss.backward() - assert not tricycle_attention.out_projection.weights.is_vector + assert not tricycle_attention.out_projection.weights.is_batched tricycle_out_weights = tricycle_attention.out_projection.weights.grad assert tricycle_out_weights.close_to(c_proj.weight.grad.T.numpy()) @@ -319,7 +319,7 @@ def test_MLPBlock(): block.linear_1.weights = to_tensor(np.ones(block.linear_1.weights.shape)) block.linear_2.weights = to_tensor(np.ones(block.linear_2.weights.shape)) - out_tensor = block(in_tensor.to_vector()) + out_tensor = block(in_tensor.to_batched()) assert out_tensor.shape == (3, 4) @@ -337,7 +337,7 @@ def test_MLPBlock(): ) correct_output = to_tensor(correct_output) - assert out_tensor.is_vector + assert out_tensor.is_batched assert out_tensor.close_to(correct_output) out_tensor.backward() @@ -362,7 +362,8 @@ def test_GPT2TransformerBlock(): embedding_dim = 7 * n_heads in_tensor = to_tensor( - np.random.random((batch_size, n_tokens, embedding_dim)), is_vector=True + np.random.random((batch_size, n_tokens, embedding_dim)), + is_batched=True, ) block = GPT2TransformerBlock( embedding_dim=embedding_dim, @@ -371,7 +372,7 @@ def test_GPT2TransformerBlock(): context_window=32, ) - out_tensor = block(in_tensor.to_vector()) + out_tensor = block(in_tensor.to_batched()) assert out_tensor.shape == (batch_size, n_tokens, embedding_dim) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index 872186f..cb80b45 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -4,7 +4,7 @@ from tricycle.tensor import to_tensor -def test_vector_reduce(): +def test_batched_reduce(): x = to_tensor(np.arange(5)) op = Einsum("a->") result = op(x) @@ -38,7 +38,7 @@ def test_matrix_partial_reduce(): def test_transpose(): x = to_tensor(np.arange(20).reshape(4, 5)) op = Einsum("ij->ji") - assert op(x).close_to(x._data.T) + assert op(x).close_to(x.array.T) op(x).backward() assert x.grad is not None diff --git a/tests/test_layers.py b/tests/test_layers.py index 3d5b54b..dc1395e 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -46,16 +46,16 @@ def test_dropout(): # sourcery skip: square-identity size = 100 dropout_prob = 0.3 - # non-vectorised + # non-batched in_tensor = to_tensor( np.random.normal(size=(size, size)), name="in_tensor" ) dropout = Dropout(dropout_prob) - out_tensor = dropout(in_tensor.to_vector()) + out_tensor = dropout(in_tensor.to_batched()) assert out_tensor.shape == in_tensor.shape - zero_x_idx, zero_y_idx = np.where(out_tensor._data == 0) + zero_x_idx, zero_y_idx = np.where(out_tensor.array == 0) n_zeros = len(zero_x_idx) expected_n_zeros = int(size * size * dropout_prob) @@ -77,13 +77,13 @@ def test_layer_norm(): np.random.seed(0) in_tensor = to_tensor(np.random.normal(size=(100, 100)), name="in_tensor") layer_norm = LayerNorm(100) - out_tensor = layer_norm(in_tensor.to_vector()) + out_tensor = layer_norm(in_tensor.to_batched()) assert out_tensor.shape == in_tensor.shape out_tensor.backward() assert copy(out_tensor).mean().close_to(0, atol=1e-6) - assert np.allclose(np.std(out_tensor._data), [1] * 100, atol=1e-7) + assert np.allclose(np.std(out_tensor.array), [1] * 100, atol=1e-7) assert in_tensor.grad is not None assert in_tensor.grad.shape == in_tensor.shape @@ -125,7 +125,7 @@ def test_embedding(): ) -def test_embedding_vectorised(): +def test_embedding_batched(): np.random.seed(0) vocab_size = 3 out_shape = 5 @@ -133,7 +133,7 @@ def test_embedding_vectorised(): [[0, 1, 2, 0], [1, 2, 2, 1]], requires_grad=False, dtype=np.int8, - ).to_vector() + ).to_batched() embedding_layer = Embedding(from_size=vocab_size, to_size=out_shape) weights = np.indices((vocab_size * out_shape,)).reshape( @@ -174,10 +174,10 @@ def test_rms_norm(): np.random.seed(0) in_tensor = to_tensor(np.random.normal(size=(100, 100)), name="in_tensor") layer_norm = RMSNorm(100) - out_tensor = layer_norm(in_tensor.to_vector()) + out_tensor = layer_norm(in_tensor.to_batched()) assert out_tensor.shape == in_tensor.shape - assert np.allclose((out_tensor._data**2).mean(), 1) + assert np.allclose((out_tensor.array**2).mean(), 1) out_tensor.backward() assert in_tensor.grad is not None diff --git a/tests/test_loss.py b/tests/test_loss.py index 0dbacc6..0f4e193 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -35,7 +35,7 @@ def test_can_CrossEntropy(): assert loss.close_to(1.0986122886681098) -def test_CrossEntropy_vectorised(): +def test_CrossEntropy_batched(): batch_size = 3 n_tokens = 5 vocab_size = 7 @@ -43,8 +43,8 @@ def test_CrossEntropy_vectorised(): y_true = np.random.randint(0, vocab_size, size=(batch_size, n_tokens)) y_pred = np.random.random((batch_size, n_tokens, vocab_size)) - y_true = to_tensor(y_true, dtype=int).to_vector() - y_pred = to_tensor(y_pred).to_vector() + y_true = to_tensor(y_true, dtype=int).to_batched() + y_pred = to_tensor(y_pred).to_batched() loss = CrossEntropy()(y_true, y_pred) @@ -87,8 +87,8 @@ def test_single_lr_step_with_multiple_datapoints(): slope = to_tensor([0.02]) intercept = to_tensor([0.01]) - x_input = to_tensor(x, requires_grad=False, name="x", is_vector=True) - y_input = to_tensor(y, requires_grad=False, name="y", is_vector=True) + x_input = to_tensor(x, requires_grad=False, name="x", is_batched=True) + y_input = to_tensor(y, requires_grad=False, name="y", is_batched=True) y_pred = x_input * slope + intercept loss = mean_square_error(y_input, y_pred) @@ -233,16 +233,16 @@ def model(X, slope, intercept): return Einsum("i,ij->j")(X, slope) + intercept for idx in range(loops): - X = to_tensor(X_data).to_vector() - y = to_tensor(y_data).to_vector() + X = to_tensor(X_data).to_batched() + y = to_tensor(y_data).to_batched() # predict an output y_pred = model(X, slope, intercept) # calculate the loss loss = mean_square_error(y, y_pred) - # we need to unvectorise the loss before finding its average - loss = loss.from_vector().mean() + # we need to unbatch the loss before finding its average + loss = loss.from_batched().mean() losses[idx] = loss.numpy() loss.backward() @@ -250,8 +250,8 @@ def model(X, slope, intercept): assert slope.grad is not None assert intercept.grad is not None - slope.grad = slope.grad.from_vector().e("abc->bc") - intercept.grad = intercept.grad.from_vector().e("ab->b") + slope.grad = slope.grad.from_batched().e("abc->bc") + intercept.grad = intercept.grad.from_batched().e("ab->b") slope = (slope - slope.grad * learning_rate).zero_grad() intercept = (intercept - intercept.grad * learning_rate).zero_grad() diff --git a/tests/test_model_matches_pytorch.py b/tests/test_model_matches_pytorch.py index d62f5a6..d89438a 100644 --- a/tests/test_model_matches_pytorch.py +++ b/tests/test_model_matches_pytorch.py @@ -46,7 +46,7 @@ def integer(draw): @st.composite def tokens(draw): """ - Tokens are a list of integers. They can be either 1d or 2d and vectorised + Tokens are a list of integers. They can be either 1d or 2d and batched """ shape = draw(xp.array_shapes(min_dims=1, max_dims=2, max_side=64)) tokens_ = draw( @@ -57,7 +57,10 @@ def tokens(draw): ) ) return to_tensor( - tokens_, is_vector=len(shape) == 2, dtype=np.int64, requires_grad=False + tokens_, + is_batched=len(shape) == 2, + dtype=np.int64, + requires_grad=False, ) @@ -80,20 +83,20 @@ def tensor(draw): """ Generate a single, initial tensor (not as the result of an operation) For our model, we need the following tensors: - - 1d non-vector - - 2d non-vector - - 2d vector - - 3d vector + - 1d non-batch + - 2d non-batch + - 2d batch + - 3d batch """ shape_ = draw(tensor_shape()) data = draw(xp.arrays(dtype=np.float64, shape=shape_)) match len(shape_): case 1: - is_vector = False + is_batched = False case 2: - is_vector = draw(st.booleans()) + is_batched = draw(st.booleans()) case 3: - is_vector = True + is_batched = True requires_grad = True if CUPY_ENABLED: on_gpu = draw(st.booleans()) @@ -103,7 +106,7 @@ def tensor(draw): tensor = to_tensor( data, - is_vector=is_vector, + is_batched=is_batched, requires_grad=requires_grad, ) if on_gpu: @@ -114,7 +117,7 @@ def tensor(draw): @st.composite def embedding_shape(draw): """ - Embeddings are either 2d or 3d and vectorised + Embeddings are either 2d or 3d and batched """ return draw( st.lists( @@ -123,29 +126,29 @@ def embedding_shape(draw): ) -def build_tensor(shape_, is_vector): +def build_tensor(shape_, is_batched): """ Generate a single, initial tensor (not as the result of an operation) For our model, we need the following tensors: - - 1d non-vector - - 2d non-vector - - 2d vector - - 3d vector + - 1d non-batch + - 2d non-batch + - 2d batch + - 3d batch """ np.random.seed(0) data = np.random.random(shape_).astype(np.float32) match len(shape_): case 1: - is_vector = False + is_batched = False case 2: - is_vector = is_vector + is_batched = is_batched case 3: - is_vector = True + is_batched = True requires_grad = True return to_tensor( data, - is_vector=is_vector, + is_batched=is_batched, requires_grad=requires_grad, ) @@ -158,7 +161,7 @@ def small_tensor(draw): """ shape = draw(st.integers(min_value=1, max_value=4)) data = draw(xp.arrays(dtype=np.float64, shape=shape)) - is_vector = len(shape) in {3, 4} + is_batched = len(shape) in {3, 4} requires_grad = draw(st.booleans()) if CUPY_ENABLED: on_gpu = draw(st.booleans()) @@ -168,7 +171,7 @@ def small_tensor(draw): tensor = to_tensor( data, - is_vector=is_vector, + is_batched=is_batched, requires_grad=requires_grad, ) if on_gpu: @@ -188,12 +191,12 @@ def tensor_pair_same_shape(draw): tensors = [] for _ in range(2): data = draw(xp.arrays(dtype=np.float64, shape=shape)) - is_vector = draw(st.booleans()) + is_batched = draw(st.booleans()) if draw(st.booleans()): data = data[1:] - tensor = to_tensor(data, is_vector=is_vector) + tensor = to_tensor(data, is_batched=is_batched) tensors.append(tensor) return tensors @@ -201,9 +204,9 @@ def tensor_pair_same_shape(draw): @given(tensor_shape(), integer(), st.booleans()) @settings(deadline=1000) -def test_tricycle_dense_matches_pytorch(in_shape, out_shape, is_vector): - tensor = build_tensor(in_shape, is_vector) - assume(np.isfinite(tensor._data).all()) +def test_tricycle_dense_matches_pytorch(in_shape, out_shape, is_batched): + tensor = build_tensor(in_shape, is_batched) + assume(np.isfinite(tensor.array).all()) from_size = tensor.shape[-1] @@ -213,7 +216,7 @@ def test_tricycle_dense_matches_pytorch(in_shape, out_shape, is_vector): tr_layer = Dense(from_size=from_size, to_size=out_shape) tr_layer.weights = to_tensor(pt_layer.weight.detach().numpy().T) - pt_out = pt_layer(torch.tensor(tensor._data)) + pt_out = pt_layer(torch.tensor(tensor.array)) tr_out = tr_layer(tensor) assert np.allclose( @@ -221,7 +224,7 @@ def test_tricycle_dense_matches_pytorch(in_shape, out_shape, is_vector): ) pt_out.sum().backward() - tr_out.from_vector().sum().backward() + tr_out.from_batched().sum().backward() assert np.allclose( pt_layer.weight.grad.detach().numpy().T, @@ -232,21 +235,15 @@ def test_tricycle_dense_matches_pytorch(in_shape, out_shape, is_vector): @given(tokens(), integer()) @settings(deadline=1000) -# @example( -# tokens_=to_tensor( -# [[1, 1], [1, 1]], dtype=np.int64, requires_grad=False, is_vector=True -# ), -# out_shape=1, -# ) def test_embedding_matches(tokens_, out_shape): - vocab_size = tokens_._data.max() + 1 + vocab_size = tokens_.array.max() + 1 pt_layer = torch.nn.Embedding( num_embeddings=vocab_size, embedding_dim=out_shape ) tr_layer = Embedding(from_size=vocab_size, to_size=out_shape) tr_layer.weights = to_tensor(pt_layer.weight.detach().numpy()) - pt_out = pt_layer(torch.tensor(tokens_._data)) + pt_out = pt_layer(torch.tensor(tokens_.array)) tr_out = tr_layer(tokens_) assert np.allclose( @@ -254,7 +251,7 @@ def test_embedding_matches(tokens_, out_shape): ) pt_out.sum().backward() - tr_out.from_vector().sum().backward() + tr_out.from_batched().sum().backward() assert np.allclose( pt_layer.weight.grad.detach().numpy(), @@ -265,17 +262,17 @@ def test_embedding_matches(tokens_, out_shape): @given(tensor_shape(force_divisible_by_32=True), st.booleans()) @settings(deadline=1000) -@example(in_shape=[1, 1, 1, 128], is_vector=True) -def test_tricycle_softmax_matches_pytorch(in_shape, is_vector): - tensor = build_tensor(in_shape, is_vector) - assume(np.isfinite(tensor._data).all()) +@example(in_shape=[1, 1, 1, 128], is_batched=True) +def test_tricycle_softmax_matches_pytorch(in_shape, is_batched): + tensor = build_tensor(in_shape, is_batched) + assume(np.isfinite(tensor.array).all()) tensor.requires_grad = True pt_layer = torch.nn.functional.softmax tr_layer = Softmax() - pt_input = torch.tensor(tensor._data, requires_grad=True) + pt_input = torch.tensor(tensor.array, requires_grad=True) pt_out = pt_layer(pt_input, dim=-1) tr_out = tr_layer(tensor) @@ -285,7 +282,7 @@ def test_tricycle_softmax_matches_pytorch(in_shape, is_vector): ) pt_out.sum().backward() - tr_out.from_vector().sum().backward() + tr_out.from_batched().sum().backward() assert np.allclose( pt_input.grad.detach().numpy(), @@ -296,25 +293,25 @@ def test_tricycle_softmax_matches_pytorch(in_shape, is_vector): @given(tensor_shape(), st.booleans()) -@example(in_shape=[2, 2, 4], is_vector=False) -def test_crossentropy_matches(in_shape, is_vector): - y_pred = build_tensor(in_shape, is_vector) +@example(in_shape=[2, 2, 4], is_batched=False) +def test_crossentropy_matches(in_shape, is_batched): + y_pred = build_tensor(in_shape, is_batched) y_true = np.random.randint(0, in_shape[-1], size=in_shape[:-1]) - y_true = to_tensor(y_true, is_vector=is_vector, dtype=int) - assume(np.isfinite(y_pred._data).all()) + y_true = to_tensor(y_true, is_batched=is_batched, dtype=int) + assume(np.isfinite(y_pred.array).all()) - tr_out = CrossEntropy()(y_true, y_pred).from_vector() + tr_out = CrossEntropy()(y_true, y_pred).from_batched() if len(in_shape) > 1: tr_out = tr_out.mean() if len(in_shape) == 1: - p_y_pred = copy(y_pred._data) + p_y_pred = copy(y_pred.array) if len(in_shape) == 2: - p_y_pred = copy(y_pred._data) + p_y_pred = copy(y_pred.array) if len(in_shape) == 3: - p_y_pred = copy(y_pred._data).transpose(0, -1, 1) + p_y_pred = copy(y_pred.array).transpose(0, -1, 1) p_y_pred = torch.tensor(p_y_pred, requires_grad=True) - p_y_true = torch.tensor(y_true._data, dtype=torch.long) + p_y_true = torch.tensor(y_true.array, dtype=torch.long) p_out = torch.nn.CrossEntropyLoss()( input=p_y_pred, diff --git a/tests/test_optimisers.py b/tests/test_optimisers.py index 8ed2568..26c4039 100644 --- a/tests/test_optimisers.py +++ b/tests/test_optimisers.py @@ -32,7 +32,7 @@ def test_can_train_simple_neural_network_no_wd(): optimiser = StochasticGradientDescent(learning_rate=1e-2) losses = [] - batches = ds.to_tensor().to_vector() + batches = ds.to_tensor() # sourcery skip: no-loop-in-tests # sourcery skip: no-conditionals-in-tests for step, (x, y) in enumerate(batches): @@ -74,7 +74,7 @@ def test_can_train_simple_neural_network_wd(): optimiser = StochasticGradientDescent(learning_rate=1e-2, weight_decay=1e1) losses = [] - batches = ds.to_tensor().to_vector() + batches = ds.to_tensor() # sourcery skip: no-loop-in-tests # sourcery skip: no-conditionals-in-tests for step, (x, y) in enumerate(batches): @@ -116,7 +116,7 @@ def test_can_train_simple_neural_network_momentum(): optimiser = StochasticGradientDescent(learning_rate=1e-2, momentum=0.9) losses = [] - batches = ds.to_tensor().to_vector() + batches = ds.to_tensor() # sourcery skip: no-loop-in-tests # sourcery skip: no-conditionals-in-tests for step, (x, y) in enumerate(batches): diff --git a/tests/test_simple_neural_network.py b/tests/test_simple_neural_network.py index dd0538d..28e243f 100644 --- a/tests/test_simple_neural_network.py +++ b/tests/test_simple_neural_network.py @@ -49,7 +49,7 @@ def test_can_train_simple_neural_network(): # sourcery skip: no-loop-in-tests # sourcery skip: no-conditionals-in-tests i = 0 - batches = ds.to_tensor().to_vector() + batches = ds.to_tensor() for step, (x_in, y_out) in enumerate(batches): if step > N_STEPS: break @@ -102,7 +102,7 @@ def test_can_train_simple_neural_network_gpu(): # sourcery skip: no-loop-in-tests # sourcery skip: no-conditionals-in-tests i = 0 - batches = ds.to_tensor().to_vector() + batches = ds.to_tensor() for step, (x_in, y_out) in enumerate(batches): if step > N_STEPS: break @@ -110,7 +110,7 @@ def test_can_train_simple_neural_network_gpu(): y_out = y_out.to_gpu() y_pred = model(x_in) - loss = loss_fn(y_out, y_pred).from_vector().e("a->") / BATCH_SIZE + loss = loss_fn(y_out, y_pred).from_batched().e("a->") / BATCH_SIZE loss.backward() model.update(optimiser) diff --git a/tests/test_tensor_pbt.py b/tests/test_tensor_pbt.py index 6f9a2c0..388c225 100644 --- a/tests/test_tensor_pbt.py +++ b/tests/test_tensor_pbt.py @@ -20,7 +20,7 @@ ) from tricycle.einsum import EinsumBackOp from tricycle.layers import Dense -from tricycle.tensor import nothing, to_tensor, unvectorise, vectorise +from tricycle.tensor import batch, nothing, to_tensor, unbatch from tricycle.tokeniser import BPETokeniser from tricycle.unary import ( UnaryAdd, @@ -117,17 +117,17 @@ def tensor(draw): data = draw(xp.arrays(dtype=np.float32, shape=shape)) match len(shape): case 1: - is_vector = False + is_batched = False case 2: - is_vector = draw(st.booleans()) + is_batched = draw(st.booleans()) case 3: - is_vector = draw(st.booleans()) + is_batched = draw(st.booleans()) case 4: - is_vector = True + is_batched = True requires_grad = draw(st.booleans()) return to_tensor( data, - is_vector=is_vector, + is_batched=is_batched, requires_grad=requires_grad, ) @@ -140,7 +140,7 @@ def small_tensor(draw): """ shape = draw(st.integers(min_value=1, max_value=4)) data = draw(xp.arrays(dtype=np.float64, shape=shape)) - is_vector = len(shape) in {3, 4} + is_batched = len(shape) in {3, 4} requires_grad = draw(st.booleans()) if CUPY_ENABLED: on_gpu = draw(st.booleans()) @@ -150,7 +150,7 @@ def small_tensor(draw): tensor = to_tensor( data, - is_vector=is_vector, + is_batched=is_batched, requires_grad=requires_grad, ) if on_gpu: @@ -170,12 +170,12 @@ def tensor_pair_same_shape(draw): tensors = [] for _ in range(2): data = draw(xp.arrays(dtype=np.float64, shape=shape)) - is_vector = draw(st.booleans()) + is_batched = draw(st.booleans()) if draw(st.booleans()): data = data[1:] - tensor = to_tensor(data, is_vector=is_vector) + tensor = to_tensor(data, is_batched=is_batched) tensors.append(tensor) return tensors @@ -202,7 +202,7 @@ def test_tensor_addition_same_shape(tensors): assert result.args == (tensor_1, tensor_2) assert result.back_fns == (nothing, nothing) - assert result.is_vector == tensor_1.is_vector or tensor_2.is_vector + assert result.is_batched == tensor_1.is_batched or tensor_2.is_batched @given(tensor(), scalar()) @@ -221,7 +221,7 @@ def test_tensor_addition_scalar(tensor, scalar): assert result.args == (tensor,) assert result.back_fns == (nothing,) - assert result.is_vector == tensor.is_vector + assert result.is_batched == tensor.is_batched @given(tensor_pair_same_shape()) @@ -248,39 +248,39 @@ def test_tensor_multiplication(tensors): assert isinstance(result.back_fns[0], EinsumBackOp) assert isinstance(result.back_fns[1], EinsumBackOp) - assert result.is_vector == tensor_1.is_vector or tensor_2.is_vector + assert result.is_batched == tensor_1.is_batched or tensor_2.is_batched @given(tensor()) def test_close_to(tensor): - equal_nan = np.isnan(tensor._data).any() + equal_nan = np.isnan(tensor.array).any() assert tensor.close_to(tensor, equal_nan=equal_nan, rtol=1e-6, atol=1e-8) @given(tensor()) -def test_can_vectorise_and_unvectorise(tensor): - assume(not tensor.is_vector) +def test_can_batch_and_unbatch(tensor): + assume(not tensor.is_batched) - vectorised = tensor.to_vector() - assert vectorised.is_vector + batched = tensor.to_batched() + assert batched.is_batched - unvectorised = vectorised.from_vector() - assert not unvectorised.is_vector + unbatched = batched.from_batched() + assert not unbatched.is_batched - assert tensor.close_to(unvectorised, equal_nan=True) + assert tensor.close_to(unbatched, equal_nan=True) # sourcery skip: no-conditionals-in-tests if tensor.requires_grad: - assert len(unvectorised.args) == 1 - assert unvectorised.args[0].close_to(tensor, equal_nan=True) - assert unvectorised.back_fns == (vectorise,) + assert len(unbatched.args) == 1 + assert unbatched.args[0].close_to(tensor, equal_nan=True) + assert unbatched.back_fns == (batch,) - assert len(unvectorised.args[0].args) == 1 - assert unvectorised.args[0].args[0].close_to(tensor, equal_nan=True) - assert unvectorised.args[0].back_fns == (unvectorise,) + assert len(unbatched.args[0].args) == 1 + assert unbatched.args[0].args[0].close_to(tensor, equal_nan=True) + assert unbatched.args[0].back_fns == (unbatch,) - assert unvectorised.requires_grad + assert unbatched.requires_grad @given(tensor()) @@ -309,7 +309,7 @@ def test_unary_ops(tensor, op): else: result = op(tensor) assert result.shape == tensor.shape - assert result.is_vector == tensor.is_vector + assert result.is_batched == tensor.is_batched assert result.on_gpu == tensor.on_gpu @@ -327,7 +327,7 @@ def test_binary_ops(tensors, op): result = op(tensor_1, tensor_2) assert result.shape in [tensor_1.shape, tensor_2.shape] - assert result.is_vector == any([tensor_1.is_vector, tensor_2.is_vector]) + assert result.is_batched == any([tensor_1.is_batched, tensor_2.is_batched]) assert result.on_gpu == any([tensor_1.on_gpu, tensor_2.on_gpu]) @@ -361,7 +361,7 @@ def test_tokeniser_train_encode_decode(text): def test_tricycle_dense_matches_pytorch(tensor, out_shape): np.random.seed(0) torch.manual_seed(0) - assume(np.isfinite(tensor._data).all()) + assume(np.isfinite(tensor.array).all()) from_size = tensor.shape[-1] @@ -376,7 +376,7 @@ def test_tricycle_dense_matches_pytorch(tensor, out_shape): pt_layer.weight.detach().numpy().T, dtype=np.float32 ) - pt_out = pt_layer(torch.tensor(tensor._data)) + pt_out = pt_layer(torch.tensor(tensor.array)) tr_out = tr_layer(tensor) pt_out_np = pt_out.detach().numpy() @@ -389,7 +389,7 @@ def test_tricycle_dense_matches_pytorch(tensor, out_shape): assert tr_out.close_to(pt_out_np, rtol=1e-2, equal_nan=True) pt_out.sum().backward() - match (tensor.ndim, tensor.is_vector): + match (tensor.ndim, tensor.is_batched): case 1, False: tr_out.e("a->").backward() case 2, False: @@ -397,11 +397,11 @@ def test_tricycle_dense_matches_pytorch(tensor, out_shape): case 3, False: tr_out.e("abc->").backward() case 2, True: - tr_out.from_vector().e("ab->").backward() + tr_out.from_batched().e("ab->").backward() case 3, True: - tr_out.from_vector().e("abc->").backward() + tr_out.from_batched().e("abc->").backward() case 4, True: - tr_out.from_vector().e("abcd->").backward() + tr_out.from_batched().e("abcd->").backward() assert np.allclose( pt_layer.weight.grad.detach().numpy().T, diff --git a/tests/test_unary_ops.py b/tests/test_unary_ops.py index 224bab1..c922829 100644 --- a/tests/test_unary_ops.py +++ b/tests/test_unary_ops.py @@ -91,7 +91,7 @@ def test_can_udiv(): ) with np.errstate(divide="ignore"): out_tensor.backward() - correct = -np.power(in_tensor._data, -2) * 2 + correct = -np.power(in_tensor.array, -2) * 2 assert in_tensor.grad is not None assert in_tensor.grad.close_to(correct) diff --git a/tests/test_vectorise.py b/tests/test_vectorise.py index 50fb582..2409e8a 100644 --- a/tests/test_vectorise.py +++ b/tests/test_vectorise.py @@ -5,10 +5,10 @@ from tricycle.functions import Softmax from tricycle.layers import Dense, Sequential from tricycle.loss import CrossEntropy, mean_square_error -from tricycle.tensor import to_tensor, unvectorise, vectorise +from tricycle.tensor import batch, to_tensor, unbatch -def test_can_vectorise_single_einsum(): +def test_can_batch_single_einsum(): input_1 = np.arange(1, 4) input_2 = np.arange(2, 5) input_3 = np.arange(3, 6) @@ -23,16 +23,16 @@ def test_can_vectorise_single_einsum(): assert output_2 == 9 assert output_3 == 12 - input_vector = to_tensor([input_1, input_2, input_3]) - input_vector = vectorise(input_vector) + input_batch = to_tensor([input_1, input_2, input_3]) + input_batch = batch(input_batch) op = Einsum("a->") - output_vector = op(input_vector) - output_vector = unvectorise(output_vector) + output_batch = op(input_batch) + output_batch = unbatch(output_batch) - assert output_vector.close_to([6, 9, 12]) + assert output_batch.close_to([6, 9, 12]) -def test_can_vectorise_entire_model(): +def test_can_batch_entire_model(): np.random.seed(42) layer_1 = Dense(4, 16) layer_2 = Dense(16, 3) @@ -47,20 +47,20 @@ def test_can_vectorise_entire_model(): output_2 = model(to_tensor(input_2)) output_3 = model(to_tensor(input_3)) - input_vector = to_tensor([input_1, input_2, input_3]) + input_batch = to_tensor([input_1, input_2, input_3]) correct_output = to_tensor( - [output_1._data, output_2._data, output_3._data] + [output_1.array, output_2.array, output_3.array] ) - input_vector = vectorise(input_vector) - correct_output = vectorise(correct_output) - output_vector = model(input_vector) - output_vector = unvectorise(output_vector) + input_batch = batch(input_batch) + correct_output = batch(correct_output) + output_batch = model(input_batch) + output_batch = unbatch(output_batch) - assert output_vector.close_to(correct_output) + assert output_batch.close_to(correct_output) -def test_can_vectorise_mse(): +def test_can_batch_mse(): y_true = to_tensor([0, 0, 1, 0]) input_1 = to_tensor(np.arange(1, 5)) @@ -71,23 +71,23 @@ def test_can_vectorise_mse(): output_2 = mean_square_error(y_true, input_2) output_3 = mean_square_error(y_true, input_3) - input_y_true = to_tensor(np.array([y_true._data] * 3)) - input_vector = to_tensor( - np.array([input_1._data, input_2._data, input_3._data]) + input_y_true = to_tensor(np.array([y_true.array] * 3)) + input_batch = to_tensor( + np.array([input_1.array, input_2.array, input_3.array]) ) correct_output = to_tensor( - np.array([output_1._data, output_2._data, output_3._data]).sum() + np.array([output_1.array, output_2.array, output_3.array]).sum() ) - input_y_true = vectorise(input_y_true) - input_vector = vectorise(input_vector) - output_vector = mean_square_error(input_y_true, input_vector) - output_vector = unvectorise(output_vector) + input_y_true = batch(input_y_true) + input_batch = batch(input_batch) + output_batch = mean_square_error(input_y_true, input_batch) + output_batch = unbatch(output_batch) - assert output_vector.close_to(correct_output) + assert output_batch.close_to(correct_output) -def test_can_vectorise_softmax(): +def test_can_batch_softmax(): input_1 = to_tensor(np.arange(1, 5)) input_2 = to_tensor(np.arange(2, 6)) input_3 = to_tensor(np.arange(3, 7)) @@ -96,26 +96,26 @@ def test_can_vectorise_softmax(): output_2 = Softmax()(input_2) output_3 = Softmax()(input_3) - input_vector = to_tensor( - np.array([input_1._data, input_2._data, input_3._data]) + input_batch = to_tensor( + np.array([input_1.array, input_2.array, input_3.array]) ) correct_output = to_tensor( - np.array([output_1._data, output_2._data, output_3._data]) + np.array([output_1.array, output_2.array, output_3.array]) ) - input_vector = vectorise(input_vector) - output_vector = Softmax()(input_vector) - output_vector = unvectorise(output_vector) + input_batch = batch(input_batch) + output_batch = Softmax()(input_batch) + output_batch = unbatch(output_batch) - assert output_vector.close_to(correct_output) + assert output_batch.close_to(correct_output) -def test_can_vectorise_split(): +def test_can_batch_split(): in_tensor = to_tensor( [[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]], name="in_tensor" ) - out_tensors = in_tensor.to_vector().split(3) + out_tensors = in_tensor.to_batched().split(3) assert len(out_tensors) == 3 assert out_tensors[0].shape == (2, 2) @@ -126,9 +126,9 @@ def test_can_vectorise_split(): assert out_tensors[1].close_to([[3, 4], [3, 4]]) assert out_tensors[2].close_to([[5, 6], [5, 6]]) - assert out_tensors[0].is_vector - assert out_tensors[1].is_vector - assert out_tensors[2].is_vector + assert out_tensors[0].is_batched + assert out_tensors[1].is_batched + assert out_tensors[2].is_batched out_tensors[0].backward() diff --git a/train_smol_gpt.py b/train_smol_gpt.py index 4192a80..61fc519 100644 --- a/train_smol_gpt.py +++ b/train_smol_gpt.py @@ -115,7 +115,7 @@ def estimate_loss( # forward and backward pass logits = model(inputs) loss = loss_fn(outputs, logits) - batch_loss += loss._data / config.eval_steps + batch_loss += loss.array / config.eval_steps return batch_loss @@ -169,7 +169,7 @@ def estimate_loss( # forward and backward pass logits = model(inputs) loss = loss_fn(outputs, logits) - batch_loss += loss._data / config.gradient_accumulation_steps + batch_loss += loss.array / config.gradient_accumulation_steps loss.backward() # Use the optimiser to update weights