In [1]:
!pip install -q wurlitzer ninja
%load_ext wurlitzer

In [2]:
import os, time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
from torchvision import datasets, transforms

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = "cuda"

torch.manual_seed(0)

<torch._C.Generator at 0x7a1216646f90>

# Conv + Relu Fused Py Kernel

In [3]:
import torch
import torch.nn.functional as F

def run_kernel(f, times, *args):
    for i in range(times):
        f(i, *args)

In [None]:
# üîπ Description:
# This function is supposed to perform a 3x3 convolution on each pixel of the input image
# and then apply the ReLU activation function to the result.
#
# In other words, this is the "kernel" part ‚Äî where the actual multiply‚Äìaccumulate
# operations of convolution happen.
#
# Parameters:
# - i : the index of the pixel (from 0 to N*H*W - 1)
# - x : the input tensor of shape [N, 1, H, W]
# - w : the convolution filter (weights) of shape [1, 1, 3, 3]
# - b : the bias term (a scalar tensor)
# - out : the output tensor to store results
# - N, H, W : dimensions of the input (batch size, height, width)
#
# Inside this function, students should:
# 1Ô∏è Convert i into (n, h, w) indices ‚Äî to locate the correct pixel in the batch
# 2Ô∏è Compute the accumulated sum (acc) by multiplying the 3x3 neighborhood by the weights
# 3Ô∏è Add the bias term b
# 4Ô∏è Apply ReLU (if acc < 0, set it to 0)
# 5Ô∏è Store the result in out[n, 0, h, w]

def conv_relu_kernel_py(i, x, w, b, out, N, H, W):
    pass  # you should write your code here

In [None]:
# Function: conv_relu_py
# ----------------------
# üîπ Description:
# This is the higher-level wrapper function that coordinates the operation.
# It:
# 1Ô∏è Checks that the input tensors have the expected shapes (using assert)
# 2Ô∏è Creates an output tensor with the same size as x
# 3Ô∏è Calls the run_kernel function, which runs conv_relu_kernel_py
#     for each pixel index i (from 0 to N*H*W)
# 4Ô∏è Returns the output tensor
#
# So, this function organizes and launches the lower-level computation.
def conv_relu_py(x, w, b):
    pass  # you should write your code here

In [None]:
device = "cpu"  

H, W = 8, 8  

x = torch.randn(4, 1, H, W, device=device)

conv = torch.nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=True).to(device)

with torch.no_grad():
    conv.weight.copy_(torch.randn_like(conv.weight))
    conv.bias.copy_(torch.randn_like(conv.bias))

w = conv.weight.detach()
b = conv.bias.detach()

y_ref = F.relu(conv(x))

y_py = conv_relu_py(x, w, b)

print("Max diff:", (y_ref - y_py).abs().max().item())


# Cuda Kernel

In [None]:
cuda_begin = r'''
#include <torch/extension.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

cuda_src = cuda_begin + r'''

__global__ void conv_relu_kernel(
    const float* __restrict__ x,
    const float* __restrict__ w,
    const float* __restrict__ b,
    float* __restrict__ out,
    int N,
    int H,
    int W
) {

    // You should write your code here

}

torch::Tensor conv_relu_fused(torch::Tensor x,
                              torch::Tensor w,
                              torch::Tensor b) {
    CHECK_INPUT(x);
    CHECK_INPUT(w);
    CHECK_INPUT(b);

    TORCH_CHECK(x.dim() == 4, "x must be [N,C,H,W]");
    TORCH_CHECK(w.dim() == 4, "w must be [C_out,C_in,3,3]");
    TORCH_CHECK(b.dim() == 1, "b must be [C_out]");

    TORCH_CHECK(x.size(1) == 1, "only C_in=1 supported");
    TORCH_CHECK(w.size(0) == 1 && w.size(1) == 1 &&
                w.size(2) == 3 && w.size(3) == 3,
                "only 1x1x3x3 kernel supported");
    TORCH_CHECK(b.size(0) == 1, "only 1 output channel supported");

    auto x_c = x.contiguous();
    auto w_c = w.contiguous();
    auto b_c = b.contiguous();

    int N = x_c.size(0);
    int H = x_c.size(2);
    int W = x_c.size(3);

    auto out = torch::empty_like(x_c);

    int n_pix = N * H * W;
    int threads = 256;
    int blocks = cdiv(n_pix, threads);

    conv_relu_kernel<<<blocks, threads>>>(
        x_c.data_ptr<float>(),
        w_c.data_ptr<float>(),
        b_c.data_ptr<float>(),
        out.data_ptr<float>(),
        N, H, W
    );
    C10_CUDA_KERNEL_LAUNCH_CHECK();

    return out;
}
'''

cpp_src = r'''
torch::Tensor conv_relu_fused(torch::Tensor x,
                              torch::Tensor w,
                              torch::Tensor b);
'''

module = load_inline(
    name="conv_relu_fused_ext",
    cpp_sources=[cpp_src],
    cuda_sources=[cuda_src],
    functions=["conv_relu_fused"],
    extra_cuda_cflags=["-O3"],
    verbose=False,
)


In [8]:
class FusedConvReLUFn(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, x, weight, bias):

        assert x.is_cuda and weight.is_cuda and bias.is_cuda
        assert x.dim() == 4 and x.size(1) == 1, "only C_in=1 supported"
        assert weight.shape == (1, 1, 3, 3), "only 1x1x3x3 kernel supported"
        assert bias.shape == (1,), "only 1 output channel supported"

        y = module.conv_relu_fused(x, weight, bias)
        ctx.save_for_backward(x, weight, bias, y)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias, y = ctx.saved_tensors

        # grad ReLU
        mask = (y > 0).to(grad_output.dtype)
        grad_z = grad_output * mask

        # dL/dx
        grad_x = torch.nn.grad.conv2d_input(
            x.shape, weight, grad_z, padding=1
        )
        # dL/dW
        grad_weight = torch.nn.grad.conv2d_weight(
            x, weight.shape, grad_z, padding=1
        )
        # dL/db
        grad_bias = grad_z.sum(dim=[0, 2, 3])

        return grad_x, grad_weight, grad_bias


In [9]:
class FusedConvReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(1, 1, 3, 3))
        self.bias   = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return FusedConvReLUFn.apply(x, self.weight, self.bias)


In [10]:
class CNNBaseline(nn.Module):
    def __init__(self, num_convs=5):
        super().__init__()

        self.num_convs = num_convs
        self.conv = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=True)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(14*14, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        
        for _ in range(self.num_convs):
            x = F.relu(self.conv(x))
        
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [11]:
class CNNFused(nn.Module):
    def __init__(self, num_convs=5):
        super().__init__()
        
        self.num_convs = num_convs
        self.conv = FusedConvReLU()
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(14*14, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):

        for _ in range(self.num_convs):
            x = self.conv(x)
        
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
batch_size = 128

transform = transforms.ToTensor()

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [13]:
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total


In [14]:
def train_model_timed(model, train_loader, test_loader, device, epochs=3, lr=1e-3):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    epoch_total_times   = []
    epoch_forward_times = []
    epoch_backward_times= []
    epoch_other_times   = []
    test_accuracies     = []

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0

        fwd_time = 0.0
        bwd_time = 0.0
        other_time = 0.0

        torch.cuda.synchronize()
        epoch_start = time.time()

        for images, labels in train_loader:
            step_start = time.time()

            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # ----- forward -----
            torch.cuda.synchronize()
            t0 = time.time()
            outputs = model(images)
            loss = criterion(outputs, labels)
            torch.cuda.synchronize()
            t1 = time.time()
            fwd_time += (t1 - t0)

            # ----- backward -----
            torch.cuda.synchronize()
            t2 = time.time()
            optimizer.zero_grad()
            loss.backward()
            torch.cuda.synchronize()
            t3 = time.time()
            bwd_time += (t3 - t2)

            # ----- optimizer step -----
            torch.cuda.synchronize()
            t4 = time.time()
            optimizer.step()
            torch.cuda.synchronize()
            t5 = time.time()
            other_time += (t5 - t4)

            running_loss += loss.item() * labels.size(0)

        torch.cuda.synchronize()
        epoch_end = time.time()
        epoch_time = epoch_end - epoch_start

        avg_loss = running_loss / len(train_loader.dataset)
        acc = evaluate(model, test_loader, device)

        epoch_total_times.append(epoch_time)
        epoch_forward_times.append(fwd_time)
        epoch_backward_times.append(bwd_time)
        epoch_other_times.append(other_time)
        test_accuracies.append(acc)

        print(
            f"Epoch {epoch}: "
            f"loss={avg_loss:.4f}, "
            f"test_acc={acc*100:.2f}%, "
            f"time_total={epoch_time:.2f}s "
            f"(fwd={fwd_time:.2f}s, bwd={bwd_time:.2f}s, other={other_time:.2f}s)"
        )

    return {
        "total_times": epoch_total_times,
        "forward_times": epoch_forward_times,
        "backward_times": epoch_backward_times,
        "other_times": epoch_other_times,
        "test_accuracies": test_accuracies,
    }


In [None]:
num_convs = 2 # 2 and 10 should be tested

baseline_cudnn = CNNBaseline(num_convs).to(device)
baseline_cudnn_off = CNNBaseline(num_convs).to(device)
baseline_cpu = CNNBaseline(num_convs).to(device)
fused = CNNFused(num_convs).to(device)

In [16]:
def copy_weights(dst, src):
    with torch.no_grad():
        
        dst.conv.weight.copy_(src.conv.weight)
        dst.conv.bias.copy_(src.conv.bias)
        dst.fc1.weight.copy_(src.fc1.weight)
        dst.fc1.bias.copy_(src.fc1.bias)
        dst.fc2.weight.copy_(src.fc2.weight)
        dst.fc2.bias.copy_(src.fc2.bias)

copy_weights(fused, baseline_cudnn)
copy_weights(baseline_cudnn_off, baseline_cudnn)
copy_weights(baseline_cpu, baseline_cudnn)

In [None]:
print("=== Training baseline CNN (Conv2d + ReLU) - cuDNN  ===")
device = "cuda"

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

stats_base1 = train_model_timed(baseline_cudnn, train_loader, test_loader, device, epochs=5)

In [None]:
print("\n=== Training fused CNN (FusedConvReLU) ===")

stats_fused = train_model_timed(fused, train_loader, test_loader, device, epochs=5)


In [None]:
print("=== Training baseline CNN (Conv2d + ReLU) - cuDNN off ===")

torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False

stats_base2 = train_model_timed(baseline_cudnn_off, train_loader, test_loader, device, epochs=5)

In [None]:
print("=== Training baseline CNN (Conv2d + ReLU) - CPU ===")
device = "cpu"

stats_base3 = train_model_timed(baseline_cpu, train_loader, test_loader, device, epochs=5)