# CP Forward and Barward Propagation for Linear layer

Import libraries

In [1]:
import torch
import tltorch

DTYPE = torch.float64



Given a matrix
$\mathbf{W} \in \mathbb{R}^{M \times N}$ where $M$ = `in_features` and $N$ = `out_features`

In [3]:
in_features = 256
out_features = 128

Tensorized into $\mathcal{W} \in \mathbb{R}^{m_1 \times m_2 \times \cdots \times m_D \times n_1 \times n_2 \times \cdots \times n_D}$ where $M = \prod_{d=1}^D m_d$ and $M = \prod_{d=1}^D m_d$ where $D$=`order` and `in_tensorized_shape` = $(m_1, m_2, ..., m_D)$ and `out_tensorized_shape` = $(n_1, n_2, ..., n_D)$

In [4]:
tensorized_shape = tltorch.utils.get_tensorized_shape(in_features, out_features, order=3)

Tensorizing (in, out)=((256, 128)) -> (((4, 4, 16), (4, 4, 8)))


CP decomposition of $\mathcal{W}$ is given by:

$\mathcal{W} = \sum_{r=1}^R \mathbf{gm}_1[:,r] \otimes \mathbf{gm}_2[:,r] \otimes \cdots \otimes \mathbf{gm}_D[:,r] \otimes \mathbf{gn}_1[:,r] \otimes \mathbf{gn}_2[:,r] \otimes \cdots \otimes \mathbf{gn}_D[:,r]$ where $R$=`rank` and $\mathbf{gm}_d \in \mathbb{R}^{m_d \times R}\ \forall d \in D$ are `in_factors` and $\mathbf{gn}_d \in \mathbb{R}^{n_d \times R}\ \forall d \in D$ are `out_factors` 

In [10]:
rank = 100
cp_tensor = tltorch.TensorizedTensor.new(tensorized_shape, rank, factorization='CP', dtype=DTYPE)
tltorch.tensor_init(cp_tensor)
print(cp_tensor)
print(cp_tensor.factors)

CPTensorized, shape=[256, 128], tensorized_shape=((4, 4, 16), (4, 4, 8)), rank=100)
FactorList(
    (factor_0): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_1): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_2): Parameter containing: [torch.DoubleTensor of size 16x100]
    (factor_3): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_4): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_5): Parameter containing: [torch.DoubleTensor of size 8x100]
)


From the factors we can reconstruct the tensor $\mathcal{W}$ or the matrix $\mathbf{W}$

In [9]:
print(cp_tensor.to_tensor().shape)
print(cp_tensor.to_matrix().shape)

torch.Size([4, 4, 16, 4, 4, 8])
torch.Size([256, 128])




A typical linear layer involves $\mathbf{Y=XW}$ where $\mathbf{X} \in \mathbb{R}^{B \times M}$ and $\mathbf{Y} \in \mathbb{R}^{B \times N}$ and $B$ = `batch_size`

A tensorized linear layer will equivalently do $\mathcal{Y=XW}$ where $\mathcal{X} \in \mathbb{R}^{B \times m_1 \times m_2 \times \cdots \times m_D}$ and $\mathcal{Y} \in \mathbb{R}^{B \times n_1 \times n_2 \times \cdots \times n_D}$

Instead of reconstructing the tensor or the matrix we can do **factorized forward propagation** with CP factors that is given by:

$\mathcal{Y}=\sum_{r=1}^R (\mathcal{X} \times_1 \mathbf{gm}_1[:,r] \times_2 \mathbf{gm}_2[:,r] \times_3 \cdots \times_D \mathbf{gm}_D[:,r]) \otimes \mathbf{gn}_1[:,r] \otimes \mathbf{gn}_2[:,r] \otimes \cdots \otimes \mathbf{gn}_D[:,r]$ where $\times_d$ is $d$-mode product which is dot product along dimension $d$

In [11]:
def cp_times_matrix_fwd(tensor, matrix):
    """
    Multiplies a tensorly CP tensorized matrix and an input matrix
    
    X @ W
    """
    
    order = len(tensor.tensorized_shape[0])
    saved_tensors = []

    # tensorize the input
    output = matrix.reshape((matrix.shape[0],) + tensor.tensorized_shape[0])
    saved_tensors.append(output)

    # forward propagate with input factors
    output = torch.einsum('na...,ar->n...r', output, tensor.factors[0])
    saved_tensors.append(output)
    for factor in tensor.factors[1:order]:
        output = torch.einsum('na...r,ar->n...r', output, factor)
        saved_tensors.append(output)

    # forward propagate with output factors
    for factor in tensor.factors[order:tensor.order-1]:
        output = torch.einsum('n...r,ar->n...ar', output, factor)
        saved_tensors.append(output)
    output = torch.einsum('n...r,ar->n...a', output, tensor.factors[-1])
    
    # vectorize the output
    output = output.reshape((matrix.shape[0], tensor.shape[1]))
    
    return output, saved_tensors

In [13]:
def cp_times_matrix_bwd(tensor, grad, saved_tensors):
    '''
    X @ W backprob
    '''
    
    order = len(tensor.tensorized_shape[0])
    factor_grads = []

    # derivative of reshape
    grad = grad.reshape((grad.shape[0],) + tensor.tensorized_shape[1])

    # derivatives for 'n...r,ar->n...a'
    factor_grads.append(torch.einsum('...a,...r->ar', grad, saved_tensors[-1]))
    grad = torch.einsum('...a,ar->...r', grad, tensor.factors[-1])

    for (factor, saved_tensor) in zip(reversed(tensor.factors[order:tensor.order-1]), 
           reversed(saved_tensors[order:tensor.order-1])):     
        # derivatives for 'n...r,ar->n...ar'
        factor_grads.append(torch.einsum('...ar,...r->ar', grad, saved_tensor))
        grad = torch.einsum('...ar,ar->...r', grad, factor)

    for (factor, saved_tensor) in zip(reversed(tensor.factors[1:order]), 
           reversed(saved_tensors[1:order])):
        # derivatives for 'na...r,ar->n...r'
        factor_grads.append(torch.einsum('n...r,na...r->ar', grad, saved_tensor))
        grad = torch.einsum('n...r,ar->na...r', grad, factor)

    # derivatives for 'na...,ar->n...r'
    factor_grads.append(torch.einsum('n...r,na...->ar', grad, saved_tensors[0]))
    grad = torch.einsum('n...r,ar->na...', grad, tensor.factors[0])

    # derivative for reshape
    grad = grad.reshape((saved_tensors[0].shape[0], tensor.shape[0]))

    factor_grads = [x for x in reversed(factor_grads)]
    
    return factor_grads, grad

Let's check if the **factorized forward propagation** is working properly

In [12]:
batch_size = 32
X = torch.randn((batch_size, in_features), dtype=DTYPE, requires_grad=True)
W = cp_tensor.to_matrix()
standard_fwd = X @ W

with torch.no_grad():
    factorized_fwd, saved_tensors = cp_times_matrix_fwd(cp_tensor, X)

print(torch.allclose(standard_fwd, factorized_fwd))

True


\mathbb{R}^{{C_o}_d \times R}Let's check if the **factorized backward propagation** is working properly

In [15]:
dy = torch.randn_like(standard_fwd)
standard_fwd.backward(dy)

with torch.no_grad():
    factor_grad, dx = cp_times_matrix_bwd(cp_tensor, dy, saved_tensors)

for i, grad in enumerate(factor_grad):
    print(torch.allclose(cp_tensor.factors[i].grad, grad))

print(torch.allclose(dx, X.grad))

True
True
True
True
True
True
True


Let's compare the number of parameters in standard and **factorized linear**

In [19]:
print('The number of parameters in CP format: {}'.format(sum([torch.numel(x) for x in cp_tensor.factors])))
print('The number of parameters in matrix format: {}'.format(torch.numel(W)))

The number of parameters in CP format: 4000
The number of parameters in matrix format: 32768


# CP Forward and Backward Propagation for CONV layer

In [1]:
import torch
import tltorch
import torch.nn.functional as F

DTYPE = torch.float64



A convolutional kernel is a 4 dimensional tensor of $\mathcal{W} \in \mathbb{R}^{C_o \times C_i \times K_h \times K_w}$ where $C_o$=`out_channels`, $C_i$=`in_channels`, $K_h$=`kernel_size_h`, and $K_w$=`kernel_size_w` 

$\mathcal{W}$ can be decomposed into a tensor $\mathcal{W}' \in \mathbb{R}^{{C_o}_1 \times {C_o}_2 \times \cdots \times {C_o}_D \times {C_i}_1 \times {C_i}_2 \times \cdots \times {C_i}_D \times K_h \times K_w}$ where $C_o = \prod_{d=1}^D {C_o}_d$ and $C_i = \prod_{d=1}^D {C_i}_d$ and $D$=`order`

In [27]:
out_channels = 32
in_channels = 16
kernel_size_h = 3
kernel_size_w = 3
order = 2
tensorized_shape = tltorch.utils.get_tensorized_shape(out_channels, in_channels, order, verbose=False) + ((kernel_size_h, kernel_size_w),)
print('Tensorized shape: {}'.format(tensorized_shape))

Tensorized shape: ((4, 8), (4, 4), (3, 3))


Define $\mathcal{W}$ with CP decomposition

$\mathcal{W} = \sum_{r=1}^R \mathbf{gn}_1[:,r] \otimes \mathbf{gn}_2[:,r] \otimes \cdots \otimes \mathbf{gn}_D[:,r] \otimes \mathbf{gm}_1[:,r] \otimes \mathbf{gm}_2[:,r] \otimes \cdots \otimes \mathbf{gm}_D[:,r] \otimes \mathbf{kh}[:,r] \otimes \mathbf{kw}[:,r]$ where $R$=`rank` and $\mathbf{gn}_d \in \mathbb{R}^{{C_o}_d \times R}\ \forall d \in D$ are `out_factors` and $\mathbf{gm}_d \in \mathbb{R}^{{C_i}_d \times R}\ \forall d \in D$ are `in_factors` and $\mathbf{kh} \in \mathbb{R}^{k_h \times R}, \mathbf{kw} \in \mathbb{R}^{k_w \times R}$ are `kernel_factors`

In [31]:
rank = 100
tensor = tltorch.TensorizedTensor.new(tensorized_shape, rank, factorization='CP', dtype=DTYPE)
tltorch.tensor_init(tensor)
print(tensor)
print(tensor.factors)

CPTensorized, shape=[32, 16, 9], tensorized_shape=((4, 8), (4, 4), (3, 3)), rank=100)
FactorList(
    (factor_0): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_1): Parameter containing: [torch.DoubleTensor of size 8x100]
    (factor_2): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_3): Parameter containing: [torch.DoubleTensor of size 4x100]
    (factor_4): Parameter containing: [torch.DoubleTensor of size 3x100]
    (factor_5): Parameter containing: [torch.DoubleTensor of size 3x100]
)


Given an input tensor $\mathcal{X} \in \mathbb{R}^{C_i \times H \times W}$, its convolution with $\mathcal{W}$ is given by:

$\mathcal{Y}(o,h,w) = \sum_{i=1}^{C_i} \sum_{k_h=1}^{K_h} \sum_{k_w=1}^{K_w} \mathcal{W}[k_h,k_w,i,o] \mathcal{X}[i,h-k_h,w-h_w]\ \forall o \in C_o,\ \forall h \in H,\ \forall w \in W$

Its **factorized forward propagation** is given by:

$\mathcal{Y}(o_1,o_2,...,o_D,h,w) = \sum_{r=1}^R \left(\sum_{k_h=1}^{K_h} \sum_{k_w=1}^{K_w} \left(\mathcal{X}' \times_1 \mathbf{gm}_1[:,r] \times_2 \mathbf{gm}_2[:,r] \times_3 \cdots \times_D \mathbf{gm}_D[:,r]\right)[h-k_h,w-h_w]\right) \otimes \mathbf{gn}_1[:,r] \otimes \mathbf{gn}_2[:,r] \otimes \cdots \otimes \mathbf{gn}_D[:,r]$ 

where $\mathcal{X}' \in \mathbb{R}^{{C_i}_1 \times {C_i}_2 \times \cdots {C_i}_D \times H \times W}$ is a reshaped $\mathcal{X}$

This can be seen as:
1. forward propagate with `in_factors`
2. convolve with `kernel_factors`
3. forward propagate with `out_factors`

In [32]:
def cp_conv_fwd(tensor, input_tensor, order):
    
    saved_tensors = []
    
    # tensorize the input
    output = input_tensor.reshape((input_tensor.shape[0],) + tensor.tensorized_shape[1] + input_tensor.shape[-2:])
    saved_tensors.append(output)
    
    # forward propagate with input factors
    output = torch.einsum('na...xy,ar->nr...xy', output, tensor.factors[order])
    saved_tensors.append(output)
    for factor in tensor.factors[order+1:-2]:
        output = torch.einsum('nra...xy,ar->nr...xy', output, factor)
        saved_tensors.append(output)
    
    # x and y convolutions
    output = torch.nn.functional.conv2d(output, 
                                        tensor.factors[-2].T.reshape(tensor.rank, 1, tensor.tensorized_shape[-1][0], 1), 
                                        padding='same', 
                                        groups=tensor.rank)
    saved_tensors.append(output)
    
    output = torch.nn.functional.conv2d(output, 
                                        tensor.factors[-1].T.reshape(tensor.rank, 1, 1, tensor.tensorized_shape[-1][1]), 
                                        padding='same', 
                                        groups=tensor.rank)
    saved_tensors.append(output)
    # forward propagate with output factors
    for factor in tensor.factors[:order-1]:
        output = torch.einsum('nr...xy,ar->nr...axy', output, factor)
        saved_tensors.append(output)
    output = torch.einsum('nr...xy,ar->n...axy', output, tensor.factors[order-1])

    # reshape the output
    output = output.reshape((output.shape[0], tensor.shape[0], output.shape[-2], output.shape[-1]))
    
    return output, saved_tensors

In [33]:
def cp_conv_bwd(tensor, dy, saved_tensors):
    
    order = len(tensor.tensorized_shape[0])
    
    out_factor_grads = []
    
    dy = dy.reshape((saved_tensors[0].shape[0],) + tensor.tensorized_shape[0] + saved_tensors[0].shape[-2:])

    out_factor_grads.append(torch.einsum('n...axy,nr...xy->ar', dy, saved_tensors[-1]))
    dy = torch.einsum('n...axy,ar->nr...xy', dy, tensor.factors[order-1])
    
    for factor, saved_tensor in zip(reversed(tensor.factors[:order-1]), 
                                    reversed(saved_tensors[-order:-1])):
        out_factor_grads.append(torch.einsum('nr...axy,nr...xy->ar', dy, saved_tensor))
        dy = torch.einsum('nr...axy,ar->nr...xy', dy, factor)

    
    factor_grads = []
    pad = tensor.tensorized_shape[-1][1]//2
    factor_grads.append(F.conv3d(torch.einsum('ncxy->cnxy', saved_tensors[-order-1]).unsqueeze(0), 
                                 torch.einsum('ncxy->cnxy', dy).unsqueeze(1), 
                                 padding=(0,0,pad), 
                                 groups=tensor.rank).squeeze(0).reshape(tensor.rank, tensor.tensorized_shape[-1][1]).T)
    dy = torch.nn.functional.conv_transpose2d(dy, 
                                              tensor.factors[-1].T.reshape(tensor.rank, 1, 1, tensor.tensorized_shape[-1][1]),
                                              padding=(0,pad), 
                                              groups=tensor.rank)
    

    pad = tensor.tensorized_shape[-1][0]//2
    factor_grads.append(F.conv3d(torch.einsum('ncxy->cnxy', saved_tensors[-order-2]).unsqueeze(0), 
                                 torch.einsum('ncxy->cnxy', dy).unsqueeze(1), 
                                 padding=(0,pad,0), 
                                 groups=tensor.rank).squeeze(0).reshape(tensor.rank, tensor.tensorized_shape[-1][0]).T)
    dy = torch.nn.functional.conv_transpose2d(dy, 
                                              tensor.factors[-2].T.reshape(tensor.rank, 1, tensor.tensorized_shape[-1][0], 1),
                                              padding=(pad,0), 
                                              groups=tensor.rank)

    for factor, saved_tensor in zip(reversed(tensor.factors[order+1:order*2]), 
                                    reversed(saved_tensors[1:-order-2])):
        factor_grads.append(torch.einsum('nr...xy,nra...xy->ar', dy, saved_tensor))
        dy = torch.einsum('nr...xy,ar->nra...xy', dy, factor)

    factor_grads.append(torch.einsum('nr...xy,na...xy->ar', dy, saved_tensors[0]))
    dy = torch.einsum('nr...xy,ar->na...xy', dy, tensor.factors[order])
    
    dy = dy.reshape((saved_tensors[0].shape[0], tensor.shape[1]) + saved_tensors[0].shape[-2:])
    
    factor_grads = [x for x in reversed(out_factor_grads)] + [x for x in reversed(factor_grads)]
    
    return factor_grads, dy

Let's check if the **factorized forward propagation** is working properly

In [34]:
batch_size = 1
x = 32
y = 32
input_tensor = torch.randn((batch_size, in_channels, x, y), dtype=DTYPE, requires_grad=True)

kernel = tensor.to_matrix().reshape(out_channels, in_channels, kernel_size_x, kernel_size_y)
standard_fwd = F.conv2d(input_tensor, kernel, bias=None, padding='same')

with torch.no_grad():
    factorized_fwd, saved_tensors = cp_conv_fwd(tensor, input_tensor, 2)

print(torch.allclose(standard_fwd, factorized_fwd))

True


Let's check if the **factorized backward propagation** is working properly

In [35]:
dy = torch.randn_like(standard_fwd)
standard_fwd.backward(dy)

factor_grads, dx = cp_conv_bwd(tensor, dy, saved_tensors)
for i, grad in enumerate(factor_grads):
    print(torch.allclose(tensor.factors[i].grad, grad))
print(torch.allclose(dx, input_tensor.grad))

True
True
True
True
True
True
True


Let's compare the number of parameters in standard and **factorized linear**

In [36]:
print('The number of parameters in CP format: {}'.format(sum([torch.numel(x) for x in tensor.factors])))
print('The number of parameters in matrix format: {}'.format(torch.numel(kernel)))

The number of parameters in CP format: 2600
The number of parameters in matrix format: 4608
