In [33]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.optim as optim

import torch.nn as nn

import numpy as np

import pickle

In [15]:
class MultimodalDataset(Dataset):
    
    def __init__(self, text, audio, vision, labels):
        
        self.text = text
        self.audio = audio
        self.vision = vision
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        
        return self.text[idx], self.audio[idx], self.vision[idx], self.labels[idx]

In [16]:
class SubNet(nn.Module):
    
    def __init__(self, input_size, hidden_size, dropout, output_size):
        
        super(SubNet, self).__init__()
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        
        _, final_states = self.lstm(x)
        h = self.dropout(final_states[0]).squeeze()
        y = self.linear(h)
        return y

In [50]:
def get_cmu_mosi_dataset(path='../../dataset/cmu-mosi/mosi_50_seq_data.pkl'):
    
    file = open(path, 'rb')
    data = pickle.load(file)
    file.close()
    
    text = torch.tensor(data['train']['text'], dtype=torch.float32)
    audio = torch.tensor(data['train']['audio'], dtype=torch.float32)
    vision = torch.tensor(data['train']['vision'], dtype=torch.float32)
    labels = torch.tensor(data['train']['labels'], dtype=torch.float32).squeeze(1)
    train_set = MultimodalDataset(text, audio, vision, labels)

    text = torch.tensor(data['valid']['text'], dtype=torch.float32)
    audio = torch.tensor(data['valid']['audio'], dtype=torch.float32)
    vision = torch.tensor(data['valid']['vision'], dtype=torch.float32)
    labels = torch.tensor(data['valid']['labels'], dtype=torch.float32).squeeze(1)
    valid_set = MultimodalDataset(text, audio, vision, labels)
    
    text = torch.tensor(data['test']['text'], dtype=torch.float32)
    audio = torch.tensor(data['test']['audio'], dtype=torch.float32)
    vision = torch.tensor(data['test']['vision'], dtype=torch.float32)
    labels = torch.tensor(data['test']['labels'], dtype=torch.float32).squeeze(1)
    test_set = MultimodalDataset(text, audio, vision, labels)
    
    return train_set, valid_set, test_set

In [51]:
class SubNets(nn.Module):

    def __init__(self, input_size, hidden_size, dropout, output_size):
        
        super(SubNets, self).__init__()
        
        self.t_subnet = SubNet(input_size[0], hidden_size[0], dropout[0], output_size[0])
        self.a_subnet = SubNet(input_size[1], hidden_size[1], dropout[1], output_size[1])
        self.v_subnet = SubNet(input_size[2], hidden_size[2], dropout[2], output_size[2])
        
    def forward(self, x_t, x_a, x_v):
        
        return self.t_subnet(x_t), self.a_subnet(x_a), self.v_subnet(x_v)

In [52]:
class Fusion(nn.Module):
    
    def __init__(self, input_size, dropout, output_size, rank):
        
        super(Fusion, self).__init__()
        
        self.t_linear = nn.Linear(input_size[0], rank, bias=False)
        self.a_linear = nn.Linear(input_size[1], rank, bias=False)
        self.v_linear = nn.Linear(input_size[2], rank, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.y_linear = nn.Linear(rank, output_size)
        
    def forward(self, x_t, x_a, x_v):
        
        y_t = self.dropout(self.t_linear(x_t))
        y_a = self.dropout(self.a_linear(x_a))
        y_v = self.dropout(self.v_linear(x_v))
        
        y = self.y_linear(y_t * y_a * y_v)
        
        return y

In [53]:
class TFN(nn.Module):
    
    def __init__(self, s_input_size, s_hidden_size, s_dropout, s_output_size,
                 f_input_size, f_dropout, f_output_size, f_rank):
        
        super(TFN, self).__init__()
        
        self.subnets = SubNets(s_input_size, s_hidden_size, s_dropout, s_output_size)
        self.fusion = Fusion(f_input_size, f_dropout, f_output_size, f_rank)
        
    def forward(self, x_t, x_a, x_v):
            
        y_t, y_a, y_v = self.subnets(x_t, x_a, x_v)
        
        batch_size = y_t.shape[0]
        
        y_t = torch.cat((y_t, torch.ones((batch_size, 1))), dim=1)
        y_a = torch.cat((y_a, torch.ones((batch_size, 1))), dim=1)
        y_v = torch.cat((y_v, torch.ones((batch_size, 1))), dim=1)

        y = self.fusion(y_t, y_a, y_v)
        
        return y

In [89]:
def train(epochs=100, batch_size=32):
    
    train_set, valid_set, test_set = get_cmu_mosi_dataset()
    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_set, batch_size=len(valid_set))
    test_dataloader = DataLoader(test_set, batch_size=len(test_set))
    
    s_input_size = (300, 5, 20)
    s_hidden_size = (2, 2, 2)
    s_dropout = (0.2, 0.2, 0.2)
    s_output_size = (2, 2, 2)
    
    f_input_size = (3, 3, 3)
    f_dropout = 0.2
    f_output_size = 1
    f_rank = 2
    
    model = TFN(s_input_size, s_hidden_size, s_dropout, s_output_size, 
              f_input_size, f_dropout, f_output_size, f_rank)
    
    optimizer = optim.Adam(list(model.parameters()))
    criterion = nn.MSELoss()
    
    model.train()
    for epoch in range(epochs):
        train_loss = 0.0
        for batch in train_dataloader:
            model.zero_grad()
            
            text, audio, vision, labels = batch
            output = model(text, audio, vision)
            
            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

        print('Train Loss {:.4f}'.format(train_loss/len(train_dataloader)))
    
        model.eval()
        for batch in valid_dataloader:

            text, audio, vision, labels = batch
            output = model(text, audio, vision)
            
        valid_error = nn.functional.mse_loss(output, labels).item()
        print('Valid Error {:.4f}'.format(valid_error))
        
        for batch in test_dataloader:

            text, audio, vision, labels = batch
            output = model(text, audio, vision)
            
        test_error = nn.functional.mse_loss(output, labels).item()
        print('Test Error {:.4f}'.format(test_error))

In [90]:
train()

Train Loss 2.3869
Valid Error 2.7649
Test Error 3.2211
Train Loss 2.3052
Valid Error 2.7421
Test Error 3.1121
Train Loss 2.3283
Valid Error 2.7267
Test Error 3.0047
Train Loss 2.2943
Valid Error 2.7194
Test Error 2.9113
Train Loss 2.2162
Valid Error 2.7216
Test Error 2.8074
Train Loss 2.2338
Valid Error 2.7199
Test Error 2.7831
Train Loss 2.2003
Valid Error 2.7157
Test Error 2.7526
Train Loss 2.2198
Valid Error 2.7154
Test Error 2.6983
Train Loss 2.1227
Valid Error 2.7015
Test Error 2.6580
Train Loss 2.1160
Valid Error 2.6870
Test Error 2.6421
Train Loss 2.0494
Valid Error 2.6679
Test Error 2.6714
Train Loss 2.0248
Valid Error 2.6529
Test Error 2.6570
Train Loss 1.9404
Valid Error 2.6817
Test Error 2.5457
Train Loss 1.9017
Valid Error 2.6498
Test Error 2.6037
Train Loss 1.9008
Valid Error 2.6323
Test Error 2.5926
Train Loss 1.8067
Valid Error 2.6352
Test Error 2.5270
Train Loss 1.7306
Valid Error 2.6149
Test Error 2.5564
Train Loss 1.7246
Valid Error 2.6094
Test Error 2.5290
Train Loss