Skip to content

[v0.2.7] Optimize Linear layer: eliminate CPU transfers #69

@m96-chan

Description

@m96-chan

Problem

Current Linear layer in pygpukit.llm.model has significant CPU overhead due to GPU↔CPU transfers on every forward pass.

Current Implementation (Linear.__call__)

# Weight transpose: GPU→CPU→GPU
weight_np = self.weight.to_numpy()      # GPU→CPU
weight_t = from_numpy(weight_np.T.copy())  # CPU→GPU

y = matmul(x, weight_t)  # GPU compute

# Bias addition: GPU→CPU→GPU
if self.bias is not None:
    y_np = y.to_numpy()      # GPU→CPU
    bias_np = self.bias.to_numpy()  # GPU→CPU
    y_np += bias_np           # CPU compute
    y = from_numpy(y_np)      # CPU→GPU

Impact

  • 4 GPU↔CPU transfers per Linear call
  • 2 CPU operations per Linear call (transpose, bias add)
  • For a 12-layer model with 2 Linear per layer: 96 transfers per forward pass
  • GPU utilization measured at 1-25% (should be 80%+)

Benchmark Evidence (RTX 3090 Ti)

nvidia-smi dmon -s u results:
# gpu     sm    mem
    0      5     22   <- Very low!
    0     25     11   <- Still low

Solution Options

Option 1: Pre-transpose weights (Quick fix)

Store weights as [in_features, out_features] instead of [out_features, in_features].

def __init__(self, weight, bias):
    # Store transposed weight
    self.weight_t = from_numpy(weight.to_numpy().T.copy())
    self.bias = bias

def __call__(self, x):
    y = matmul(x, self.weight_t)  # No transpose needed
    # Still need bias fix...

Option 2: Implement broadcast add on GPU (Recommended)

Add add_bias(x, bias) operation in C++/CUDA that broadcasts bias across batch dimension.

y = matmul(x, self.weight_t)
y = add_bias(y, self.bias)  # GPU-native broadcast add

Option 3: Fused Linear kernel

Implement y = xW^T + b as single CUDA kernel (like cuBLAS GEMM with beta).

Acceptance Criteria

  • GPU utilization > 70% during LLM forward pass
  • No GPU↔CPU transfers in Linear layer
  • Benchmark shows significant speedup (target: 5-10x faster)

Related

  • Multi-LLM demo: examples/demo_v026_multi_llm.py
  • Linear implementation: src/pygpukit/llm/model.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions