# Cody Nichols' Assignment 3 COMP 6970

In [23]:
import torch
from torch.autograd import gradcheck
import torch
from torch.autograd import Function

class BatchNorm1dManual(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, gamma, beta, eps=1e-5):
        # Step 1: Compute mean and variance
        mu = x.mean(dim=0)
        var = x.var(dim=0, unbiased=False)

        # Step 2: Normalize the inputs
        x_hat = (x - mu) / torch.sqrt(var + eps)
        
        # Step 3: Apply scale and shift
        result = gamma * x_hat + beta

        # Step 4: Save computed values for backward pass
        ctx.save_for_backward(x, mu, var, x_hat, gamma, beta, torch.tensor(eps))

        return result

    @staticmethod
    def backward(ctx, dout):
        # Step 1: Obtain saved variables in forward
        x, mu, var, x_hat, gamma, beta, eps = ctx.saved_tensors
        N, D = x.shape

        # Step 2: Gradients for gamma and beta
        dgamma = torch.sum(dout * x_hat, dim=0)
        dbeta = torch.sum(dout, dim=0)

        # Step 3: Backprop through the normalization
        dx_hat = dout * gamma

        # Step 4: Gradient with respect to input
        std_inv = 1.0 / torch.sqrt(var + eps)
        dx = ((1.0 / N) * std_inv) * (N * dx_hat - torch.sum(dx_hat, dim=0) - x_hat * torch.sum(dx_hat * x_hat, dim=0))

        return dx, dgamma, dbeta

N, D = 5, 4
# Create random inputs
x = torch.randn(N, D, dtype=torch.double, requires_grad=True)
gamma = torch.ones(D, dtype=torch.double, requires_grad=True)
beta = torch.zeros(D, dtype=torch.double, requires_grad=True)

# Wrap inputs in tuple and perform gradcheck
input_tuple = (x, gamma, beta)
grad_check = gradcheck(BatchNorm1dManual.apply, input_tuple, eps=1e-6, atol=1e-4)
print("Gradient check passed:", grad_check)

Gradient check passed: True


In [24]:
import torch
from torch.autograd import Function
from torch.autograd import gradcheck

class InstanceNorm2d(Function):
    @staticmethod
    def forward(ctx, x, gamma, beta, eps=1e-5):
        N, C, H, W = x.shape

        # Reshape x to compute mean and variance over H and W
        x_reshaped = x.view(N, C, -1)

        # Compute mean and variance per instance per channel
        mu = x_reshaped.mean(dim=2, keepdim=True)
        var = x_reshaped.var(dim=2, unbiased=False, keepdim=True)

        # Normalize
        x_hat = (x_reshaped - mu) / torch.sqrt(var + eps)

        # Reshape x_hat back
        x_hat = x_hat.view(N, C, H, W)

        # Scale and shift
        gamma = gamma.view(1, C, 1, 1)
        beta = beta.view(1, C, 1, 1)
        result = gamma * x_hat + beta

        # Save variables for backward pass
        ctx.save_for_backward(x_hat, var, gamma, torch.tensor(eps))

        return result

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved values
        x_hat, var, gamma, eps = ctx.saved_tensors
        N, C, H, W = grad_output.shape
        M = H * W

        # Reshape tensors
        x_hat = x_hat.view(N, C, -1)
        grad_output = grad_output.view(N, C, -1)

        # Gradients w.r.t. scale (gamma) and shift (beta)
        dgamma = (grad_output * x_hat).sum(dim=(0, 2))
        dbeta = grad_output.sum(dim=(0, 2))

        # Reshape gamma
        gamma = gamma.view(1, C, 1)

        # Compute dx_hat
        dx_hat = grad_output * gamma

        # Compute for dx
        std_inv = 1.0 / torch.sqrt(var + eps)

        # Sum over spatial dimensions
        dx_hat_sum = dx_hat.sum(dim=2, keepdim=True)
        x_hat_dx_hat_sum = (dx_hat * x_hat).sum(dim=2, keepdim=True)

        # Gradient w.r.t. input (x)
        dx = (1.0 / M) * std_inv * (M * dx_hat - dx_hat_sum - x_hat * x_hat_dx_hat_sum)

        # Reshape dx
        dx = dx.view(N, C, H, W)

        return dx, dgamma, dbeta

N, C, H, W = 2,3,4,4
# Example usage
x = torch.randn(N, C, H, W, dtype=torch.double, requires_grad=True)
gamma = torch.ones(C, dtype=torch.double, requires_grad=True)
beta = torch.zeros(C, dtype=torch.double, requires_grad=True)

# Wrap inputs in tuple and perform gradcheck
input_tuple = (x, gamma, beta)
grad_check = gradcheck(InstanceNorm2d.apply, input_tuple, eps=1e-6, atol=1e-4)
print("Gradient check passed:", grad_check)

Gradient check passed: True


In [30]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model with Linear followed by BatchNorm1d
class OriginalModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(OriginalModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.bn = nn.BatchNorm1d(output_size)

    def forward(self, x):
        x = self.fc(x)
        x = self.bn(x)
        return x

# Define a model that only uses the fused Linear layer (BatchNorm merged)
class FusedModel(nn.Module):
    def __init__(self, fused_layer):
        super(FusedModel, self).__init__()
        self.fc_fused = fused_layer

    def forward(self, x):
        x = self.fc_fused(x)
        return x

# Function to test equivalence between original and fused models
def test_fusion(original_model, fused_model, input_data):
    # Get outputs from the original model
    original_output = original_model(input_data)

    # Get outputs from the fused model
    fused_output = fused_model(input_data)

    # Compare the outputs (element-wise)
    if torch.allclose(original_output, fused_output, atol=1e-5):
        print("Fusion is correct! The outputs match.")
    else:
        print("Fusion is incorrect! The outputs do not match.")

# Fuse batch norm into the linear layer
def fuse_batch_norm_into_linear(linear_layer, batch_norm_layer):
    W = linear_layer.weight
    b = linear_layer.bias

    gamma = batch_norm_layer.weight
    beta = batch_norm_layer.bias
    running_mean = batch_norm_layer.running_mean
    running_var = batch_norm_layer.running_var
    eps = batch_norm_layer.eps

    # Compute standard deviation
    std = torch.sqrt(running_var + eps)

    # Compute scale
    scale = gamma / std

    # Reshape scale
    scale = scale.unsqueeze(1)

    # Compute the weight
    W_fused = scale * W

    # Compute the bias
    b_fused = scale.squeeze() * (b - running_mean) + beta

    # Create a new Linear layer with fused parameters
    fused_layer = nn.Linear(linear_layer.in_features, linear_layer.out_features)
    fused_layer.weight = nn.Parameter(W_fused)
    fused_layer.bias = nn.Parameter(b_fused)

    return fused_layer

# Define input size and output size
input_size = 128
output_size = 64

# Instantiate the original model
original_model = OriginalModel(input_size, output_size)

# Switch the model to evaluation mode (important for BatchNorm)
original_model.eval()

# Generate some random input data
input_data = torch.randn(10, input_size)

# Forward pass through the original model
with torch.no_grad():
    original_output = original_model(input_data)

# Fuse the BatchNorm layer into the linear layer
fused_layer = fuse_batch_norm_into_linear(original_model.fc, original_model.bn)

# Instantiate the fused model
fused_model = FusedModel(fused_layer)

# Switch the fused model to evaluation mode
fused_model.eval()

# Test if the original and fused models produce the same output
test_fusion(original_model, fused_model, input_data)

Fusion is correct! The outputs match.


In [35]:
import torch
from torch.autograd import Function
from torch.autograd import gradcheck

class LinearFunction(Function):
    @staticmethod
    def forward(ctx, x, weight, bias):
        # Save tensors for backward pass
        ctx.save_for_backward(x, weight, bias)
        
        # Perform linear transformation
        y = x.mm(weight.t())
        
        # Add bias
        y += bias.unsqueeze(0).expand_as(y)
        return y

    @staticmethod
    def backward(ctx, dout):
        x, weight, bias = ctx.saved_tensors
        # Gradient w.r.t. input
        dx = dout.mm(weight)

        # Gradient w.r.t. weight
        dw = dout.t().mm(x)

        # Gradient w.r.t. bias
        db = dout.sum(dim=0)
        
        return dx, dw, db

def gradient_check_and_output_check_linear():
    torch.manual_seed(0)

    # Create input data
    x = torch.randn(5, 3, dtype=torch.double, requires_grad=True)  # Input tensor (requires grad)
    weight = torch.randn(2, 3, dtype=torch.double, requires_grad=True)  # Weight tensor
    bias = torch.randn(2, dtype=torch.double, requires_grad=True)  # Bias tensor

    # Custom Linear function
    linear_custom = LinearFunction.apply
    output_custom = linear_custom(x, weight, bias)

    # PyTorch's built-in Linear layer
    linear_torch = torch.nn.Linear(3, 2, bias=True).double()
    linear_torch.weight = torch.nn.Parameter(weight.clone().detach())
    linear_torch.bias = torch.nn.Parameter(bias.clone().detach())
    output_torch = linear_torch(x)

    # Compare outputs
    print('Output Consistency Check for Linear Module:', torch.allclose(output_custom, output_torch, atol=1e-6))

    # Use gradcheck on the custom LinearFunction
    test_gradcheck = gradcheck(linear_custom, (x, weight, bias), eps=1e-6, atol=1e-4)
    print('Gradient Check for Linear Module:', test_gradcheck)

# Call the function
gradient_check_and_output_check_linear()

Output Consistency Check for Linear Module: True
Gradient Check for Linear Module: True


In [13]:
import torch
from torch.autograd import Function
from torch.autograd import gradcheck

class ReLUFunction(Function):
    @staticmethod
    def forward(ctx, x):
        # Compute ReLU activation
        y = x.clamp(min=0)
        
        # Save input tensor for backward pass
        ctx.save_for_backward(x)
        return y

    @staticmethod
    def backward(ctx, dout):
        x, = ctx.saved_tensors
        # Create a mask where x > 0
        mask = (x > 0).double()
        
        # Compute gradient w.r.t. input x
        dx = dout * mask
        return dx

def gradient_check_and_output_check_relu():
    torch.manual_seed(0)

    # Create input data
    x = torch.randn(5, 3, dtype=torch.double, requires_grad=True)

    # Custom ReLU function
    relu_custom = ReLUFunction.apply
    output_custom = relu_custom(x)

    # PyTorch's built-in ReLU layer
    relu_torch = torch.nn.ReLU().double()
    output_torch = relu_torch(x)

    # Compare outputs
    print('Output Consistency Check for ReLU Module:', torch.allclose(output_custom, output_torch, atol=1e-6))

    # Use gradcheck on the custom ReLUFunction
    test_gradcheck = gradcheck(relu_custom, (x,), eps=1e-6, atol=1e-4)
    print('Gradient Check for ReLU Module:', test_gradcheck)

# Call the function
gradient_check_and_output_check_relu()

Output Consistency Check for ReLU Module: True
Gradient Check for ReLU Module: True


In [36]:
import torch
from torch.autograd import Function
from torch.autograd import gradcheck

class SoftMaxFunction(Function):
    @staticmethod
    def forward(ctx, x):
        # Subtract max for numerical stability
        x_max = x.max(dim=1, keepdim=True)[0]
        x_stable = x - x_max
        
        # Compute exponentials
        exp_x = torch.exp(x_stable)
        
        # Compute sum over classes
        sum_exp_x = exp_x.sum(dim=1, keepdim=True)
        
        # Compute Softmax output
        s = exp_x / sum_exp_x
        
        # Save the Softmax output for backward pass
        ctx.save_for_backward(s)
        return s

    @staticmethod
    def backward(ctx, dout):
        s, = ctx.saved_tensors
        
        # Compute dot product of dout and s
        s_dout = (dout * s).sum(dim=1, keepdim=True)
        
        # Compute gradient w.r.t. input x
        dx = s * (dout - s_dout)
        return dx

def gradient_check_and_output_check_softmax():
    torch.manual_seed(0)

    # Create input data
    x = torch.randn(5, 3, dtype=torch.double, requires_grad=True)

    # Custom Softmax function
    softmax_custom = SoftMaxFunction.apply
    output_custom = softmax_custom(x)

    # PyTorch's built-in Softmax layer
    softmax_torch = torch.nn.Softmax(dim=1).double()
    output_torch = softmax_torch(x)

    # Compare outputs
    print('Output Consistency Check for Softmax Module:', torch.allclose(output_custom, output_torch, atol=1e-6))

    # Use gradcheck on the custom SoftMaxFunction
    test_gradcheck = gradcheck(softmax_custom, (x,), eps=1e-6, atol=1e-4)
    print('Gradient Check for Softmax Module:', test_gradcheck)

# Call the function
gradient_check_and_output_check_softmax()

Output Consistency Check for Softmax Module: True
Gradient Check for Softmax Module: True
