In [168]:
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import torch

In [169]:
# Model Definition

# Modality Specific Module
class ModalitySpecificModule:
    @staticmethod
    def extract_modality_specific_interactions(input_layer):
        bi_gru_output = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128))(input_layer)
        msl_l1 = tf.keras.layers.Dense(128, activation='relu')(bi_gru_output)
        msl_output = tf.keras.layers.Dense(64, activation='relu')(msl_l1)
        return msl_output

    # Input Definition
    text_input = tf.keras.Input(shape=(50, 300))
    audio_input = tf.keras.Input(shape=(50, 5))
    visual_input = tf.keras.Input(shape=(50, 20))

    @staticmethod
    def compute(text_input, audio_input, visual_input):
        # Extract Modality Specific Interactions
        text_msm_output = ModalitySpecificModule.extract_modality_specific_interactions(text_input)
        audio_msm_output = ModalitySpecificModule.extract_modality_specific_interactions(audio_input)
        visual_msm_output = ModalitySpecificModule.extract_modality_specific_interactions(visual_input)
        return text_msm_output, audio_msm_output, visual_msm_output


# Dense Multimodal Fusion Module
class DenseMultimodalFusionModule:
    residual_features = []
    @staticmethod
    def df_block(df_input1, df_input2, df_input3):
        df_output1 = tf.keras.layers.Dense(64, activation='relu')(df_input1)
        df_output2 = tf.keras.layers.Dense(64, activation='relu')(df_input2)
        df_output3 = tf.keras.layers.Dense(64, activation='relu')(df_input3)
        return df_output1, df_output2, df_output3
        
    @staticmethod
    def dense_fusion_layer(f1, f2, f3):
        df_input1 = tf.keras.layers.Concatenate()([f1, f3])
        df_input2 = tf.keras.layers.Concatenate()([f1, f2])
        df_input3 = tf.keras.layers.Concatenate()([f2, f3])
        r = tf.add(tf.add(f1, f2), f3)
        DenseMultimodalFusionModule.residual_features.append(r)
        fusion_output = DenseMultimodalFusionModule.df_block(df_input1, df_input2, df_input3)
        return fusion_output
    
    @staticmethod
    def compute(f1_0, f2_0, f3_0):
        f1_1, f2_1, f3_1 = DenseMultimodalFusionModule.dense_fusion_layer(f1_0, f2_0, f3_0)
        f1_2, f2_2, f3_2 = DenseMultimodalFusionModule.dense_fusion_layer(f1_1, f2_1, f3_1)
        f1_3, f2_3, f3_3 = DenseMultimodalFusionModule.dense_fusion_layer(f1_2, f2_2, f3_2)
        r = tf.add(tf.add(f1_3, f2_3), f3_3)
        DenseMultimodalFusionModule.residual_features.append(r)
        return DenseMultimodalFusionModule.residual_features


# Multimodal Residual Module
class MultimodalResidualModule:
    @staticmethod
    def compute(residual_features):
        final_residual_feature = residual_features[0]
        for r in residual_features[1:]:
            final_residual_feature = tf.add(final_residual_feature, r)
        return final_residual_feature


# Sentiment Classification Module
class SentimentClassificationModule:
    @staticmethod
    def cmumosiRound(element):
        if element < -2:
            result = -3
        elif -2 <= element and element < -1:
            result = -2
        elif -1 <= element and element < 0:
            result = -1
        elif 0 <= element and element <= 0:
                result = 0
        elif 0 < element and element <= 1:
                result = 1
        elif 1 < element and element <= 2:
                result = 2
        elif element > 2:
                result = 3
        return result
    
    @staticmethod
    def convertToTensor(result):
        rep = {
            -3 : [1., 0., 0., 0., 0., 0., 0.],
            -2 : [0., 1., 0., 0., 0., 0., 0.],
            -1 : [0., 0., 1., 0., 0., 0., 0.],
            0 : [0., 0., 0., 1., 0., 0., 0.],
            1 : [0., 0., 0., 0., 1., 0., 0.],
            2 : [0., 0., 0., 0., 0., 1., 0.],
            3 : [0., 0., 0., 0., 0., 0., 1.]
        }
        return torch.tensor(rep[result])

    @staticmethod
    def compute(residual_feature):
        output_l1 = tf.keras.layers.Dense(64, activation='relu')(residual_feature)
        output_l2 = tf.keras.layers.Dense(32, activation='relu')(output_l1)
        output_l3 = tf.keras.layers.Dense(16, activation='relu')(output_l2)
        output_l4 = tf.keras.layers.Dense(8, activation='relu')(output_l3)

        num_classes = 7
        sentiment = tf.keras.layers.Dense(num_classes, activation='softmax')(output_l4)
        return sentiment


text_input, audio_input, visual_input = ModalitySpecificModule.text_input, ModalitySpecificModule.audio_input, ModalitySpecificModule.visual_input
text_msm_output, audio_msm_output, visual_msm_output = ModalitySpecificModule.compute(text_input, audio_input, visual_input)
residual_features = DenseMultimodalFusionModule.compute(text_msm_output, audio_msm_output, visual_msm_output)
final_residual_feature = MultimodalResidualModule.compute(residual_features)
output = SentimentClassificationModule.compute(final_residual_feature)

model = tf.keras.Model(inputs=[text_input, audio_input, visual_input], outputs=output)

model.summary()

Model: "model_7"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_88 (InputLayer)       [(None, 50, 300)]            0         []                            
                                                                                                  
 input_89 (InputLayer)       [(None, 50, 5)]              0         []                            
                                                                                                  
 input_90 (InputLayer)       [(None, 50, 20)]             0         []                            
                                                                                                  
 bidirectional_87 (Bidirect  (None, 256)                  330240    ['input_88[0][0]']            
 ional)                                                                                     

In [170]:
model.compile(
  optimizer='adam', 
  loss=tf.keras.losses.MeanAbsoluteError(), 
  metrics=[
    tf.keras.metrics.Accuracy(),
    tf.keras.metrics.F1Score()
  ]
)

In [171]:
import pickle

with open('mosi_data.pkl', 'rb') as fp:
    data = pickle.load(fp)

In [172]:
print(data.keys())
print(data['train'].keys())
print(data['train']['vision'].shape)
print(data['train']['audio'].shape)
print(data['train']['text'].shape)
print(data['train']['labels'].shape)
print(data['train']['id'][0].shape)

dict_keys(['valid', 'test', 'train'])
dict_keys(['vision', 'labels', 'text', 'audio', 'id'])
(1284, 50, 20)
(1284, 50, 5)
(1284, 50, 300)
(1284, 1, 1)
(3,)


In [173]:
train_text_input = data['train']['text']
train_audio_input = data['train']['audio']
train_visual_input = data['train']['vision']
train_labels = [SentimentClassificationModule.cmumosiRound(element) for element in data['train']['labels'][:, 0][:, 0]]
train_labels = [SentimentClassificationModule.convertToTensor(element) for element in train_labels]
train_labels = np.array(train_labels)

valid_text_input = data['valid']['text']
valid_audio_input = data['valid']['audio']
valid_visual_input = data['valid']['vision']
valid_labels = [SentimentClassificationModule.cmumosiRound(element) for element in data['valid']['labels'][:, 0][:, 0]]
valid_labels = [SentimentClassificationModule.convertToTensor(element) for element in valid_labels]
valid_labels = np.array(valid_labels)

In [178]:
labels = sorted(set(list(np.squeeze(data['train']['labels']))))
labelset = sorted(set([SentimentClassificationModule.cmumosiRound(label) for label in labels]))
print(labelset)

[-3, -2, -1, 0, 1, 2, 3]


In [176]:
training_dataset = [train_text_input, train_audio_input, train_visual_input]
validation_data = ([valid_text_input, valid_audio_input, valid_visual_input], valid_labels)

In [177]:
model.fit(training_dataset, train_labels, validation_data=validation_data, epochs=10)

Epoch 1/10


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7ff63278c710>