In [1]:
import torch
import torch.nn as nn
import tltorch
import numpy as np
from tensor_fusion.model import AdaptiveRankFactorizedTextSubNet, SubNet
from tensor_fusion.fusion_layer import AdaptiveRankFusionLayer
device='cuda'
DTYPE=torch.float64



In [2]:
TextSubNet = AdaptiveRankFactorizedTextSubNet(300, 128, 64, dropout=0.0, prior_type='half_cauchy', eta=0.01, device=device, dtype=DTYPE)

0.01
0.01


In [3]:
a = torch.randn((4, 20, 300), device=device, dtype=DTYPE)
out = TextSubNet(a)

In [4]:
out.shape

torch.Size([4, 64])

In [2]:
layer = AdaptiveRankFactorizedLinear(300, 128, max_rank=10, tensor_type='CP', prior_type='half_cauchy', eta=0.01, device=device, dtype=DTYPE)

In [3]:
a = torch.randn((1, 300), device=device, dtype=DTYPE)

In [4]:
out = layer(a)

In [5]:
W = layer.weight_tensor.get_full().reshape(300, 128)

In [6]:
out_ = a @ W 

In [7]:
torch.allclose(out_, out)

True

In [12]:
class AdaptiveRankFactorizedLSTM(nn.Module):
    
    def __init__(self, input_size, hidden_size, bias=True, dropout=0.0,
                 max_rank=20, tensor_type='CP', prior_type='half_cauchy', eta=None,
                 device=None, dtype=None):
        
        super().__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.layer_ih = AdaptiveRankFactorizedLinear(input_size, hidden_size*4, bias, dropout, 
                                                     max_rank, tensor_type, prior_type, eta,
                                                     device, dtype)
        self.layer_hh = AdaptiveRankFactorizedLinear(hidden_size, hidden_size*4, bias, dropout,
                                                     max_rank, tensor_type, prior_type, eta,
                                                     device, dtype)
        
    def forward(self, x):
        output = []
        batch_size = x.shape[0]
        
        c = torch.zeros((batch_size, self.hidden_size), device=x.device, dtype=x.dtype)
        h = torch.zeros((batch_size, self.hidden_size), device=x.device, dtype=x.dtype)
        for seq in range(20):
            ih = self.layer_ih(x[:,seq,:])
            hh = self.layer_hh(h)
            i, f, g, o = torch.split(ih + hh, self.hidden_size, 1)
            i = torch.sigmoid(i)
            f = torch.sigmoid(f)
            g = torch.tanh(g)
            o = torch.sigmoid(o)
            c = f * c + i * g
            h = o * torch.tanh(c)
            output.append(h.unsqueeze(1))
            
        output = torch.cat(output, dim=1)
        
        return output, (h, c)

In [13]:
rnn = AdaptiveRankFactorizedLSTM(300, 128, eta=0.01, device=device, dtype=DTYPE)

In [14]:
a = torch.randn((4, 20, 300), device=device, dtype=DTYPE)

In [18]:
output, (h_n, c_n) = rnn(a)

In [19]:
output.shape

torch.Size([4, 20, 128])

In [20]:
h_n.shape

torch.Size([4, 128])

In [21]:
c_n.shape

torch.Size([4, 128])