In [1]:
from __future__ import annotations

import os
os.environ['DEBUG'] = '4'

import time, math
from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set

from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule

In [2]:
class Function:
    def __init__(self, device: str, *tensors: Tensor):
        self.device = device # device
        self.needs_input_grad = [t.requires_grad for t in tensors] # which tensors need grad
        self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False # if any tensor needs grad
        if self.requires_grad: self.parents = tensors  # parents if needed for backprop

    def forward(self, *args, **kwargs): raise NotImplementedError
    def backward(self, *args, **kwargs): raise NotImplementedError

    @classmethod
    def apply(fxn : Type[Function], *x: Tensor, **kwargs) -> Tensor:
        ctx = fxn(x[0].device, *x) # construct context
        ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
        if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
        return ret

In [4]:
import tinygrad.mlops as mlops

class Tensor:
    __slots__ = "lazydata", "requires_grad", "grad", "_ctx" # this specifically declates which attributes are allowed to save memory
    __deletable__ = ('_ctx',) # declare which attributes can be deleted to save memory
    training: ClassVar[bool] = True # are we training? i.e. do we need to compute gradients?

    # Context manager to enable/disable training (i.e. gradients)
    class train:
        def __enter__(self):
            self.prev = Tensor.training
            Tensor.training = True
        def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
            Tensor.training = self.prev # revert to previous state

    no_grad: ClassVar[bool] = False # always start with gradients enabled
    default_type: ClassVar[DType] = dtypes.float32 # default type for new tensors

    def __init__(self, data: Union[int, float, list, LazyBuffer, np.ndarray], device: Optional[str] = None, dtype: Optional[DType] = None, requires_grad: Optional[bool] = None):
        assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
        device = Device.canonicalize(device) # handles lower stuff
        # tensor have gradients, buffers do not
        self.grad: Optional[Tensor] = None # gradient

        # NOTE: this can be in three states. False and None: no gradient, True: gradient
        # None (the default) will be updated to True if it's put in an optimizer
        self.requires_grad: Optional[bool] = requires_grad 

        # internal variables used for autograd graph construction
        self._ctx: Optional[Function] = None # context for autograd

        # ! Logic to handle instantiation of different data
        if isinstance(data, LazyBuffer):
            assert dtype is None or dtype == data.dtype, f"dtype doesn't match, and casting isn't supported"
        elif isinstance(data, (int, float)): # if we're instantiating from a scalar
            data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
        elif data.__class__ is list:
            assert dtype is None or dtype.np is not None, f"{dtype} doesnt have a numpy dtype"
            data = LazyBuffer.fromCPU(np.array(data, dtype=(dtype or Tensor.default_type).np))
        elif isinstance(data, np.ndarray):
            assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
            if data.shape == (): # if we're instantiating from a scalar
                data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
            else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
        else: raise RuntimeError(f"can't create Tensor from {data}")

        # data is a LazyBuffer, but it might be on the wrong device
        self.lazydata = data if data.device == device else data.copy_to_device(device)
        #! note the input data is finally stored in the .lazydata attrb as a LazyBuffer

    def __repr__(self):
        return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"

    # Python has a non moving GC, so this should be okay
    def __hash__(self): return id(self)

    @property
    def device(self) -> str: return self.lazydata.device

    @property
    def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape

    @property
    def dtype(self) -> DType: return self.lazydata.dtype

    # ! data handlers ====================

    @staticmethod
    def corealize(lst: Iterable[Tensor]): # realize all tensors in a list
        seen: Set[LazyBuffer] = set()
        sched = []
        for t in lst: sched += t.lazydata.schedule(seen)
        run_schedule(sched)

    def realize(self) -> Tensor: # realize the tensor (compute all ops)
        run_schedule(self.lazydata.schedule())
        return self
    
    def assign(self, x) -> Tensor:
        # TODO: this is a hack for writing to DISK
        if self.device.startswith("DISK"):
            if x.__class__ is not Tensor: x = Tensor(x, device='CPU', dtype=self.dtype) # make tensor
            self.contiguous().realize().lazydata.realized_copyin(x.numpy())
            return self 
        if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
        assert self.shape == x.shape and self.device == self.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
        assert not x.requires_grad
        if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
        if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
        self.lazydata = x.lazydata
        return self
    
    def detatch(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
    def numpy(self) -> np.ndarray:
        assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
        assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
        return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)

    # TODO: if things are realized this won't work
    def to_(self, device: str):
        assert self.lazydata.realized is None
        self.lazydata.device = device
        if self.grad: self.grad.to_(device)
    
    def to(self, device: str) -> Tensor:
        ret = Tensor(self.lazydata, device)
        if self.grad: ret.grad = self.grad.to(device)
        return ret
    
    #! Creation llop entrypoint ====================

    @staticmethod
    def _loadop(op, sz, device: Optional[str] = None, dtype: Optional[DType] = None, arg = None, **kwargs):
        return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
    
    @staticmethod
    def empty(*shape, **kwargs):
        assert all_int(shape), f"cannot create with symbolic shape {shape}"
        return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
    
    _seed: int = int(time.time()) # interesting, using time for seed (ClassAtr)
    @staticmethod
    def manual_seed(seed=0): Tensor._seed = seed # set seed

    @staticmethod
    def rand(*shape, **kwargs):
        assert all_int(shape), f"cannot create with symbolic shape {shape}"
        Tensor._seed += 1
        return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
    
    # ! Creation helper functions ====================

    @staticmethod
    def full(shape: Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)

    @staticmethod
    def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)

    @staticmethod
    def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)

    @staticmethod
    def arrange(start, stop=None, step=1, **kwargs):
        if stop is None: stop, start = start, 0
        return Tensor.full(math.ceil((stop-start)/step), start, **kwargs).cumsum() + (start - step)
    
    @staticmethod
    def eye(dim: int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)

    def full_like(self, fill_value, **kwargs):
        return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
    def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
    def ones_like(self, **kwargs): return self.full_like(1, **kwargs)

    # ! rng hlops ====================

    @staticmethod
    def randn(*shape, dtype: Optional[DType] = None, **kwargs) -> Tensor:
        # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
        src = Tensor.rand(2, *shape, **kwargs)
        return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
    
    @staticmethod
    def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean

    @staticmethod
    def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
        dtype = kwargs.pop("dtype", Tensor.default_type)
        return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low

    @staticmethod
    def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)

    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
    @staticmethod
    def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)

    # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
    @staticmethod
    def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
        bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
        return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)

    # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
    @staticmethod
    def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
        std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
        return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
    
    # ! toposort and backward pass ====================
    def deepwalk(self):
        def _deepwalk(node, visited, nodes):
            visited.add(node)
            if getattr(node, "_ctx", None):
                for i in node._ctx.parents:
                    if i not in visited: _deepwalk(i, visited, nodes)
            nodes.append(node)
            return nodes
        return _deepwalk(self, set(), [])
    
    def backward(self):
        assert self.shape == tuple(), f"backward can only be called for scalar tensor, but it has shape {self.shape}"

        # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
        # this is "implicit gradient creation"
        self.grad = Tensor(1, device=self.device, requires_grad=False)

        # walk the graph in reverse
        for t0 in reversed(self.deepwalk()):
            assert (t0.grad is not None)
            grads = t0._ctx.backward(t0.grad.lazydata)
            grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
                        for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
            for t, g in zip(t0._ctx.parents, grads):
                if g is not None and t.requires_grad:
                    assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
                    t.grad = g if t.grad is None else (t.grad + g)
                del t0._ctx

    #  ! movement mlops ====================
    def reshape(self, shape, *args) -> Tensor:
        new_shape = argfix(shape, *args)  # Standardize shape argument with argfix
        assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}" # Ensure no zero dimensions in new shape
        # If dimension is -1, auto-calculate it to maintain total element count
        return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s  for s in new_shape]))

    def expand(self, shape, *args) -> Tensor:
        # Use argfix to standardize the shape, then expand dimensions
        # If dimension is -1, keep the original size in that dimension
        return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s, x in zip(self.shape, argfix(shape, *args))]))

    def permute(self, order, *args) -> Tensor:
        # Permute dimensions based on the order given, standardized by argfix
        return mlops.Permute.apply(self, order=argfix(order, *args))

    def flip(self, axis, *args) -> Tensor:
        # Flip the tensor along specified axes.
        # Negative axis values are adjusted to positive by adding len(self.shape)
        return mlops.Flip.apply(self, axis=[x if x >= 0 else x + len(self.shape) for x in argfix(axis, *args)])

    def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> Tensor:
        # Shrink tensor only if the shrinking arg is different from the tensor's shape
        return mlops.Shrink.apply(self, arg=arg) if any(x != (0, s) for x, s in zip(arg, self.shape)) else self

    def pad(self, arg: Tuple[Tuple[int, int], ...], value: float = 0) -> Tensor:
        # Apply padding to tensor; if arg is all zeros, return the original tensor
        ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
        # If value is non-zero, add it to the padding
        return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)

    # ***** movement hlops *****

    # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
    # - A slice i:j returns the elements with indices in [i, j)
    #    - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
    #    - Negative values for i and j are taken relative to the end of the sequence
    #    - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
    # - Indexing with None on a given axis will add a new dimension of size one before that axis
    # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
    # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
    # - Strides > 1 and < 0 are now allowed!:
    #    - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
    #    - Idea of stride < 0 support:
    #        - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
    #    - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
    #        - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
    #        - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
    #          is possible.
    #        - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
    # - Fancy indexing and combined indexing is supported
    #    - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing
    #    - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively
    #        - The first iteration will expand the dim of self while consecutive iterations will reduce the dim
    #    - There's a special case where a permute is needed at the end:
    #        - if first Tensor passed in (expand dims) is not at dim 0
    #        - and following Tensors does not follow consecutively to the end of fancy indexing's dims

    def __getitem__(self, val):
        def normalize_int(e, i, dim_sz): # handle bounds
            if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
            raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
        
        orig_slices = list(val) if isinstance(val, tuple) else [val] # store the original slices arg
        count = defaultdict(list) 
        for i, v in enumerate(orig_slices):
            count[type(v)].append(i) # count the number of each type of slice

        # check for too many slices
        if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor])) > len(self.shape):
            raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
        
        # check for too many ellipses
        if len(ellipsis_found := count[type(Ellipsis)]) > 1:
            raise IndexError(f"an index can only have a single ellipsis ('...')")
        
        ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices) # find the index of the ellipsis
        orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices + 1) # replace the ellipsis with slices

        # Create a list of valid slices to apply to the tensor.
        # - Use slice directly if it is a slice object.
        # - Convert integer indices to slices after normalizing them.
        # - Use a full slice (slice(None)) for other types (like Tensor).
        valid_slices = [
            v if isinstance(v, slice)  # Keep slice as-is
            else slice(y_ := normalize_int(v, i, dim_sz), y_ + 1)  # Convert int to slice
            if isinstance(v, int) 
            else slice(None)  # For other types, use a full slice
            for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))  # Iterate over dimensions
        ]

        # Unpack indices into start, stop, strides based on slices and shape; set to empty if no valid slices
        start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
        # Adjust start and end for each slice based on stride direction
        new_slice = tuple((s, e) if st > 0 else (e + 1, s + 1) for s, e, st in zip(start, stop, strides))
        # Apply shrink and flip operations based on calculated slice; flip on negative strides
        sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
        new_shape = sliced_tensor.shape  # Get the shape of the newly sliced tensor

        if any(abs(s) != 1 for s in strides): # Check if any stride is not equal to 1 (in absolute value)
            strides = tuple(abs(s) != 1 for s in strides) # Store which strides are not 1 as a tuple of booleans
            # Pad the tensor to match new strides; if dimension size % stride != 0, padding needed
            padded_tensor = sliced_tensor.pad(tuple((0, s - (dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
            # Reshape tensor to accommodate strides, splitting dimensions
            reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
            new_shape = reshaped_tensor.shape[::2] # Get the shape for dimensions affected by the non-1 strides
            sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))) # Shrink tensor to finalize operation, effectively applying the strides

        final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0
        # Iterate over original slices to compute the final shape and other parameters
        for i, s in enumerate(orig_slices):
            if s is None:
                final_shape.append(1)  # Add a dimension of size 1 if slice is None
            else:  # s can be int, slice, or Tensor
                dim_shape = next(it_shape)  # Fetch next shape dimension
                if isinstance(s, int): dim_collapsed += 1  # Count how many dimensions are collapsed
                else:
                    assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
                    final_shape.append(dim_shape)  # Add to final shape
                    if isinstance(s, Tensor): # If slice is a Tensor, store it and its adjusted dimension
                        tensors.append(s)
                        dim.append(i - dim_collapsed)

        # Reshape the sliced tensor to its final shape
        ret = sliced_tensor.reshape(tuple(final_shape))

        if tensors: # Fancy/tensor indexing
            # normalize idx
            # TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
            idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)]
            max_dim = max(i.ndim for i in idx)
            # compute sum_dim, arange, and idx
            sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)]
            arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))]
            first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))]
            rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)]
            idx = first_idx + rest_idx
            ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
            # iteratively fancy index
            for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
            # special permute case
            if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)):
                ret_dims = list(range(ret.ndim))
                ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:])
            return ret
        
    def __setitem__(self,s,v): return self.__getitem__(s).assign(v)

    # NOTE: using slice is discouraged and things should migrate to pad and shrink
    def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
        arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
        padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
        return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
    
    def gather(self: Tensor, idx: Tensor, dim: int):
        assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
        assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
        if dim < 0: dim += self.ndim
        idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
        permarg = list(range(self.ndim))
        permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
        return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)

    def cat(self, *args, dim=0):
        dim = (dim + len(self.shape)) if dim < 0 else dim
        assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
        catargs = [self, *args]
        assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
        shapes = [s.shape[dim] for s in catargs]
        shape_cumsum = [0, *accumulate(shapes)]
        slc = [[(0, 0) for _ in self.shape] for _ in catargs]
        for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
            s[dim] = (k, shape_cumsum[-1] - k - shp)
        return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])

    @staticmethod
    def stack(tensors, dim=0):
        first = tensors[0].unsqueeze(dim)
        unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]]
        # checks for shapes and number of dimensions delegated to cat
        return first.cat(*unsqueezed_tensors, dim=dim)

    def repeat(self, repeats):
        base_shape = (1,) * (len(repeats) - self.ndim) + self.shape
        new_shape = [x for b in base_shape for x in [1, b]]
        expand_shape = [x for rs in zip(repeats, base_shape) for x in rs]
        final_shape = [r*s for r,s in zip(repeats, base_shape)]
        return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)

    def chunk(self, num:int, dim:int) -> List[Tensor]:
        assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
        dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
        slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
        return [self[tuple(sl)] for sl in slice_params]

    def squeeze(self, dim=None):
        if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])
        if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior
        if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})")
        if dim < 0: dim += self.ndim
        return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim])

    def unsqueeze(self, dim):
        if dim < 0: dim = len(self.shape) + dim + 1
        return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])

    # (padding_left, padding_right, padding_top, padding_bottom)
    def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0):
        slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
        return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)

    @property
    def T(self) -> Tensor: return self.transpose()
    def transpose(self, ax1=1, ax2=0) -> Tensor:
        order = list(range(len(self.shape)))
        order[ax1], order[ax2] = order[ax2], order[ax1]
        return self.permute(order)

    def flatten(self, start_dim=0):
        return self.reshape(shape=self.shape[:start_dim] + (-1,))

    # ! reduce ops ====================

    def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
        axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
        axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
        shape = [s for i,s in enumerate(self.shape) if i not in axis_]
        ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
        return ret if keepdim else ret.reshape(shape=shape)

    def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
    def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
    def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))

    def mean(self, axis=None, keepdim=False):
        assert all_int(self.shape), "does not support symbolic shape"
        out = self.sum(axis=axis, keepdim=keepdim)
        return out.mul(prod(out.shape)/prod(self.shape))

    def std(self, axis=None, keepdim=False, correction=1):
        assert all_int(self.shape), "does not support symbolic shape"
        square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
        return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt()
    
    def _softmax(self, axis):
        m = self - self.max(axis=axis, keepdim=True)
        e = m.exp()
        return m, e, e.sum(axis=axis, keepdim=True)
    
    def softmax(self, axis=-1):
        _, e, ss = self._softmax(axis)
        return e.div(ss)
    
    def log_softmax(self, axis=-1):
        m, _, ss = self._softmax(axis)
        return m - ss.log()

    def argmax(self, axis=None, keepdim=False):
        if axis is None:
            idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
            return prod(self.shape) - idx.max() - 1
        axis = axis + len(self.shape) if axis < 0 else axis
        m = self == self.max(axis=axis, keepdim=True)
        idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
        return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1

    def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)

    # ! processing ops ====================

In [12]:
from tinygrad.helpers import argfix, prod

# experiments

argfix(0,)

(0,)

In [16]:
# lets learn walrus

# https://www.python.org/dev/peps/pep-0572/

# example 1
if (len_ := len([1,2,3])) > 2:
    print(len_)

3


In [14]:
(n := len([1,2,3]))

3