In [1]:
import torch
import torch.nn as nn
import numpy as np
from abc import abstractmethod, ABC
from tensor_layers import truncated_normal
from tensor_layers import low_rank_tensors
from tensor_layers import TensorizedLinear
import torch.distributions as td

import tensorly as tl
tl.set_backend('pytorch')

In [28]:
class CP_Linear(nn.Module):
    
    def __init__(self, input_sizes, output_size, init_rank=20, rank_adaptive=True,
                 prior_type='log_uniform', em_stepsize=1.0, init_method='nn', eta=None):
        
        super(CP_Linear, self).__init__()
        
        self.input_sizes = input_sizes
        self.output_size = output_size
        
        self.dims = input_sizes + (output_size,)
        self.order = len(self.dims)
        self.init_rank = init_rank
        self.prior_type = prior_type
        
        self._init_factors()
        
        if rank_adaptive:
            self._init_factor_distributions()
            self._init_rank_priors_and_prior_distributions()
        
    def _init_factors(self):
        '''
        assumming 'nn' init from Cole's implementation
        '''
        
        self.target_stddev = np.sqrt(2/np.prod(self.input_sizes))
        
        # factor_stddev = {target_stddev / sqrt(init_rank)}^(1 / order)
        self.factor_stddev = torch.pow(self.target_stddev / 
                                       torch.sqrt(torch.tensor(1.0 * self.init_rank)),
                                       1.0 / self.order)
        
        # factor_init_dist = truncated normal distribution
        factor_init_dist = truncated_normal.TruncatedNormal(loc=0.0,
                                           scale=self.factor_stddev,
                                           a=-3.0*self.factor_stddev,
                                           b=3.0*self.factor_stddev)
        
        # factors are sampled from truncated normal distribution 
        # with shape (dim_i, init_rank) for i=1,2,...,D
        self.factors = nn.ParameterList(
            [nn.Parameter(factor_init_dist.sample([x, self.init_rank])) 
             for x in self.dims])
    
    def _init_factor_distributions(self):
        '''
        assuming learned_scale == False from Cole's implementation
        i.e., we are not learning the variance
        '''
        factor_scale_multiplier = 1e-9
        
        # factor_scales = (dim_i, init_rank) matrix with all the elements
        #                 equal to the factor_scale_multiplier for i=1,2,...,D
        self.factor_scales = [factor_scale_multiplier * torch.ones(factor.shape)
                              for factor in self.factors]
        
        # factor_distributions = Independent(Normal(loc=factors[i], scale=factor_scales[i]), 2) 
        #                        for i=1,2,...,D
        self.factor_distributions = []
        for factor, factor_scale in zip(self.factors, self.factor_scales):
            self.factor_distributions.append(
                td.Independent(td.Normal(loc=factor,
                                         scale=factor_scale),
                                reinterpreted_batch_ndims=2))
            
    def _init_rank_priors_and_prior_distributions(self):
        
        self.rank_parameter = torch.sqrt(
            self.get_rank_parameters_update().clone().detach()).view([1,self.init_rank])

        self.factor_prior_distributions = []

        for x in self.dims:
            zero_mean = torch.zeros([x, self.init_rank])
            base_dist = td.Normal(loc=zero_mean,scale=self.rank_parameter)
            independent_dist = td.Independent(base_dist,reinterpreted_batch_ndims=2)
            self.factor_prior_distributions.append(independent_dist)
            
    def get_rank_parameters_update(self):
        
        def half_cauchy():

            M = torch.sum(torch.stack([torch.sum(torch.square(x.mean) 
                                                 + torch.square(x.stddev), dim=0) 
                                       for x in self.factor_distributions]),dim=0)

            D = 1.0 * sum(self.dims)

            update = (M - D * self.eta**2 + torch.sqrt(torch.square(M) + (2.0 * D + 8.0) * torch.square(torch.tensor(self.eta)) * M +torch.pow(torch.tensor(self.eta), 4.0) * torch.square(torch.tensor(D)))) / (2.0 * D + 4.0)

            return update

        def log_uniform():

            M = torch.sum(torch.stack([torch.sum(torch.square(x.mean) 
                                                 + torch.square(x.stddev), dim=0)
                                       for x in self.factor_distributions]),dim=0)

            D = 1.0 * (sum(self.dims) + 1.0)

            update = M / D

            return update

        if self.prior_type == 'log_uniform':
            return log_uniform()
        elif self.prior_type == 'half_cauchy':
            return half_cauchy()
        else:
            raise ValueError("Prior type not supported")

    def update_rank_parameters(self):

        with torch.no_grad():
            rank_update = self.get_rank_parameters_update()
            sqrt_parameter_update = torch.sqrt((1 - self.em_stepsize) * self.rank_parameter.data**2 + self.em_stepsize * rank_update)
            self.rank_parameter.data.sub_(self.rank_parameter.data)
            self.rank_parameter.data.add_(sqrt_parameter_update.to(self.rank_parameter.device))
            
    def get_rank_variance(self):
        return torch.square(torch.relu(self.rank_parameter))
    
    def prune_ranks(self, threshold=1e-5):
        mask = self.get_rank_variance() < threshold
        
        # re-define the factors by removing columns from all the factors 
        # if rank adaptive
        # also re-define factor distributions and rank priors and rank prior distributions
        
        
        return self.get_rank_variance() < threshold
        

In [29]:
layer = CP_Linear((10, 20, 30), 3)

In [30]:
layer.factors

ParameterList(
    (0): Parameter containing: [torch.FloatTensor of size 10x20]
    (1): Parameter containing: [torch.FloatTensor of size 20x20]
    (2): Parameter containing: [torch.FloatTensor of size 30x20]
    (3): Parameter containing: [torch.FloatTensor of size 3x20]
)

In [35]:
layer.rank_parameter

tensor([[0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2554, 0.2148, 0.2193,
         0.2417, 0.2414, 0.2323, 0.2310, 0.2832, 0.2792, 0.2771, 0.2530, 0.2779,
         0.2337, 0.2838]])

In [40]:
layer.factor_prior_distributions[0].stddev

tensor([[0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2554, 0.2148, 0.2193,
         0.2417, 0.2414, 0.2323, 0.2310, 0.2832, 0.2792, 0.2771, 0.2530, 0.2779,
         0.2337, 0.2838],
        [0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2554, 0.2148, 0.2193,
         0.2417, 0.2414, 0.2323, 0.2310, 0.2832, 0.2792, 0.2771, 0.2530, 0.2779,
         0.2337, 0.2838],
        [0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2554, 0.2148, 0.2193,
         0.2417, 0.2414, 0.2323, 0.2310, 0.2832, 0.2792, 0.2771, 0.2530, 0.2779,
         0.2337, 0.2838],
        [0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2554, 0.2148, 0.2193,
         0.2417, 0.2414, 0.2323, 0.2310, 0.2832, 0.2792, 0.2771, 0.2530, 0.2779,
         0.2337, 0.2838],
        [0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2554, 0.2148, 0.2193,
         0.2417, 0.2414, 0.2323, 0.2310, 0.2832, 0.2792, 0.2771, 0.2530, 0.2779,
         0.2337, 0.2838],
        [0.2780, 0.2373, 0.2650, 0.2288, 0.2519, 0.2620, 0.2