In [1]:
import torch
import tltorch

DTYPE = torch.float64



## Tensor Fusion

Typically, when inputs from multiple sources (**multimodal inputs**) are given, we concatenate them and feed it through a linear layer to get a single vector. This is easy but not intuitive compared to how humans treat information from different sources. We consider relationship between multimodal information to make our decision, but concatenation fails to capture the intermodal interaction. Therefore, in https://arxiv.org/abs/1707.07250, the authors developed the idea of tensor fusion.

In `tensor_fusion_layer`, the **multimodal inputs** are concatenated with 1 and then outer producted together into a `fusion_tensor`:

$\mathcal{Z} = \bigotimes_{m=1}^M [x_m, 1]$ where $\mathcal{Z}$ is `fusion_tensor` and $M$ is the number of modalities and $x$ is an input from each source. Therefore, $x_m \in \mathbb{R}^{s_m}\ \forall m \in M$ where $s_m$ is `input_size[m]` and $\mathcal{Z} \in \mathbb{R}^{(s_1+1) \times (s_2+1) \times \cdots \times (s_M+1)}$  

Then, the `fusion_tensor` is multiplied to a `fusion_weight` $\mathcal{W} \in \mathbb{R}^{(s_1+1) \times (s_2+1) \times \cdots \times (s_M+1) \times s_{out}}$ where $s_{out}$ is `output_size`:  

$h = \mathcal{Z} \cdot \mathcal{W} \in \mathbb{R}^{s_{out}}$ where $h$ is the `fusion_output`  

In [2]:
batch_size = 32
input_sizes = [64, 4, 16]
inputs = [torch.randn((batch_size, input_size), dtype=DTYPE) for input_size in input_sizes]

output_size = 10
fusion_weight_shape = [x+1 for x in input_sizes] + [output_size]
fusion_weight = torch.randn(fusion_weight_shape, dtype=DTYPE)

In [3]:
print('Input sizes: {}'.format(input_sizes))
print('Output size: {}'.format(output_size))
print('Fusion weight shape: {}'.format(fusion_weight.shape))

Input sizes: [64, 4, 16]
Output size: 10
Fusion weight shape: torch.Size([65, 5, 17, 10])


In [4]:
def concatenate_one(inputs):
    batch_size = inputs[0].shape[0]
    return [torch.cat([x, torch.ones((batch_size,1), dtype=DTYPE)], dim=1) for x in inputs]

In [5]:
def tensor_fusion_layer(inputs, fusion_weight):
    
    concatenated_inputs = concatenate_one(inputs)
    
    fusion_tensor = concatenated_inputs[0]
    for x in concatenated_inputs[1:]:
        fusion_tensor = torch.einsum('n...,na->n...a', fusion_tensor, x)
    
    output = torch.einsum('n...,...o->no', fusion_tensor, fusion_weight)
    
    return output

In [6]:
fusion_output = tensor_fusion_layer(inputs, fusion_weight)

In [7]:
print('Fusion output shape: {}'.format(fusion_output.shape))

Fusion output shape: torch.Size([32, 10])


## CP Tensor Fusion

Given that $z_m = [x_m, 1]\ \forall m \in M$,

$\mathcal{Z} = \bigotimes_{m=1}^M z_m$.

The $\mathcal{W}$ can be CP decomposed into:

$[U^{(1)}, U^{(2)}, ..., U^{(M)}, U^{(out)}]$

where $U^{(m)} \in \mathbb{R}^{s_m+1 \times R}\ \forall m \in M$ and $U^{(out)} \in \mathbb{R}^{s_{out} \times R}$ are the `factors` and $R$ is `rank`

We can reconstruct the `fusion_weight` by:

$\mathcal{W} = \sum_{r=1}^R \bigotimes_{m=1}^M U^{(m)}[:,r] \otimes U^{(out)}[:,r]$

The computational cost of tensor fusion layer can be reduced by using CP decomposition:

$h = \mathcal{Z} \cdot \mathcal{W} = \bigotimes_{m=1}^M z_m \cdot \sum_{r=1}^R \left( \bigotimes_{m=1}^M U^{(m)}[:,r] \otimes U^{(out)}[:,r] \right) =  \sum_{r=1}^R \left( \Lambda_{m=1}^M z_m \cdot U^{(m)}[:,r] \right) \otimes U^{(out)}[:,r] = \left( \Lambda_{m=1}^M   z_m \cdot U^{(m)} \right) \otimes {U^{(out)}}^\top$  

In [8]:
rank = 10
fusion_weight = tltorch.TensorizedTensor.new(fusion_weight_shape, rank, factorization='CP', dtype=DTYPE)
tltorch.tensor_init(fusion_weight)
print(fusion_weight)
print(fusion_weight.factors)

CPTensorized, shape=[65, 5, 17, 10], tensorized_shape=[65, 5, 17, 10], rank=10)
FactorList(
    (factor_0): Parameter containing: [torch.DoubleTensor of size 65x10]
    (factor_1): Parameter containing: [torch.DoubleTensor of size 5x10]
    (factor_2): Parameter containing: [torch.DoubleTensor of size 17x10]
    (factor_3): Parameter containing: [torch.DoubleTensor of size 10x10]
)


In [9]:
def cp_tensor_fusion_layer(inputs, fusion_weight):
    
    concatenated_inputs = concatenate_one(inputs)

    fusion_output = 1.0
    for x, factor in zip(concatenated_inputs, fusion_weight.factors[:-1]):
        fusion_output = fusion_output * (x @ factor)
    fusion_output = fusion_output @ fusion_weight.factors[-1].T
    
    return fusion_output

Let's check if it is true

In [10]:
import time

tic = time.time()
fusion_output = cp_tensor_fusion_layer(inputs, fusion_weight)
toc = time.time()
cp_time = toc - tic
tic = time.time()
fusion_output_ = tensor_fusion_layer(inputs, fusion_weight.to_matrix())
toc = time.time()
reg_time = toc - tic
print('Are the outputs close enough? {}'.format(torch.allclose(fusion_output, fusion_output_)))
print('Is CP tensor fusion faster than regular tensor fusion? {}'.format(cp_time < reg_time))
print('CP fusion time: {}'.format(cp_time))
print('Regular fusion time: {}'.format(reg_time))

Are the outputs close enough? True
Is CP tensor fusion faster than regular tensor fusion? False
CP fusion time: 0.008232831954956055
Regular fusion time: 0.006335020065307617


