In [1]:
from __future__ import annotations

import os
os.environ['DEBUG'] = '4'

import time, math
from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set

from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule

In [2]:
class Function:
    def __init__(self, device: str, *tensors: Tensor):
        self.device = device # device
        self.needs_input_grad = [t.requires_grad for t in tensors] # which tensors need grad
        self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False # if any tensor needs grad
        if self.requires_grad: self.parents = tensors  # parents if needed for backprop

    def forward(self, *args, **kwargs): raise NotImplementedError
    def backward(self, *args, **kwargs): raise NotImplementedError

    @classmethod
    def apply(fxn : Type[Function], *x: Tensor, **kwargs) -> Tensor:
        ctx = fxn(x[0].device, *x) # construct context
        ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
        if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
        return ret

In [4]:
class Tensor:
    __slots__ = "lazydata", "requires_grad", "grad", "_ctx" # this specifically declates which attributes are allowed to save memory
    __deletable__ = ('_ctx',) # declare which attributes can be deleted to save memory
    training: ClassVar[bool] = True # are we training? i.e. do we need to compute gradients?

    # Context manager to enable/disable training (i.e. gradients)
    class train:
        def __enter__(self):
            self.prev = Tensor.training
            Tensor.training = True
        def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
            Tensor.training = self.prev # revert to previous state

    no_grad: ClassVar[bool] = False # always start with gradients enabled
    default_type: ClassVar[DType] = dtypes.float32 # default type for new tensors

    def __init__(self, data: Union[int, float, list, LazyBuffer, np.ndarray], device: Optional[str] = None, dtype: Optional[DType] = None, requires_grad: Optional[bool] = None):
        assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
        device = Device.canonicalize(device) # handles lower stuff
        # tensor have gradients, buffers do not
        self.grad: Optional[Tensor] = None # gradient

        # NOTE: this can be in three states. False and None: no gradient, True: gradient
        # None (the default) will be updated to True if it's put in an optimizer
        self.requires_grad: Optional[bool] = requires_grad 

        # internal variables used for autograd graph construction
        self._ctx: Optional[Function] = None # context for autograd

        # ! Logic to handle instantiation of different data
        if isinstance(data, LazyBuffer):
            assert dtype is None or dtype == data.dtype, f"dtype doesn't match, and casting isn't supported"
        elif isinstance(data, (int, float)): # if we're instantiating from a scalar
            data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
        elif data.__class__ is list:
            assert dtype is None or dtype.np is not None, f"{dtype} doesnt have a numpy dtype"
            data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
        elif isinstance(data, np.ndarray):
            assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
            if data.shape == (): # if we're instantiating from a scalar
                data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
            else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
        else: raise RuntimeError(f"can't create Tensor from {data}")

        # data is a LazyBuffer, but it might be on the wrong device
        self.lazydata = data if data.device == device else data.copy_to_device(device)
        #! note the input data is finally stored in the .lazydata attrb as a LazyBuffer

    def __repr__(self):
        return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"

    # Python has a non moving GC, so this should be okay
    def __hash__(self): return id(self)

    @property
    def device(self) -> str: return self.lazydata.device

    @property
    def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape

    @property
    def dtype(self) -> DType: return self.lazydata.dtype

    # ! data handlers ====================

    @staticmethod
    def corealize(lst: Iterable[Tensor]): # realize all tensors in a list
        seen: Set[LazyBuffer] = set()
        sched = []
        for t in lst: sched += t.lazydata.schedule(seen)
        run_schedule(sched)

    def realize(self) -> Tensor: # realize the tensor (compute all ops)
        run_schedule(self.lazydata.schedule())
        return self
    
    def assign(self, x) -> Tensor:
        # TODO: this is a hack for writing to DISK
        if self.device.startswith("DISK"):
            if x.__class__ is not Tensor: x = Tensor(x, device='CPU', dtype=self.dtype) # make tensor
            self.contiguous().realize().lazydata.realized_copyin(x.numpy())
            return self 
        if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
        assert self.shape == x.shape and self.device == self.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
        assert not x.requires_grad
        if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
        if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
        self.lazydata = x.lazydata
        return self
    
    def detatch(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
    def numpy(self) -> np.ndarray:
        assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
        assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
        return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)

    # TODO: if things are realized this won't work
    def to_(self, device: str):
        assert self.lazydata.realized is None
        self.lazydata.device = device
        if self.grad: self.grad.to_(device)
    
    def to(self, device: str) -> Tensor:
        ret = Tensor(self.lazydata, device)
        if self.grad: ret.grad = self.grad.to(device)
        return ret
    
    #! Creation llop entrypoint ====================

    @staticmethod
    def _loadop(op, sz, device: Optional[str] = None, dtype: Optional[DType] = None, arg = None, **kwargs):
        return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
    
    @staticmethod
    def empty(*shape, **kwargs):
        assert all_int(shape), f"cannot create with symbolic shape {shape}"
        return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
    
    _seed: int = int(time.time()) # interesting, using time for seed (ClassAtr)
    @staticmethod
    def manual_seed(seed=0): Tensor._seed = seed # set seed

    @staticmethod
    def rand(*shape, **kwargs):
        assert all_int(shape), f"cannot create with symbolic shape {shape}"
        Tensor._seed += 1
        return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
    
    # ! Creation helper functions ====================

    @staticmethod
    def full(shape: Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)

    @staticmethod
    def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)

    @staticmethod
    def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)

    @staticmethod
    def arrange(start, stop=None, step=1, **kwargs):
        if stop is None: stop, start = start, 0
        return Tensor.full(math.ceil((stop-start)/step), start, **kwargs).cumsum() + (start - step)
    
    @staticmethod
    def eye(dim: int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)

    def full_like(self, fill_value, **kwargs):
        return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
    def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
    def ones_like(self, **kwargs): return self.full_like(1, **kwargs)

    # ! rng hlops ====================

    @staticmethod
    def randn(*shape, dtype: Optional[DType] = None, **kwargs) -> Tensor:
        # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
        src = Tensor.rand(2, *shape, **kwargs)
        return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
    
    @staticmethod
    def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean

    @staticmethod
    def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
        dtype = kwargs.pop("dtype", Tensor.default_type)
        return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low

    @staticmethod
    def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)

    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
    @staticmethod
    def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)

    # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
    @staticmethod
    def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
        bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
        return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)

    # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
    @staticmethod
    def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
        std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
        return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
    
    # ! toposort and backward pass ====================
    def deepwalk(self):
        def _deepwalk(node, visited, nodes):
            visited.add(node)
            if getattr(node, "_ctx", None):
                for i in node._ctx.parents:
                    if i not in visited: _deepwalk(i, visited, nodes)
            nodes.append(node)
            return nodes
        return _deepwalk(self, set(), [])
    
    def backward(self):
        assert self.shape == tuple(), f"backward can only be called for scalar tensor, but it has shape {self.shape}"

        # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
        # this is "implicit gradient creation"
        self.grad = Tensor(1, device=self.device, requires_grad=False)

        # walk the graph in reverse
        for t0 in reversed(self.deepwalk()):
            assert (t0.grad is not None)
            grads = t0._ctx.backward(t0.grad.lazydata)
            grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
                        for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
            for t, g in zip(t0._ctx.parents, grads):
                if g is not None and t.requires_grad:
                    assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
                    t.grad = g if t.grad is None else (t.grad + g)
                del t0._ctx

    #  ! movement mlops ====================