In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.autograd import Variable
from torch.utils.data import DataLoader

import numpy as np

import pickle

from datasets import Multimodal_Binary_Dataset
from fusion_model import CP_Tensor_Fusion_Network

import time

In [64]:
def train_cmu_mosi(batch_size=32, epochs=100, lr=.001, max_rank=20, rank_adaptive=True,  
                   warmup_epochs=50, kl_multiplier=1e-4, no_kl_epochs=5, accelerated=True):

    # load dataset file
    file = open('../../dataset/cmu-mosi/mosi_20_seq_data.pkl', 'rb')
    data = pickle.load(file)
    file.close()

    # prepare the datasets and data loaders
    train_set = Multimodal_Binary_Dataset(data['train']['text'], data['train']['audio'],
                                  data['train']['vision'], data['train']['labels'])
    valid_set = Multimodal_Binary_Dataset(data['valid']['text'], data['valid']['audio'],
                                  data['valid']['vision'], data['valid']['labels'])

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_set, batch_size=len(valid_set))

    # set up model
    input_sizes = (train_set[0][0]['audio'].shape[0], train_set[0][0]['vision'].shape[0],
                   train_set[0][0]['text'].shape[1])
    hidden_sizes = (32, 32, 128)
    output_size = 1
    
    model = CP_Tensor_Fusion_Network(input_sizes, hidden_sizes, output_size, max_rank,
                                     rank_adaptive)
    # set up training
    DTYPE = torch.FloatTensor
    optimizer = optim.Adam(list(model.parameters()), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    
    # train and validate
    for e in range(1, epochs + 1):
        # train
        tic = time.time()
        model.train()
        train_loss = 0.0
        for batch in train_dataloader:
            model.zero_grad()

            features, label = batch
            
            x_a = Variable(features['audio'].float().type(DTYPE), requires_grad=False)
            x_v = Variable(features['vision'].float().type(DTYPE), requires_grad=False)
            x_t = Variable(features['text'].float().type(DTYPE), requires_grad=False)
            y = Variable(label.view(-1, 1).float().type(DTYPE), requires_grad=False)
            
            output = model([x_a, x_v, x_t])
            nll_loss = criterion(output, y)
            nlp_loss = get_neg_log_prior_loss(model.tensor_fusion_layer.weight_tensor.rank_parameter,
                                              model.tensor_fusion_layer.weight_tensor.factors)

            nll_loss.backward()
            optimizer.step()
            train_loss += nll_loss.item()
        
        return model
        print(train_loss)

In [69]:
def get_neg_log_prior_loss(rank_parameter, factors):
    dist = td.Normal(0, rank_parameter)
    for factor in factors:
        p = -torch.sum(torch.log(dist.cdf(factor)))
        print(p)
    return p

In [70]:
model = train_cmu_mosi()

tensor(644.1651, grad_fn=<NegBackward>)
tensor(621.0089, grad_fn=<NegBackward>)
tensor(2558.0603, grad_fn=<NegBackward>)
tensor(18.5479, grad_fn=<NegBackward>)
tensor(644.5923, grad_fn=<NegBackward>)
tensor(621.0871, grad_fn=<NegBackward>)
tensor(2557.8948, grad_fn=<NegBackward>)
tensor(18.5685, grad_fn=<NegBackward>)
tensor(644.6938, grad_fn=<NegBackward>)
tensor(621.0804, grad_fn=<NegBackward>)
tensor(2557.6550, grad_fn=<NegBackward>)
tensor(18.5952, grad_fn=<NegBackward>)
tensor(644.5192, grad_fn=<NegBackward>)
tensor(620.9665, grad_fn=<NegBackward>)
tensor(2557.5671, grad_fn=<NegBackward>)
tensor(18.6205, grad_fn=<NegBackward>)
tensor(644.5486, grad_fn=<NegBackward>)
tensor(620.8978, grad_fn=<NegBackward>)
tensor(2557.4680, grad_fn=<NegBackward>)
tensor(18.6370, grad_fn=<NegBackward>)
tensor(644.6428, grad_fn=<NegBackward>)
tensor(620.8658, grad_fn=<NegBackward>)
tensor(2557.3743, grad_fn=<NegBackward>)
tensor(18.6571, grad_fn=<NegBackward>)
tensor(644.5530, grad_fn=<NegBackward>)


In [48]:
y.shape

torch.Size([32, 1])

In [45]:
label.shapeb

torch.Size([32, 1, 1])

In [9]:
import torch.distributions as td

In [10]:
normal = td.Normal(0, 1)

In [22]:
normal.cdf(torch.tensor(-1))

tensor(0.1587)

In [28]:
tensor = model.tensor_fusion_layer.weight_tensor

In [29]:
factor_scale = 1e-9
factor_dist = td.Normal(loc=0, scale=tensor.rank_parameter)

In [34]:
tensor.factors[0][0]

tensor([ 0.1616,  0.2713, -0.1114,  0.0177,  0.2770,  0.1499, -0.0087,  0.0943,
         0.1177, -0.0481,  0.0643,  0.1426, -0.0518, -0.0145,  0.0151,  0.0486,
        -0.0546, -0.0239,  0.3509,  0.0476], grad_fn=<SelectBackward>)

In [38]:
factor_dist.cdf(tensor.factors[0][0])

tensor([[0.8359, 0.9501, 0.2737, 0.5453, 0.9428, 0.8083, 0.4792, 0.7029, 0.7608,
         0.3821, 0.6439, 0.7983, 0.3786, 0.4666, 0.5341, 0.6188, 0.3771, 0.4454,
         0.9798, 0.6127]], grad_fn=<MulBackward0>)

In [37]:
factor_dist.cdf(tensor.factors[0])

tensor([[0.8359, 0.9501, 0.2737, 0.5453, 0.9428, 0.8083, 0.4792, 0.7029, 0.7608,
         0.3821, 0.6439, 0.7983, 0.3786, 0.4666, 0.5341, 0.6188, 0.3771, 0.4454,
         0.9798, 0.6127],
        [0.6176, 0.7911, 0.2937, 0.6401, 0.9798, 0.4173, 0.8173, 0.7636, 0.7208,
         0.9178, 0.8128, 0.9612, 0.8145, 0.4543, 0.4445, 0.5084, 0.4040, 0.1602,
         0.2736, 0.2127],
        [0.8112, 0.4306, 0.7227, 0.6602, 0.8213, 0.0386, 0.6663, 0.5947, 0.1200,
         0.3541, 0.3249, 0.9399, 0.2074, 0.6484, 0.4268, 0.1778, 0.4798, 0.7598,
         0.0732, 0.3149],
        [0.8229, 0.0717, 0.2684, 0.0650, 0.8956, 0.8945, 0.6467, 0.2294, 0.1787,
         0.4025, 0.0838, 0.3088, 0.2935, 0.4559, 0.4754, 0.9647, 0.0382, 0.0778,
         0.4810, 0.0702],
        [0.7859, 0.1282, 0.0739, 0.6209, 0.9866, 0.7137, 0.7140, 0.5657, 0.6711,
         0.1053, 0.0805, 0.9103, 0.1144, 0.4349, 0.8268, 0.9001, 0.5141, 0.6955,
         0.5202, 0.9097],
        [0.1704, 0.8412, 0.0365, 0.0144, 0.4696, 0.7992, 0.4

In [32]:
print(tensor.rank_parameter)

Parameter containing:
tensor([[0.1653, 0.1648, 0.1852, 0.1560, 0.1755, 0.1720, 0.1666, 0.1770, 0.1661,
         0.1602, 0.1742, 0.1707, 0.1674, 0.1738, 0.1759, 0.1608, 0.1744, 0.1742,
         0.1711, 0.1663]])
