In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## 1) Implement tiling

In [9]:
def tmul(tileA, tileB, t):        
    # Check if the input dimension <= tile_size    
    assert tileA.size(0) <= t
    assert tileA.size(1) <= t
    assert tileB.size(1) <= t
    
    return tileA @ tileB

In [18]:
from math import ceil

def mmul_tiling(matA, matB, t):
    a, c = matA.size()
    _, b = matB.size()
    matC = torch.zeros(a, b)

    ## TODO ##
    # Hint: Design a 3-level for loop        
    for i in range(ceil(a/t)): 
        for j in range(ceil(b/t)):
            for k in range(ceil(c/t)):
                matC[i*t:(i+1)*t, j*t:(j+1)*t] += tmul(matA[i*t:(i+1)*t, k*t:(k+1)*t], matB[k*t:(k+1)*t, j*t:(j+1)*t], t)

    return matC

## 2) Define lowering functions

In [11]:
def conv2d(inputs, weights, padding, tiling, tile_size, bias=None):
    o_chn, i_chn, kernel_size, _ = weights.size()
    bs, i_chn, res, _ = inputs.size()
    
    def weight_lowering():
        ## TODO ##
        lowered_weights = weights.reshape(o_chn, i_chn*kernel_size*kernel_size)
        return lowered_weights

    def inputs_lowering():
        # padding
        pad, _ = padding
        inputs_padded = torch.zeros(bs, i_chn, res+pad*2, res+pad*2)
        inputs_padded[..., pad:res+pad, pad:res+pad] = inputs

        lowered_inputs = torch.zeros(kernel_size*kernel_size, i_chn, bs, res*res)
        for i in range(kernel_size):
            for j in range(kernel_size):
                lowered_inputs[i*kernel_size+j] = inputs_padded[..., i:res+i, j:res+j].transpose(0, 1).reshape(i_chn, bs, -1)
                
        lowered_inputs = lowered_inputs.transpose(0, 1)
        lowered_inputs = lowered_inputs.reshape(i_chn*kernel_size*kernel_size, bs*res*res)        
        return lowered_inputs
    
    def outputs_lifting(outputs):
        ## TODO ##
        ## Hint: Use torch.reshape & torch.transpose ##
        outputs = outputs.reshape(o_chn, bs, res, res);      
        outputs = outputs.transpose(0, 1)
        return outputs

    # Lower Weights
    weights_transformed = weight_lowering()
    
    # Lower Inputs
    inputs_transformed = inputs_lowering()   

    # Compute Outputs    
    if tiling == False:
        lowered_outputs = weights_transformed @ inputs_transformed
    else:
        lowered_outputs = mmul_tiling(weights_transformed, inputs_transformed, tile_size)
        
    # Lift Outputs
    outputs = outputs_lifting(lowered_outputs)     
    
    if bias is not None:
        outputs += bias.view(1, o_chn, 1, 1)
    
    return outputs

## 3) Modify nn.Conv2d module for lowering

In [12]:
class MMConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, bias=False):
        super(MMConv2d, self).__init__(in_channels, out_channels, kernel_size, padding=padding, stride=stride, bias=bias)
        self.lowering = False
        self.tiling = False
        self.tile_size = -1
        
    def forward(self, inputs):
        if self.lowering:
            return conv2d(inputs, self.weight, padding=self.padding, tiling=self.tiling, tile_size=self.tile_size, bias=self.bias)
        else:
            return F.conv2d(inputs, self.weight, padding=self.padding, stride=self.stride, bias=self.bias)
    
    # APIs
    def set_mode(self, lowering, tiling, tile_size=-1):
        self.lowering = lowering
        self.tiling = tiling
        self.tile_size = tile_size
        
    def lowering_test(self, inputs):
        assert self.lowering
        
        # Compute Output
        print("Input size: \t", inputs.size())
        print("Weight size: \t", self.weight.size())
        pred_outputs = conv2d(inputs, self.weight, padding=self.padding, tiling=self.tiling, tile_size=self.tile_size, bias=self.bias)
        print("Output size: \t", pred_outputs.size())        
        print("=============================================")
        
        # Evaluation
        true_outputs = F.conv2d(inputs, self.weight, padding=self.padding, bias=self.bias)
        correct = (pred_outputs - true_outputs).abs().max() < 1e-5
        print("Correctness: \t", correct.item(), '\n')      

## 4) Verify

In [13]:
# Configurations
BS = 8
RES_X, RES_Y = (32, 32)
I_CHN = 3
O_CHN = 64
KERNEL_SIZE = 3
PADDING = 1
BIAS = False

layer = MMConv2d(I_CHN, O_CHN, KERNEL_SIZE, PADDING, bias=BIAS)

In [14]:
# Lowering Test
lowering = True
tiling = False

layer.set_mode(lowering, tiling)

inputs = torch.randn(BS, I_CHN, RES_Y, RES_X)
layer.lowering_test(inputs)

Input size: 	 torch.Size([8, 3, 32, 32])
Weight size: 	 torch.Size([64, 3, 3, 3])
Output size: 	 torch.Size([8, 64, 32, 32])
Correctness: 	 True 



In [19]:
# Tiling test
lowering = True
tiling = True
tile_size = 4

layer.set_mode(lowering, tiling, tile_size)

inputs = torch.randn(BS, I_CHN, RES_Y, RES_X)
layer.lowering_test(inputs)

Input size: 	 torch.Size([8, 3, 32, 32])
Weight size: 	 torch.Size([64, 3, 3, 3])
Output size: 	 torch.Size([8, 64, 32, 32])
Correctness: 	 True 

