In [3]:
import torch
import torch.nn as nn

import tltorch

DTYPE = torch.float64

A PyTorch implementation of CP tensor fusion layer looks like this.

In [23]:
class LowRankFusion(nn.Module):
    
    def __init__(self, input_sizes, output_size, rank, device=None, dtype=None):
        
        super().__init__()
        
        tensorized_shape = input_sizes + (output_size,)
        
        self.weight_tensor = tltorch.TensorizedTensor.new(tensorized_shape, 
                                                          rank, 
                                                          factorization='CP',
                                                          device=device,
                                                          dtype=DTYPE)
        tltorch.tensor_init(self.weight_tensor)
        
    def forward(self, inputs):
        
        output = 1.0
        for x, factor in zip(inputs, self.weight_tensor.factors[:-1]):
            output = output * (x @ factor)
        
        output = output @ self.weight_tensor.factors[-1].T
    
        return output

In [24]:
input_sizes = (16, 32, 64)
output_size = 32
rank = 10
fusion_layer = LowRankFusion(input_sizes, output_size, rank, dtype=DTYPE)

In [25]:
batch_size = 32
inputs = [torch.randn((batch_size, input_size), dtype=DTYPE) for input_size in input_sizes]

In [26]:
out = fusion_layer(inputs)

In [27]:
out.shape

torch.Size([32, 32])

We need to specify the rank but it is NP hard to determine the rank of a tensor.


We can use the bayesian rank determination algorithm from https://arxiv.org/abs/2010.08689 to determine the rank during the training.

The goal of bayesian training is to maximize the posterior distribution $P(\theta|\mathcal{D}) = \frac{P(\mathcal{D}|\theta)P(\theta)}{P(\mathcal{D})}$ where $\theta$ is parameters and $\mathcal{D}$ is a dataset and $P(\mathcal{D}|\theta)$ is a likelihood function and $P(\theta)$ is a parameter prior distribution.

Since the denominator is a constant, we can maximize the numerator and get the same optimal $\theta$.

Let's look into the parameters of a `AdaptiveRankFusion` layer.

A fusion weight is decomposed into CP factors $\Phi = [U^{(1)}, U^{(2)}, ..., U^{(M)}, U^{(out)}]$ where $U^{(m)} \in \mathbb{R}^{s_m \times R}$ and $U^{(out)} \in \mathbb{R}^{s_{out} \times R}$ and $s_d$ is `input_sizes`[m] and $s_{out}$ is the `output_size` and $R$ is the `max_rank`.

We also have a `rank_parameter` $\lambda \in \mathbb{R}^{R}$ that learns the standard deviation of each columns in every factor.

Therefore, $P(\theta)=P(\lambda)P(\Phi)=\prod_{r=1}^R P(\lambda_r) \prod_{m=1}^{M+1} \prod_{i,j} P(U^{(m)}[i,j])$. 

In theory, we can reduce the rank by 1 if we push the elements of $r$th columns in every factor close to zero.

Since our goal is to reduce the `max_rank`, we set a `rank_prior` distribution that prefers values close to 0 such as HalfCauchy(0, $\eta$) or LogUniform(0,$\infty$).

In other words, $P(\lambda_r)$ gets higher as $\lambda_j$ is closer to 0 and higher when $\lambda_j$ is further from 0. 

Finally, we set a `factor_prior` distribution as Normal(0,$\lambda$) so that small $\lambda_r$ penalizes large values in the $r$th columns of every factor.  

During actual training, we maximize $\log[P(\theta|\mathcal{D})] = \log[P(\mathcal{D}|\theta)P(\theta)] = \log[P(\mathcal{D}|\theta)]+\log[P(\theta)]$  

So, we compute the `log_prior`.

In [1]:
from torch.distributions.half_cauchy import HalfCauchy
from torch.distributions.normal import Normal

class AdaptiveRankFusion(LowRankFusion):
    
    def __init__(self, input_sizes, output_size, max_rank, eta=0.01, device=None, dtype=None):
        
        super().__init__(input_sizes, output_size, max_rank, device, dtype)
        
        self.rank_parameter = nn.Parameter(torch.rand((max_rank,), device=device, dtype=dtype))
        self.rank_prior = HalfCauchy(eta)
        
    def get_log_prior(self):
        
        clamped_rank_parameter = self.rank_parameter.clamp(1e-5)
        self.rank_parameter.data = clamped_rank_parameter.data
        
        log_prior = 0
        log_prior = log_prior + torch.sum(self.rank_prior.log_prob(self.rank_parameter))
        
        factor_prior = Normal(0, self.rank_parameter)
        for factor in self.weight_tensor.factors:
            log_prior = log_prior + torch.sum(factor_prior.log_prob(factor))
        
        return log_prior

NameError: name 'LowRankFusion' is not defined

In [None]:
fusion_layer = AdaptiveRankFusion(input_sizes, output_size, max_rank=rank, dtype=DTYPE)

In [None]:
fusion_layer.get_log_prior()