In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.optim as optim

import torch.nn as nn

import numpy as np

import pickle

In [5]:
class MultimodalDataset(Dataset):
    
    def __init__(self, text, audio, vision, labels):
        
        self.text = text
        self.audio = audio
        self.vision = vision
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        
        return self.text[idx], self.audio[idx], self.vision[idx], self.labels[idx]

In [7]:
def get_cmu_mosi_dataset(path='../../dataset/cmu-mosi/mosi_50_seq_data.pkl'):
    
    file = open(path, 'rb')
    data = pickle.load(file)
    file.close()
    
    text = torch.tensor(data['train']['text'], dtype=torch.float32)
    audio = torch.tensor(data['train']['audio'], dtype=torch.float32)
    vision = torch.tensor(data['train']['vision'], dtype=torch.float32)
    labels = torch.tensor(data['train']['labels'], dtype=torch.float32).squeeze(1)
    train_set = MultimodalDataset(text, audio, vision, labels)

    text = torch.tensor(data['valid']['text'], dtype=torch.float32)
    audio = torch.tensor(data['valid']['audio'], dtype=torch.float32)
    vision = torch.tensor(data['valid']['vision'], dtype=torch.float32)
    labels = torch.tensor(data['valid']['labels'], dtype=torch.float32).squeeze(1)
    valid_set = MultimodalDataset(text, audio, vision, labels)
    
    text = torch.tensor(data['test']['text'], dtype=torch.float32)
    audio = torch.tensor(data['test']['audio'], dtype=torch.float32)
    vision = torch.tensor(data['test']['vision'], dtype=torch.float32)
    labels = torch.tensor(data['test']['labels'], dtype=torch.float32).squeeze(1)
    test_set = MultimodalDataset(text, audio, vision, labels)
    
    return train_set, valid_set, test_set

In [48]:
class SubNet(nn.Module):
    
    def __init__(self, input_size, hidden_size, dropout, output_size):
        
        super(SubNet, self).__init__()
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        
        _, final_states = self.lstm(x)
        h = self.dropout(final_states[0]).squeeze()
        y = self.linear(h)
        return y

In [49]:
class SubNets(nn.Module):

    def __init__(self, input_size, hidden_size, dropout):
        
        super(SubNets, self).__init__()
        
        self.t_subnet = SubNet(input_size[0], hidden_size[0], dropout[0], hidden_size[0])
        self.a_subnet = SubNet(input_size[1], hidden_size[1], dropout[1], hidden_size[1])
        self.v_subnet = SubNet(input_size[2], hidden_size[2], dropout[2], hidden_size[2])
        
    def forward(self, x_t, x_a, x_v):
        
        return self.t_subnet(x_t), self.a_subnet(x_a), self.v_subnet(x_v)

In [50]:
class FusionLayer(nn.Module):
    
    def __init__(self, input_size, dropout, output_size, rank):
        
        super(FusionLayer, self).__init__()
        
        self.t_linear = nn.Linear(input_size[0], rank, bias=False)
        self.a_linear = nn.Linear(input_size[1], rank, bias=False)
        self.v_linear = nn.Linear(input_size[2], rank, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.o_linear = nn.Linear(rank, output_size)
        
    def forward(self, x_t, x_a, x_v):
        
        y_t = self.t_linear(x_t)
        y_a = self.a_linear(x_a)
        y_v = self.v_linear(x_v)
        
        y = self.o_linear(self.dropout(y_t * y_a * y_v))
        
        return y

In [51]:
class FusionNetwork(nn.Module):
    
    def __init__(self, s_input_size, s_hidden_size, s_dropout,
                 f_input_size, f_dropout, f_output_size, f_rank):
        
        super(FusionNetwork, self).__init__()
        
        self.subnets = SubNets(s_input_size, s_hidden_size, s_dropout)
        self.fusion_layer = FusionLayer(f_input_size, f_dropout, f_output_size, f_rank)
        
    def forward(self, x_t, x_a, x_v):
            
        y_t, y_a, y_v = self.subnets(x_t, x_a, x_v)
        
        batch_size = y_t.shape[0]
        
        y_t = torch.cat((y_t, torch.ones((batch_size, 1))), dim=1)
        y_a = torch.cat((y_a, torch.ones((batch_size, 1))), dim=1)
        y_v = torch.cat((y_v, torch.ones((batch_size, 1))), dim=1)

        y = self.fusion_layer(y_t, y_a, y_v)
        
        return y

In [68]:
def train(s_hidden_size, s_dropout, f_dropout, f_rank, epochs=100, batch_size=32):
    
    train_set, valid_set, test_set = get_cmu_mosi_dataset()
    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_set, batch_size=len(valid_set))
    test_dataloader = DataLoader(test_set, batch_size=len(test_set))
    
    s_input_size = (300, 5, 20)
    
    f_input_size = tuple(x+1 for x in s_hidden_size)
    f_output_size = 1
    
    model = FusionNetwork(s_input_size, s_hidden_size, s_dropout, 
                          f_input_size, f_dropout, f_output_size, f_rank)
    
    optimizer = optim.Adam(list(model.parameters()))
    criterion = nn.MSELoss()
    min_error = 1000
    
    for epoch in range(epochs):
        # print('Epoch {}'.format(epoch))
        model.train()
        train_loss = 0.0
        for batch in train_dataloader:
            model.zero_grad()
            
            text, audio, vision, labels = batch
            output = model(text, audio, vision)
            
            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

        # print('Train Loss {:.4f}'.format(train_loss/len(train_dataloader)))
    
        model.eval()
        for batch in valid_dataloader:

            text, audio, vision, labels = batch
            output = model(text, audio, vision)
            
        valid_error = nn.functional.mse_loss(output, labels).item()
        # print('Valid Error {:.4f}'.format(valid_error))
        
        for batch in test_dataloader:

            text, audio, vision, labels = batch
            output = model(text, audio, vision)
            
        test_error = nn.functional.mse_loss(output, labels).item()
        # print('Test Error {:.4f}'.format(test_error))
        
        if test_error < min_error:
            min_error = test_error
    
    return min_error 

In [69]:
subnet_hyper_params = dict()
subnet_hyper_params['text_hidden'] = [2, 4, 8, 16, 32, 64, 128]
subnet_hyper_params['audio_hidden'] = [2, 4, 8, 16]
subnet_hyper_params['vision_hidden'] = [2, 4, 8, 16]
subnet_hyper_params['dropout'] = [0, 0.1, 0.2, 0.3, 0.4, 0.5]
subnet_hyper_params['fusion_rank'] = [1, 2, 3, 4, 5]

In [None]:
min_setting = ()
min_error = 1000
for t_hid in subnet_hyper_params['text_hidden']:
    for a_hid in subnet_hyper_params['audio_hidden']:
        for v_hid in subnet_hyper_params['vision_hidden']:
            for t_drop in subnet_hyper_params['dropout']:
                for a_drop in subnet_hyper_params['dropout']:
                    for v_drop in subnet_hyper_params['dropout']:
                        for f_drop in subnet_hyper_params['dropout']:
                            for f_rank in subnet_hyper_params['fusion_rank']:
                                s_hidden_size = (t_hid, a_hid, v_hid)
                                s_dropout = (t_drop, a_drop, v_drop)
                                error = train(s_hidden_size, s_dropout, f_drop, f_rank)
                                if error < min_error:
                                    min_error = error
                                    min_setting = (s_hidden_size, s_dropout, f_drop, f_rank)
                                print('+++++++++++++++++++++++++++++++++++++++')
                                print('Hidden sizes {}'.format(s_hidden_size))
                                print('Subnet Dropouts {}'.format(s_dropout))
                                print('Fusion Dropout {}'.format(f_drop))
                                print('Fusion Rank {}'.format(f_rank))
                                print('MSE {:.4f}'.format(error))
                                print('min MSE {:.4f}'.format(min_error))

+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0
Fusion Rank 1
MSE 1.7419
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0
Fusion Rank 2
MSE 2.3133
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0
Fusion Rank 3
MSE 2.1248
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0
Fusion Rank 4
MSE 2.1415
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0
Fusion Rank 5
MSE 2.1516
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0.1
Fusion Rank 1
MSE 2.0867
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0)
Fusion Dropout 0.1
Fusion Rank 2




+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0
Fusion Rank 1
MSE 2.2334
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0
Fusion Rank 2
MSE 2.0720
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0
Fusion Rank 3
MSE 2.2268
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0
Fusion Rank 4
MSE 2.1743
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0
Fusion Rank 5
MSE 2.1731
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0.1
Fusion Rank 1
MSE 2.0132
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.3)
Fusion Dropout 0.1




+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0
Fusion Rank 1
MSE 2.2156
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0
Fusion Rank 2
MSE 1.9261
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0
Fusion Rank 3
MSE 1.9228
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0
Fusion Rank 4
MSE 2.2239
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0
Fusion Rank 5
MSE 1.9947
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0.1
Fusion Rank 1
MSE 2.1905
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.4)
Fusion Dropout 0.1




+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0
Fusion Rank 1
MSE 2.0511
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0
Fusion Rank 2
MSE 2.1302
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0
Fusion Rank 3
MSE 2.1907
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0
Fusion Rank 4
MSE 2.0746
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0
Fusion Rank 5
MSE 1.9617
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0.1
Fusion Rank 1
MSE 1.9698
min MSE 1.7419
+++++++++++++++++++++++++++++++++++++++
Hidden sizes (2, 2, 2)
Subnet Dropouts (0, 0, 0.5)
Fusion Dropout 0.1
