In [94]:
import torch
import torch.nn as nn
import numpy as np
from abc import abstractmethod, ABC
from tensor_layers import truncated_normal
from tensor_layers import low_rank_tensors
from tensor_layers import TensorizedLinear
import torch.distributions as td

import tensorly as tl
tl.set_backend('pytorch')

In [113]:
class CP_Linear(nn.Module):
    
    def __init__(self, input_sizes, output_size, max_rank=20, prior_type='log_uniform', 
                 em_stepsize=1.0, init_method='nn',eta=None):
        
        super(CP_Linear, self).__init__()
        
        self.input_sizes = input_sizes
        self.output_size = output_size
        
        self.dims = input_sizes + (output_size,)
        self.order = len(self.dims)
        self.max_rank = max_rank
        self.prior_type = prior_type
        
        self._init_factors()
        self._init_factor_distributions()
        self._init_rank_prior()
        
    def _init_factors(self):
        '''
        assumming 'nn' init from Cole's implementation
        '''
        
        self.target_stddev = np.sqrt(2/np.prod(self.input_sizes))
        
        # factor_stddev = {target_stddev / sqrt(max_rank)}^(1 / order)
        self.factor_stddev = torch.pow(self.target_stddev / 
                                       torch.sqrt(torch.tensor(1.0 * self.max_rank)),
                                       1.0 / self.order)
        
        # factor_init_dist = truncated normal distribution
        factor_init_dist = truncated_normal.TruncatedNormal(loc=0.0,
                                           scale=self.factor_stddev,
                                           a=-3.0*self.factor_stddev,
                                           b=3.0*self.factor_stddev)
        
        # factors are sampled from truncated normal distribution 
        # with shape (dim_i, max_rank) for i=1,2,...,D
        self.factors = nn.ParameterList(
            [nn.Parameter(factor_init_dist.sample([x, self.max_rank])) 
             for x in self.dims])
    
    def _init_factor_distributions(self):
        '''
        assuming learned_scale == False from Cole's implementation
        i.e., we are not learning the variance
        '''
        factor_scale_multiplier = 1e-9
        
        # factor_scales = (dim_i, max_rank) matrix with all the elements
        #                 equal to the factor_scale_multiplier for i=1,2,...,D
        self.factor_scales = [factor_scale_multiplier * torch.ones(factor.shape)
                              for factor in self.factors]
        
        # factor_distributions = Independent(Normal(loc=factors[i], scale=factor_scales[i]), 2) 
        #                        for i=1,2,...,D
        self.factor_distributions = []
        for factor, factor_scale in zip(self.factors, self.factor_scales):
            self.factor_distributions.append(
                td.Independent(td.Normal(loc=factor,
                                         scale=factor_scale),
                                reinterpreted_batch_ndims=2))
            
    def _init_rank_prior(self):
        
        self.rank_parameter = torch.sqrt(
            self.get_rank_parameters_update().clone().detach()).view([1,self.max_rank])

        self.factor_prior_distributions = []

        for x in self.dims:
            zero_mean = torch.zeros([x, self.max_rank])
            base_dist = td.Normal(loc=zero_mean,scale=self.rank_parameter)
            independent_dist = td.Independent(base_dist,reinterpreted_batch_ndims=2)
            self.factor_prior_distributions.append(independent_dist)
            
    def get_rank_parameters_update(self):
        
        def half_cauchy():

            M = torch.sum(torch.stack([torch.sum(torch.square(x.mean) 
                                                 + torch.square(x.stddev), dim=0) 
                                       for x in self.factor_distributions]),dim=0)

            D = 1.0 * sum(self.dims)

            update = (M - D * self.eta**2 + torch.sqrt(torch.square(M) + (2.0 * D + 8.0) * torch.square(torch.tensor(self.eta)) * M +torch.pow(torch.tensor(self.eta), 4.0) * torch.square(torch.tensor(D)))) / (2.0 * D + 4.0)

            return update

        def log_uniform():

            M = torch.sum(torch.stack([torch.sum(torch.square(x.mean) 
                                                 + torch.square(x.stddev), dim=0)
                                       for x in self.factor_distributions]),dim=0)

            D = 1.0 * (sum(self.dims) + 1.0)

            update = M / D

            return update

        if self.prior_type == 'log_uniform':
            return log_uniform()
        elif self.prior_type == 'half_cauchy':
            return half_cauchy()
        else:
            raise ValueError("Prior type not supported")

    def update_rank_parameters(self):

        with torch.no_grad():
            rank_update = self.get_rank_parameters_update()
            sqrt_parameter_update = torch.sqrt((1 - self.em_stepsize) * self.rank_parameter.data**2 + self.em_stepsize * rank_update)
            self.rank_parameter.data.sub_(self.rank_parameter.data)
            self.rank_parameter.data.add_(sqrt_parameter_update.to(self.rank_parameter.device))


In [114]:
l = TensorizedLinear(100, 10, tensor_type='CP', shape=(10,10,10))

In [115]:
l.tensor.rank_parameter

Parameter containing:
tensor([[0.2609, 0.3344, 0.3115, 0.3825, 0.3885, 0.3099, 0.2874, 0.3719, 0.2414,
         0.2997, 0.3410, 0.3306, 0.2563, 0.3023, 0.2519, 0.2998, 0.3102, 0.3000,
         0.3021, 0.3274]])

In [112]:
l.tensor.factor_distributions[0].stddev

tensor([[1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09],
        [1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09],
        [1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09],
        [1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
         1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09, 1.0000e-09,
       

In [108]:
layer = CP_Linear((10, 20, 30), 3)

In [109]:
layer.factor_distributions[0].mean

Parameter containing:
tensor([[-8.5084e-02, -4.6769e-01, -1.0849e-01, -2.4412e-01, -1.5046e-01,
          7.4886e-02,  4.7293e-01,  1.0978e-02,  1.0956e-01,  4.2085e-01,
          3.1141e-01, -1.5022e-01,  5.0437e-01,  2.1918e-01, -3.3649e-01,
         -6.7118e-02, -3.3188e-01,  7.3080e-02,  8.0571e-02,  1.1317e-02],
        [-1.0309e-01, -4.4172e-02,  4.7485e-01, -3.3470e-02,  1.3069e-01,
          1.4269e-02,  2.0047e-01, -6.8038e-02,  2.8591e-01,  6.9216e-02,
          1.2770e-02,  2.3743e-01, -3.7915e-01,  1.8009e-01,  1.3312e-02,
         -2.0657e-01,  4.0994e-02, -4.5407e-01,  1.7976e-01,  2.2039e-01],
        [ 6.3784e-01, -2.4571e-01, -5.0475e-01, -2.2820e-01, -6.1975e-02,
         -1.3783e-01,  1.0974e-01, -3.0980e-02, -6.3791e-01, -1.8114e-01,
         -5.2292e-01, -5.6264e-01,  1.5986e-01,  4.6694e-01, -3.6123e-01,
         -2.0816e-01,  4.5117e-01,  2.2959e-01,  7.4908e-03,  9.1826e-02],
        [ 2.8955e-01,  7.7498e-02,  6.2985e-02,  1.7664e-01,  3.1129e-02,
          1.1

In [110]:
layer.factor_prior_distributions[0].mean

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [23]:
class Low_Rank_Tensor(nn.Module):
    
    def __init__(self, dims, max_rank, prior_type, em_stepsize, init_method,
                target_stddev, eta, learned_scale):
        
        super(Low_Rank_Tensor, self).__init__()
        
        self.eps = 1e-12
        
        self.dims = dims
        self.order = len(dims)
        self.max_rank = max_rank
        self.prior_type = prior_type
        self.em_stepsize = em_stepsize
        self.init_method = init_method
        self.target_stddev = target_stddev
        self.learned_scale = learned_scale
        
        self.trainable_variables = []
        
        self._build_factors()
        self._build_factor_distributions()
        self._build_low_rank_prior()
        
    def add_variable(self, init_value, trainable=True):
        
        new_variable = nn.Parameter(init_value.clone().detach(), requires_grad=trainable)
        self.trainable_variables.append(new_variable)
        
        return new_variable
    
    @abstractmethod
    def _build_factors(self):
        pass
    
    @abstractmethod
    def _build_factor_distributions(self):
        pass
    
    @abstractmethod
    def _build_low_rank_prior(self):
        pass
    
    

In [77]:
class CP(Low_Rank_Tensor):
    
    def __init__(self, dims, max_rank, prior_type, em_stepsize, init_method,
                target_stddev, eta, learned_scale=True):
        
        super().__init__(dims, max_rank, prior_type, em_stepsize, init_method,
                         target_stddev, eta, learned_scale)
        
    def _build_factors(self):
        
        # assume that init_method is 'nn'
        
        
        # factor_stddev = {target_stddev / sqrt(max_rank)}^(1 / order)
        self.factor_stddev = torch.pow(self.target_stddev / 
                                       torch.sqrt(torch.tensor(1.0*self.max_rank)),
                                       1.0 / self.order)
        
        # init_dist = truncated normal distribution
        init_dist = truncated_normal.TruncatedNormal(loc=0.0, scale=self.factor_stddev,
                                                    a=-3.0*self.factor_stddev,
                                                    b=3.0*self.factor_stddev)
        
        # factors = sampled from truncated normal distribution
        #           with shape (dim_i, max_rank) for i=1,2,...,D
        self.weights = torch.ones([self.max_rank])
        self.factors = [init_dist.sample(sample_shape=[x, self.max_rank]) 
                        for x in self.dims]
        
        self.weights = None
        self.factors = [self.add_variable(x) for x in self.factors]
        
    def _build_factor_distributions(self):
        
        factor_scale_multiplier = 1e-9
        
        # factor_scales = (dim_i, max_rank) matrix with all the elements
        #                 equal to the factor_scale_multiplier for i=1,2,...,D
        factor_scales = [self.add_variable(factor_scale_multiplier * 
                                           torch.ones(factor.shape),
                                           trainable=self.learned_scale)
                        for factor in self.factors]
        
        # factor_distributions = Independent(Normal(loc=factors[i], scale=factor_scales[i]), 2) 
        #                        for i=1,2,...,D
        self.factor_distributions = []
        for factor, factor_scale in zip(self.factors, factor_scales):
            self.factor_distributions.append(
                td.Independent(base_distribution=td.Normal(loc=factor, scale=factor_scale),
                               reinterpreted_batch_ndims=2))
            
    def _build_low_rank_prior(self):
        # lambda 
        self.rank_parameter = self.add_variable(torch.sqrt(self.get_rank_parameters_update().clone().detach()).view([1, self.max_rank]),
                                                trainable=False)
        
        self.factor_prior_distributions = []
        for x in self.dims:
            zero_mean = torch.zeros([x, self.max_rank])
            base_dist = td.Normal(loc=zero_mean, scale=self.rank_parameter)
            self.factor_prior_distributions.append(td.Independent(base_dist, reinterpreted_batch_ndims=2))
        
    def get_rank_parameters_update(self):

        def log_uniform():

            M = torch.sum(to rch.stack([torch.sum(torch.square(x.mean) + 
                                                 torch.square(x.stddev), dim=0)
                                       for x in self.factor_distributions]),dim=0)

            D = 1.0 * (sum(self.dims) + 1.0)

            update = M / D

            return update
        
        return log_uniform()

In [78]:
layer = Tensorized_Linear((2, 2, 2), 1)

In [79]:
layer.tensor.rank_parameter.shape

torch.Size([1, 20])

In [80]:
layer.tensor.factor_distributions

[Independent(Normal(loc: torch.Size([2, 20]), scale: torch.Size([2, 20])), 2),
 Independent(Normal(loc: torch.Size([2, 20]), scale: torch.Size([2, 20])), 2),
 Independent(Normal(loc: torch.Size([2, 20]), scale: torch.Size([2, 20])), 2),
 Independent(Normal(loc: torch.Size([1, 20]), scale: torch.Size([1, 20])), 2)]

In [81]:
layer.tensor.factor_prior_distributions

[Independent(Normal(loc: torch.Size([2, 20]), scale: torch.Size([2, 20])), 2),
 Independent(Normal(loc: torch.Size([2, 20]), scale: torch.Size([2, 20])), 2),
 Independent(Normal(loc: torch.Size([2, 20]), scale: torch.Size([2, 20])), 2),
 Independent(Normal(loc: torch.Size([1, 20]), scale: torch.Size([1, 20])), 2)]