Skip to content

fangpin/dl-framework

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

Have you ever wondered what really happens behind the scenes when you call .backward() to compute gradients in PyTorch or TensorFlow?

We use these mature deep learning tools every day, but few of us truly explore their underlying implementations—the magic of automatic differentiation, the construction of computational graphs, the optimization of tensor operations... These core principles hidden behind the APIs are the true foundations of deep learning.

Today, I'm excited to introduce Needle—a lightweight deep learning framework implemented from scratch. It's not meant to replace existing mature frameworks, but rather to help you uncover the mysteries of deep learning:

  • You'll witness firsthand how automatic differentiation is implemented through computational graphs and the chain rule
  • You'll understand the underlying logic and implementation principles of tensor operations
  • You'll learn how to build your own deep learning components from 0 to 1

Needle has a clear code structure with detailed comments, and all features are implemented from scratch. It's the perfect project for learning the core principles of deep learning.

Core Features of Needle

Despite being lightweight, Needle implements all the core features of a deep learning framework:

đź§® Automatic Differentiation System

Automatic differentiation is one of the core technologies of deep learning. Needle implements an automatic differentiation system based on computational graphs. When you perform tensor operations, the framework automatically builds a computational graph, recording all operations and dependencies. When gradients need to be computed, the framework propagates backward along the computational graph, automatically calculating the gradient for each parameter without manual derivation.

âž• Rich Tensor Operations

Needle supports various basic tensor operations, including:

  • Element-wise operations (Add, Mul, Div, Pow, etc.)
  • Matrix operations (MatMul, Transpose, Dot, etc.)
  • Reduction operations (Sum, Mean, Max, etc.)
  • Transformation operations (Reshape, Broadcast, Concatenate, etc.)

These operations follow a similar API design to mainstream frameworks like PyTorch, making it easy to get started while maintaining code simplicity and understandability.

🔌 Multi-backend Support

Needle supports two backends:

  • NumPy backend: For rapid prototyping and CPU computation, easy for debugging and learning
  • Custom NDArray backend: Supports CPU and CUDA acceleration, providing higher performance

You can easily switch backends using the NEEDLE_BACKEND environment variable to meet the needs of different scenarios.

📦 Neural Network Modules

Needle provides complete neural network module support, including:

  • Basic components like linear layers (Linear), convolutional layers (Conv), etc.
  • Activation functions like ReLU, Sigmoid, Softmax, etc.
  • Loss functions like CrossEntropyLoss, MSELoss, etc.

These modules seamlessly integrate with the automatic differentiation system, making it convenient to build various complex neural network models.

🎯 Optimization Algorithms

Needle implements multiple commonly used optimization algorithms, including:

  • Stochastic Gradient Descent (SGD)
  • Momentum
  • RMSprop
  • Adam

These optimization algorithms can be directly used for training neural network models, supporting features like weight decay and learning rate scheduling.

📊 Data Loading

Needle provides data loading and preprocessing support, including:

  • Custom Dataset interface
  • DataLoader for batch loading
  • Common data transformation operations

These features help users easily handle various datasets and improve training efficiency.

Next, let's dive deep into the deep learning principles behind Needle. I'll use simple examples and code to explain core concepts like automatic differentiation and computational graphs.

Computational Graph & Automatic Differentiation from Scratch

Early machine learning typically started with manual gradient calculation, but as networks became more complex, manual gradient calculation became increasingly difficult.

To solve this problem, automatic differentiation was proposed. Automatic differentiation is a method to compute the gradient of a function at any point automatically.

Automatic differentiation represents all computational operations as a computational graph, where values (including tensors) are nodes in the computational graph, and operations (ops) are edges in the graph (one op may correspond to multiple edges). Gradient calculation is then converted to backpropagation on the computational graph. Finally, using iterative calculation of the chain rule, the complete gradient is computed step by step.

Let's implement an automatic differentiation framework from scratch. The complete code can be found here.

First, define the edges of the computational graph, i.e., the operations (ops) themselves. The abstract class is very simple:

class Op:
    """Operator definition."""
    def __call__(self, *args):
        raise NotImplementedError()

    def compute(self, *args: Tuple[NDArray]):
        """
        Calculate forward pass of operator.

        Parameters
        ----------
        input: np.ndarray
            A list of input arrays to the function

        Returns
        -------
        output: nd.array
            Array output of the operation
        """
        raise NotImplementedError()

    def gradient(
        self, out_grad: "Value", node: "Value"
    ) -> Union["Value", Tuple["Value"]]:
        """
        Compute partial adjoint for each input value for a given output adjoint.

        Parameters
        ----------
        out_grad: Value
            The adjoint wrt to the output value.

        node: Value
            The value node of forward evaluation.

        Returns
        -------
        input_grads: Value or Tuple[Value]
            A list containing partial gradient adjoints to be propagated to
            each of the input node.
        """
        raise NotImplementedError()

compute and gradient are the calculations in two directions on this edge. compute is the forward calculation, and gradient is the backward calculation. compute depends on all inputs of this operation as parameters and calculates the value of the output node; gradient depends on all inputs of this operation and the adjoint of the final output as parameters, and calculates the adjoint of all input nodes of this operation.

Deep learning usually deals with tensor values rather than single values, so we define a set of operations (ops) specifically for handling tensors.

class TensorOp(Op):
    """Op class specialized to output tensors, will be alternate subclasses for other structures"""
    def __call__(self, *args):
        return Tensor.make_from_op(self, args)

class TensorTupleOp(Op):
    """Op class specialized to output TensorTuple"""
    def __call__(self, *args):
        return TensorTuple.make_from_op(self, args)

Next, define the nodes of the computational graph, i.e., the values themselves. The abstract class is also very simple:

class Value:
    """A value in the computational graph."""

    # trace of computational graph
    op: Optional[Op]
    inputs: List["Value"]
    # The following fields are cached fields for
    # dynamic computation
    cached_data: NDArray
    requires_grad: bool

    def realize_cached_data(self):
        """Run compute to realize the cached data"""
        # avoid recomputation
        if self.cached_data is not None:
            return self.cached_data
        # note: data implicitly calls realized cached data
        self.cached_data = self.op.compute(
            *[x.realize_cached_data() for x in self.inputs]
        )
        return self.cached_data

    def is_leaf(self):
        return self.op is None

    def __del__(self):
        global TENSOR_COUNTER
        TENSOR_COUNTER -= 1

    def _init(
        self,
        op: Optional[Op],
        inputs: List["Tensor"],
        *, num_outputs: int = 1,
        cached_data: List[object] = None,
        requires_grad: Optional[bool] = None,
    ):
        global TENSOR_COUNTER
        TENSOR_COUNTER += 1
        if requires_grad is None:
            requires_grad = any(x.requires_grad for x in inputs)
        self.op = op
        self.inputs = inputs
        self.num_outputs = num_outputs
        self.cached_data = cached_data
        self.requires_grad = requires_grad

    @classmethod
    def make_const(cls, data, *, requires_grad=False):
        value = cls.__new__(cls)
        value._init(
            None,
            [],
            cached_data=data,
            requires_grad=requires_grad,
        )
        return value

    @classmethod
    def make_from_op(cls, op: Op, inputs: List["Value"]):
        value = cls.__new__(cls)
        value._init(op, inputs)

        if not LAZY_MODE:
            if not value.requires_grad:
                return value.detach()
            value.realize_cached_data()
        return value

The definition of Value connects the entire computational graph. Each value node records which operation (op) produced this value and all input nodes (also value nodes) of this operation. A cache area is defined to avoid repeated calculations, along with a flag indicating whether gradient calculation is needed. The is_leaf method is used to determine whether this value node is a leaf node, i.e., whether it has no input nodes. This information can be used to topologically sort the computational graph, thereby completing the calculation of forward and backward propagation.

Next, let's introduce the main character, the definition of Tensor, the most commonly used Value.

class Tensor(Value):
    grad: "Tensor"

    def __init__(
        self,
        array,
        *, device: Optional[Device] = None,
        dtype=None,
        requires_grad=True,
        **kwargs,
    ):
        if isinstance(array, Tensor):
            if device is None:
                device = array.device
            if dtype is None:
                dtype = array.dtype
            if device == array.device and dtype == array.dtype:
                cached_data = array.realize_cached_data()
            else:
                # fall back, copy through numpy conversion
                cached_data = Tensor._array_from_numpy(
                    array.numpy(), device=device, dtype=dtype
                )
        else:
            device = device if device else cpu()
            cached_data = Tensor._array_from_numpy(array, device=device, dtype=dtype)

        self._init(
            None,
            [],
            cached_data=cached_data,
            requires_grad=requires_grad,
        )

    @staticmethod
    def make_const(data, requires_grad=False):
        tensor = Tensor.__new__(Tensor)
        tensor._init(
            None,
            [],
            cached_data=(data if not isinstance(data, Tensor) else data.realize_cached_data()),
            requires_grad=requires_grad,
        )
        return tensor

    @property
    def data(self):
        return self.detach()

    @property
    def shape(self):
        return self.realize_cached_data().shape

    @property
    def dtype(self):
        return self.realize_cached_data().dtype

    @property
    def device(self):
        data = self.realize_cached_data()
        # numpy array always sits on cpu
        if array_api is numpy:
            return cpu()
        return data.device

    def backward(self, out_grad=None):
        out_grad = (
            out_grad
            if out_grad
            else init.ones(*self.shape, dtype=self.dtype, device=self.device)
        )
        compute_gradient_of_variables(self, out_grad)

    def __add__(self, other):
        if isinstance(other, Tensor):
            return needle.ops.EWiseAdd()(self, other)
        else:
            return needle.ops.AddScalar(other)(self)

Here, the device is used to specify the computing device, supporting CPU and GPU. We will implement different computing backends based on CPU and GPU separately later. Tensor itself needs to support various matrix numerical calculations, including addition, subtraction, multiplication, division, matrix multiplication, summation, broadcasting, reshape, etc. For simplicity, only the implementation of addition and scalar addition is defined here. For the complete code, please refer to autograd.py.

The compute_gradient_of_variables function above is used to compute the backpropagation of a tensor. We need to traverse the computational graph where this tensor is located and perform backpropagation in reverse topological order. The main implementation is as follows:

def compute_gradient_of_variables(output_tensor, out_grad):
    """
    Take gradient of output node with respect to each node in node_list.

    Store the computed result in the grad field of each Variable.
    """
    # a map from node to a list of gradient contributions from each output node
    node_to_output_grads_list: Dict[Tensor, List[Tensor]] = {}
    # Special note on initializing gradient of
    # We are really taking a derivative of the scalar reduce_sum(output_node)
    # instead of the vector output_node. But this is the common case for loss function.
    node_to_output_grads_list[output_tensor] = [out_grad]

    # Traverse graph in reverse topological order given the output_node that we are taking gradient wrt.
    reverse_topo_order = list(reversed(find_topo_sort([output_tensor])))

    for tensor in reverse_topo_order:
        tensor_grad = sum_node_list(node_to_output_grads_list[tensor])
        tensor.grad = tensor_grad
        # leaf node
        if tensor.op is None:
            continue
        inputs_out_grads = tensor.op.gradient_as_tuple(tensor_grad, tensor)
        for input, out_grad in zip(tensor.inputs, inputs_out_grads):
            if input not in node_to_output_grads_list:
                node_to_output_grads_list[input] = []
            node_to_output_grads_list[input].append(out_grad)


def find_topo_sort(node_list: List[Value]) -> List[Value]:
    """
    Given a list of nodes, return a topological sort list of nodes ending in them.

    A simple algorithm is to do a post-order DFS traversal on the given nodes,
    going backwards based on input edges. Since a node is added to the ordering
    after all its predecessors are traversed due to post-order DFS, we get a topological
    sort.
    """
    visited = set()
    order = []
    for node in node_list:
        topo_sort_dfs(node, visited, order)
    return order


def topo_sort_dfs(node, visited, topo_order):
    """Post-order DFS"""
    if node not in visited:
        visited.add(node)
        for n in node.inputs:
            topo_sort_dfs(n, visited, topo_order)
        topo_order.append(node)
        """Post-order DFS"""

We have completed the construction of the entire computational graph. Next, we need to refine the forward and backward propagation logic for various operations (ops) that tensors need to support. The data inside tensors is matrices, so they support various common matrix operations. The implementation here is mainly manual work, using the underlying matrix calculations to implement forward and backward propagation. For the specific complete implementation, please refer to ops_logarithmic.py and ops_mathematic.py.

Matrix Computation from Scratch

Matrix representation: No matter what dimension (shape) the matrix is, we can represent it as a one-dimensional array, but we need to calculate the position of each element according to strides.

Understanding the meaning of strides is the key to efficient matrix computation. Strides represent the step size in each dimension, i.e., the offset of each element in memory. For example, for a $2 \times 3$ matrix, its strides are $[3, 1]$, which means that the offset of each element in memory is $3$ element sizes (if the element size is $4$ bytes), and the offset of each element in the second dimension is $1$ element size. If we need to transpose this $2 \times 3$ matrix to get a $3 \times 2$ matrix, we only need to swap the strides of the original matrix, i.e., $[3, 1] \rightarrow [1, 3]$. There is no need to copy the underlying matrix elements at all. Similar operations include but are not limited to:

  • transpose
  • reshape
  • broadcast
  • view
  • squeeze/unsqueeze
  • slice
  • flip
  • reverse
  • permute

Regardless of whether it's CPU or GPU media, the key to optimizing matrix operations is how to make good use of faster storage.

  • For CPU: How to implement cache-friendly matrix operations to reduce main memory access caused by cache misses.
  • For GPU: How to make good use of shared memory to reduce access to GPU global memory. We will introduce the implementation of matrix operations on CPU and GPU in detail below.

CPU Matrix Operation Implementation

The complete code can be found here.

Use matrix tiled multiplication to decompose matrix multiplication into multiple small matrix multiplications. Each small matrix multiplication can be calculated in a cache-friendly way to reduce main memory access caused by cache misses. The size of each small matrix multiplication is $TILE \times TILE$, where $TILE$ is a hyperparameter, generally 8 or 16. $TILE$ is set to 8 or 16 because the cache line size of modern CPUs is 64 bytes, i.e., two 8 float32 elements.

#define ALIGNMENT 256
#define TILE 8
typedef float scalar_t;
const size_t ELEM_SIZE = sizeof(scalar_t);

/**
 * This is a utility structure for maintaining an array aligned to ALIGNMENT
 * boundaries in memory.  This alignment should be at least TILE * ELEM_SIZE,
 * though we make it even larger here by default.
 */
struct AlignedArray {
  AlignedArray(const size_t size) {
    int ret = posix_memalign((void **)&ptr, ALIGNMENT, size * ELEM_SIZE);
    if (ret != 0)
      throw std::bad_alloc();
    this->size = size;
  }
  ~AlignedArray() { free(ptr); }
  size_t ptr_as_int() { return (size_t)ptr; }
  scalar_t *ptr;
  size_t size;
};


void MatmulTiled(const AlignedArray &a, const AlignedArray &b,
                 AlignedArray *out, uint32_t m, uint32_t n, uint32_t p) {
  /**
   * Matrix multiplication on tiled representations of array.  In this
   * setting, a, b, and out are all *4D* compact arrays of the appropriate
   * size, e.g. a is an array of size a[m/TILE][n/TILE][TILE][TILE] You should
   * do the multiplication tile-by-tile to improve performance of the array
   * (i.e., this function should call `AlignedDot()` implemented above).
   *
   * Note that this function will only be called when m, n, p are all
   * multiples of TILE, so you can assume that this division happens without
   * any remainder.
   *
   * Args:
   *   a: compact 4D array of size m/TILE x n/TILE x TILE x TILE
   *   b: compact 4D array of size n/TILE x p/TILE x TILE x TILE
   *   out: compact 4D array of size m/TILE x p/TILE x TILE x TILE to write to
   *   m: rows of a / out
   *   n: columns of a / rows of b
   *   p: columns of b / out
   *
   */
  for (size_t i = 0; i < m * p; ++i) {
    out->ptr[i] = 0.;
  }
  for (size_t i = 0; i < m / TILE; i++) {
    for (size_t j = 0; j < p / TILE; j++) {
      float *_out = out->ptr + i * p * TILE + j * TILE * TILE;
      for (size_t k = 0; k < n / TILE; k++) {
        const float *_a = a.ptr + i * n * TILE + k * TILE * TILE;
        const float *_b = b.ptr + k * p * TILE + j * TILE * TILE;
        AlignedDot(_a, _b, _out);
      }
    }
  }
}

GPU Matrix Operation Implementation

The complete code can be found here.

Use matrix tiled multiplication to decompose matrix multiplication into multiple small matrix multiplications. Each small matrix multiplication can be calculated in a shared memory-friendly way to reduce access to GPU global memory. The size of each small matrix multiplication is $TILE \times TILE$, where $TILE$ is a hyperparameter, generally 8 or 16. $TILE$ is set to 8 or 16 because the shared memory line size of modern GPUs is 128 bytes, i.e., two 8 float32 elements.

#define BASE_THREAD_NUM 256

#define TILE 4
typedef float scalar_t;
const size_t ELEM_SIZE = sizeof(scalar_t);

struct CudaArray {
  CudaArray(const size_t size) {
    cudaError_t err = cudaMalloc(&ptr, size * ELEM_SIZE);
    if (err != cudaSuccess)
      throw std::runtime_error(cudaGetErrorString(err));
    this->size = size;
  }
  ~CudaArray() { cudaFree(ptr); }
  size_t ptr_as_int() { return (size_t)ptr; }

  scalar_t *ptr;
  size_t size;
};

struct CudaDims {
  dim3 block, grid;
};

CudaDims CudaOneDim(size_t size) {
  /**
   * Utility function to get cuda dimensions for 1D call
   */
  CudaDims dim;
  size_t num_blocks = (size + BASE_THREAD_NUM - 1) / BASE_THREAD_NUM;
  dim.block = dim3(BASE_THREAD_NUM, 1, 1);
  dim.grid = dim3(num_blocks, 1, 1);
  return dim;
}

#define MAX_VEC_SIZE 8
struct CudaVec {
  uint32_t size;
  int32_t data[MAX_VEC_SIZE];
};

CudaVec VecToCuda(const std::vector<int32_t> &x) {
  CudaVec shape;
  if (x.size() > MAX_VEC_SIZE)
    throw std::runtime_error("Exceeded CUDA supported max dimesions");
  shape.size = x.size();
  for (size_t i = 0; i < x.size(); i++) {
    shape.data[i] = x[i];
  }
  return shape;
}


__global__ void MatmulKernel(const scalar_t *a, const scalar_t *b,
                             scalar_t *out, uint32_t M, uint32_t N,
                             uint32_t P) {
  __shared__ scalar_t a_tile[TILE][TILE];
  __shared__ scalar_t b_tile[TILE][TILE];

  size_t thread_x = threadIdx.x;
  size_t thread_y = threadIdx.y;

  size_t block_x = blockIdx.x;
  size_t block_y = blockIdx.y;

  size_t x = thread_x + block_x * blockDim.x;
  size_t y = thread_y + block_y * blockDim.y;

  size_t cnt = (N + TILE - 1) / TILE;
  scalar_t sum = 0;
  for (size_t i = 0; i < cnt;
       i++) // Traverse the TILE block of a certain row of a and a certain column of b, accumulate the corresponding positions to get the out of this block
  {
    if ((i * TILE + thread_y) < N) // Prevent out-of-bounds
    {
      a_tile[thread_x][thread_y] = a[x * N + i * TILE + thread_y];
    }
    if ((i * TILE + thread_x) < N) {
      b_tile[thread_x][thread_y] = b[(i * TILE + thread_x) * P + y];
    }

    __syncthreads();

    if (x < M && y < P) {
      for (size_t j = 0; j < TILE; j++)
        if (i * TILE + j < N)
          sum += a_tile[thread_x][j] * b_tile[j][thread_y];
    }

    __syncthreads();
  }

  if (x < M && y < P)
    out[x * P + y] = sum;
}

void Matmul(const CudaArray &a, const CudaArray &b, CudaArray *out, uint32_t M,
            uint32_t N, uint32_t P) {
  /**
   * Multiply two (compact) matrices into an output (also comapct) matrix.  You
   * will want to look at the lecture and notes on GPU-based linear algebra to
   * see how to do this.  Since ultimately mugrade is just evaluating
   * correctness, you _can_ implement a version that simply parallelizes over
   * (i,j) entries in the output array.  However, to really get the full benefit
   * of this problem, we would encourage you to use cooperative fetching, shared
   * memory register tiling, and other ideas covered in the class notes.  Note
   * that unlike the tiled matmul function in the CPU backend, here you should
   * implement a single function that works across all size matrices, whether or
   * not they are a multiple of a tile size.  As with previous CUDA
   * implementations, this function here will largely just set up the kernel
   * call, and you should implement the logic in a separate MatmulKernel() call.
   *
   *
   * Args:
   *   a: compact 2D array of size m x n
   *   b: comapct 2D array of size n x p
   *   out: compact 2D array of size m x p to write the output to
   *   M: rows of a / out
   *   N: columns of a / rows of b
   *   P: columns of b / out
   */

  dim3 block = dim3(TILE, TILE, 1);
  size_t grid_x = (M + TILE - 1) / TILE;
  size_t grid_y = (P + TILE - 1) / TILE;
  dim3 grid = dim3(grid_x, grid_y, 1);
  MatmulKernel<<<grid, block>>>(a.ptr, b.ptr, out->ptr, M, N, P);
}

Triton Matrix Operation Implementation

You can also implement matrix multiplication using Triton, a Python-based compiler for compiling Python code into CUDA kernels. If you're interested in Triton, you can refer to some of my previous articles:

Building the Deep Learning Framework

The above content connects the basic underlying elements of a complete deep learning framework: including automatic differentiation calculation, and CPU and GPU implementations of matrix calculation. Next, we will complete other basic elements in the deep learning framework to build a basic ready-to-use deep learning framework.

Model Parameter

It's essentially a Tensor, wrapped to facilitate nn Module management.

class Parameter(Tensor):
    """A special kind of tensor that represents parameters."""


def _unpack_params(value: object) -> List[Tensor]:
    if isinstance(value, Parameter):
        return [value]
    elif isinstance(value, Module):
        return value.parameters()
    elif isinstance(value, dict):
        params = []
        for k, v in value.items():
            params += _unpack_params(v)
        return params
    elif isinstance(value, (list, tuple)):
        params = []
        for v in value:
            params += _unpack_params(v)
        return params
    else:
        return []


def _child_modules(value: object) -> List["Module"]:
    if isinstance(value, Module):
        modules = [value]
        modules.extend(_child_modules(value.__dict__))
        return modules
    if isinstance(value, dict):
        modules = []
        for k, v in value.items():
            modules += _child_modules(v)
        return modules
    elif isinstance(value, (list, tuple)):
        modules = []
        for v in value:
            modules += _child_modules(v)
        return modules
    else:
        return []

nn Module

nn.Module is the base class for all neural network modules. A Module is a unified unit that encapsulates model structure, parameters, and computation.

Each basic model or custom model is defined by inheriting nn.Module.

It can contain learnable parameters (automatically registered through nn.Parameter or submodules like nn.Linear) and forward computation logic.

Module supports hierarchical nesting (a Module can contain other Modules), automatically managing parameters, states (like training/evaluation mode), and devices (CPU/GPU).

The key code is as follows. For the complete code, please refer to nn_basic.py.

class Module:
    def __init__(self):
        self.training = True

    def parameters(self) -> List[Tensor]:
        """Return the list of parameters in the module."""
        return _unpack_params(self.__dict__)

    def _children(self) -> List["Module"]:
        return _child_modules(self.__dict__)

    def eval(self):
        self.training = False
        for m in self._children():
            m.training = False

    def train(self):
        self.training = True
        for m in self._children():
            m.training = True

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

def _unpack_params(value: object) -> List[Tensor]:
    if isinstance(value, Parameter):
        return [value]
    elif isinstance(value, Module):
        return value.parameters()
    elif isinstance(value, dict):
        params = []
        for k, v in value.items():
            params += _unpack_params(v)
        return params
    elif isinstance(value, (list, tuple)):
        params = []
        for v in value:
            params += _unpack_params(v)
        return params
    else:
        return []


def _child_modules(value: object) -> List["Module"]:
    if isinstance(value, Module):
        modules = [value]
        modules.extend(_child_modules(value.__dict__))
        return modules
    if isinstance(value, dict):
        modules = []
        for k, v in value.items():
            modules += _child_modules(v)
        return modules
    elif isinstance(value, (list, tuple)):
        modules = []
        for v in value:
            modules += _child_modules(v)
        return modules
    else:
        return []

Common Built-in Modules

To make it ready to use, we directly cosplay PyTorch and implement some common neural network modules, such as Linear, ReLU, Softmax, etc.

The simplest Linear layer:

class Linear(Module):
    def __init__(
        self, in_features, out_features, bias=True, device=None, dtype="float32"
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(
            init.kaiming_uniform(
                in_features,
                out_features,
                device=device,
                dtype=dtype,
                requires_grad=True,
            )
        )
        if bias:
            self.bias = Parameter(
                init.kaiming_uniform(
                    out_features, 1, device=device, dtype=dtype, requires_grad=True
                ).transpose()
            )
        else:
            self.bias = None

    def forward(self, X: Tensor) -> Tensor:
        ret = X @ self.weight
        if self.bias:
            b = self.bias.broadcast_to(ret.shape)
            ret = ret + b
        return ret


class Flatten(Module):
    def forward(self, X):
        s = X.shape
        batch = s[0]
        other = 1
        for i, x in enumerate(s):
            if i == 0:
                continue
            other *= x
        return X.reshape((batch, other))


class ReLU(Module):
    def forward(self, x: Tensor) -> Tensor:
        return ops.relu(x)


class Sequential(Module):
    def __init__(self, *modules):
        super().__init__()
        self.modules = modules

    def forward(self, x: Tensor) -> Tensor:
        input = x
        for module in self.modules:
            input = module(input)
        return input


class SoftmaxLoss(Module):
    def forward(self, logits: Tensor, y: Tensor):
        one_hot_y = init.one_hot(logits.shape[-1], y)
        z_y = ops.summation(logits * one_hot_y, axes=(-1,))
        log_sum = ops.logsumexp(logits, axes=(-1,))
        return ops.summation((log_sum - z_y) / logits.shape[0])


class BatchNorm1d(Module):
    def __init__(self, dim, eps=1e-5, momentum=0.1, device=None, dtype="float32"):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.momentum = momentum
        self.weight = Parameter(init.ones(dim, device=device, requires_grad=True))
        self.bias = Parameter(init.zeros(dim, device=device, requires_grad=True))
        self.running_mean = init.zeros(dim)
        self.running_var = init.ones(dim)

    def forward(self, x: Tensor) -> Tensor:
        if self.training:
            batch_mean = x.sum((0,)) / x.shape[0]
            batch_var = ((x - batch_mean.broadcast_to(x.shape)) ** 2).sum(
                (0,)
            ) / x.shape[0]
            self.running_mean = (
                1 - self.momentum
            ) * self.running_mean + self.momentum * batch_mean.data
            self.running_var = (
                1 - self.momentum
            ) * self.running_var + self.momentum * batch_var.data
            norm = (x - batch_mean.broadcast_to(x.shape)) / (
                batch_var.broadcast_to(x.shape) + self.eps
            ) ** 0.5
            return self.weight.broadcast_to(x.shape) * norm + self.bias.broadcast_to(
                x.shape
            )
        else:
            norm = (x - self.running_mean.broadcast_to(x.shape)) / (
                self.running_var.broadcast_to(x.shape) + self.eps
            ) ** 0.5
            return self.weight.broadcast_to(x.shape) * norm + self.bias.broadcast_to(
                x.shape
            )


class LayerNorm1d(Module):
    def __init__(self, dim, eps=1e-5, device=None, dtype="float32"):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weights = Parameter(init.ones(dim, requires_grad=True))
        self.bias = Parameter(init.zeros(dim, requires_grad=True))

    def forward(self, x: Tensor) -> Tensor:
        mean = (
            (ops.summation(x, axes=(-1,)) / x.shape[-1])
            .reshape((x.shape[0], 1))
            .broadcast_to(x.shape)
        )
        var = (
            (ops.summation((x - mean) ** 2, axes=(-1,)) / x.shape[-1])
            .reshape((x.shape[0], 1))
            .broadcast_to(x.shape)
        )
        deno = (var + self.eps) ** 0.5
        return self.weights.broadcast_to(x.shape) * (
            (x - mean) / deno
        ) + self.bias.broadcast_to(x.shape)


class Dropout(Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        if self.training:
            mask = init.randb(*x.shape, p=1 - self.p) / (1 - self.p)
            x = x * mask
        return x


class Residual(Module):
    def __init__(self, fn: Module):
        super().__init__()
        self.fn = fn

    def forward(self, x: Tensor) -> Tensor:
        return self.fn(x) + x

When initializing Module parameters, we need to implement some common parameter initialization methods, such as kaiming_uniform. For the specific code, please refer to init_initializers.py, which will not be expanded here.

There are also some more complex Modules, such as RNN and CNN, which are not implemented here temporarily.

If you are interested in Modules commonly used in large language models, such as the implementation of transformers, you can refer to my articles:

Optimizer

The optimizer is used to update model parameters to minimize the loss function. Here, two optimizers, SGD and Adam, are implemented.

class Optimizer:
    def __init__(self, params):
        self.params = params

    def step(self):
        raise NotImplementedError()

    def reset_grad(self):
        for p in self.params:
            p.grad = None


class SGD(Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.0, weight_decay=0.0):
        super().__init__(params)
        self.lr = lr
        self.momentum = momentum
        self.u = defaultdict(float)
        self.weight_decay = weight_decay

    def step(self):
        for p in self.params:
            if self.weight_decay > 0:
                grad = p.grad.data + self.weight_decay * p.data
            else:
                grad = p.grad.data
            self.u[p] = self.momentum * self.u[p] + (1 - self.momentum) * grad
            p.data = p.data - ndl.Tensor(self.lr * self.u[p], dtype=p.dtype)
            

class Adam(Optimizer):
    def __init__(
        self,
        params,
        lr=0.01,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
        weight_decay=0.0,
    ):
        super().__init__(params)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.weight_decay = weight_decay
        self.t = 0

        self.m = defaultdict(float)
        self.v = defaultdict(float)

    def step(self):
        self.t += 1
        for p in self.params:
            if self.weight_decay > 0:
                grad = p.grad.data + self.weight_decay * p.data
            else:
                grad = p.grad.data
            self.m[p] = self.beta1 * self.m[p] + (1 - self.beta1) * grad
            self.v[p] = self.beta2 * self.v[p] + (1 - self.beta2) * (grad**2)
            unbiased_m = self.m[p] / (1 - self.beta1**self.t)
            unbiased_v = self.v[p] / (1 - self.beta2**self.t)
            p.data = p.data - ndl.Tensor(
                self.lr * unbiased_m / (unbiased_v**0.5 + self.eps),
                dtype=p.dtype,
            )

DataLoader

DataLoader is used for batch loading datasets, supporting shuffle, parallel loading, and other features. Here, a simple Dataset and corresponding DataLoader are implemented. For the specific code, please refer to data_basic.py, as well as common data transformation methods.

About

learn by building a deep learning framework

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published