In [1]:
from tensor_fusion.net import ARF, TFN, LMF
import torch


input_sizes = (300, 5, 20)
hidden_sizes = (128, 32, 32)
fusion_size = 128
out_size = 1
dropouts = (0, 0, 0, 0, 0)
device = 'cpu'
dtype = torch.float32

# define a full model and a compression model
full_model = TFN(input_sizes, hidden_sizes, fusion_size, out_size, dropouts, device=device, dtype=dtype)
low_rank_model = LMF(input_sizes, hidden_sizes, fusion_size, 100, out_size, dropouts, device, dtype)
compressed_model = ARF(input_sizes, hidden_sizes, fusion_size, 100, out_size, dropouts, device=device, dtype=dtype)



In [2]:
from tensor_fusion.dataset import get_cmu_mosi_dataset

# get dataset
train_set, valid_set, test_set = get_cmu_mosi_dataset(binary=True, device=device, dtype=dtype)

In [3]:
from tensor_fusion.train import binary_map_train_ARF, binary_map_train_TFN, binary_map_train_LMF

coeff = 1e-1
lr = 1e-2

print("Train Full Model")
full_max_accuracy, full_max_model, full_result = binary_map_train_TFN(full_model, train_set, test_set, learning_rate=5e-4, print_result=True)
print("==============================")
print("Train L.R. Model")
lr_max_accuracy, lr_max_model, lr_result = binary_map_train_LMF(low_rank_model, train_set, test_set, 1e-4, print_result=True)
print("==============================")
print("Train Comp. Model")
comp_max_accuracy, comp_max_model, comp_result = binary_map_train_ARF(compressed_model, train_set, test_set, coeff, lr, print_result=True)
comp_max_model.compress()

Train Full Model
Epoch 0
Train Loss 0.6814
Valid Loss 0.6941
Test Bin Acc 0.4038
Epoch 1
Train Loss 0.5864
Valid Loss 0.5962
Test Bin Acc 0.7114
Epoch 2
Train Loss 0.4960
Valid Loss 0.5430
Test Bin Acc 0.7274
Epoch 3
Train Loss 0.4500
Valid Loss 0.5518
Test Bin Acc 0.7099
Epoch 4
Train Loss 0.3872
Valid Loss 0.5449
Test Bin Acc 0.7216
Epoch     6: reducing learning rate of group 0 to 5.0000e-05.
Epoch 5
Train Loss 0.3672
Valid Loss 0.5445
Test Bin Acc 0.7449
Epoch 6
Train Loss 0.2866
Valid Loss 0.6078
Test Bin Acc 0.7405
Epoch 7
Train Loss 0.2596
Valid Loss 0.6237
Test Bin Acc 0.7405
Epoch     9: reducing learning rate of group 0 to 5.0000e-06.
Epoch 8
Train Loss 0.2431
Valid Loss 0.6307
Test Bin Acc 0.7347
Epoch 9
Train Loss 0.2335
Valid Loss 0.6431
Test Bin Acc 0.7376
Epoch 10
Train Loss 0.2307
Valid Loss 0.6498
Test Bin Acc 0.7362
Epoch    12: reducing learning rate of group 0 to 5.0000e-07.
Epoch 11
Train Loss 0.2308
Valid Loss 0.6589
Test Bin Acc 0.7318
Epoch 12
Train Loss 0.2283


In [4]:
rank = comp_max_model.fusion_layer.max_rank
low_rank_model_2 = LMF(input_sizes, hidden_sizes, fusion_size, rank, out_size, dropouts, device, dtype)
print("==============================")
print("Train L.R. Model 2")
lr_max_accuracy_2, lr_max_model_2, lr_result_2 = binary_map_train_LMF(low_rank_model_2, train_set, test_set, 1e-4, print_result=True)

Train L.R. Model 2
Epoch 0
Train Loss 0.6919
Valid Loss 0.6997
Test Bin Acc 0.4038
Epoch 1
Train Loss 0.6918
Valid Loss 0.7002
Test Bin Acc 0.4038
Epoch 2
Train Loss 0.6914
Valid Loss 0.6999
Test Bin Acc 0.4038
Epoch 3
Train Loss 0.6875
Valid Loss 0.6895
Test Bin Acc 0.4111
Epoch 4
Train Loss 0.6581
Valid Loss 0.6201
Test Bin Acc 0.7216
Epoch 5
Train Loss 0.5671
Valid Loss 0.5690
Test Bin Acc 0.7230
Epoch 6
Train Loss 0.5069
Valid Loss 0.5637
Test Bin Acc 0.7289
Epoch 7
Train Loss 0.4804
Valid Loss 0.5623
Test Bin Acc 0.7245
Epoch 8
Train Loss 0.4575
Valid Loss 0.6101
Test Bin Acc 0.7143
Epoch 9
Train Loss 0.4390
Valid Loss 0.5646
Test Bin Acc 0.7332
Epoch    11: reducing learning rate of group 0 to 1.0000e-05.
Epoch 10
Train Loss 0.4204
Valid Loss 0.5972
Test Bin Acc 0.7187
Epoch 11
Train Loss 0.3949
Valid Loss 0.5617
Test Bin Acc 0.7347
Epoch 12
Train Loss 0.3926
Valid Loss 0.5789
Test Bin Acc 0.7303
Epoch 13
Train Loss 0.3884
Valid Loss 0.5734
Test Bin Acc 0.7318
Epoch    15: reduci

In [5]:
print('Full Model : Acc. {:.4f}     Total Params. {:10}     Fusion Params. {:10}'.format(
    full_max_accuracy, full_max_model.count_parameters(), full_max_model.count_fusion_parameters()))

print('========================================================================================')

print('L.R. Model : Acc. {:.4f}     Total Params. {:10}     Fusion Params. {:10}'.format(
    lr_max_accuracy, lr_max_model.count_parameters(), lr_max_model.count_fusion_parameters()))

print('Rank {}  Total Comp. Ratio {:5}   Fusion Comp. Ratio {:5}'.format(
    lr_max_model.fusion_layer.rank,
    full_max_model.count_parameters() / lr_max_model.count_parameters(), 
    full_max_model.count_fusion_parameters() / lr_max_model.count_fusion_parameters()))

print('========================================================================================')

print('L.R. Model 2: Acc. {:.4f}     Total Params. {:10}     Fusion Params. {:10}'.format(
    lr_max_accuracy_2, lr_max_model_2.count_parameters(), lr_max_model_2.count_fusion_parameters()))

print('Rank {}  Total Comp. Ratio {:5}   Fusion Comp. Ratio {:5}'.format(
    lr_max_model_2.fusion_layer.rank,
    full_max_model.count_parameters() / lr_max_model_2.count_parameters(), 
    full_max_model.count_fusion_parameters() / lr_max_model_2.count_fusion_parameters()))

print('========================================================================================')

print('Comp. Model: Acc. {:.4f}     Total Params. {:10}     Fusion Params. {:10}'.format(
    comp_max_accuracy, comp_max_model.count_parameters(), comp_max_model.count_fusion_parameters()))

print('Rank {}  Total Comp. Ratio {:5}   Fusion Comp. Ratio {:5}'.format(
    comp_max_model.fusion_layer.max_rank,
    full_max_model.count_parameters() / comp_max_model.count_parameters(), 
    full_max_model.count_fusion_parameters() / comp_max_model.count_fusion_parameters()))

Full Model : Acc. 0.7274     Total Params.   18256609     Fusion Params.   17981696
L.R. Model : Acc. 0.7464     Total Params.     307341     Fusion Params.      32428
Rank 100  Total Comp. Ratio 59.401801256584704   Fusion Comp. Ratio 554.5114098926854
L.R. Model 2: Acc. 0.7347     Total Params.     300558     Fusion Params.      25645
Rank 79  Total Comp. Ratio 60.74238250187984   Fusion Comp. Ratio 701.1774614934685
Comp. Model: Acc. 0.7522     Total Params.     300637     Fusion Params.      25724
Rank 79  Total Comp. Ratio 60.726420899623136   Fusion Comp. Ratio 699.0241020059088


In [6]:
import numpy as np
full_train_time_per_epoch = np.mean(full_result['train_time'])
lr_train_time_per_epoch = np.mean(lr_result['train_time'])
lr_train_time_per_epoch_2 = np.mean(lr_result_2['train_time'])
comp_train_time_per_epoch = np.mean(comp_result['train_time'])
print("Full Model Train Time Per Epoch: {:.2f}".format(full_train_time_per_epoch))
print("L.R. Model Train Time Per Epoch: {:.2f}".format(lr_train_time_per_epoch))
print("L.R. Model 2 Train Time Per Epoch: {:.2f}".format(lr_train_time_per_epoch_2))
print("Comp. Model Train Time Per Epoch: {:.2f}".format(comp_train_time_per_epoch))

Full Model Train Time Per Epoch: 5.31
L.R. Model Train Time Per Epoch: 1.31
L.R. Model 2 Train Time Per Epoch: 1.31
Comp. Model Train Time Per Epoch: 1.46


In [7]:
import time
text, audio, vision, label = test_set[:2]
tic = time.time()
output = full_max_model(text, audio, vision)
toc = time.time()
print("Full Model Inference Time: {:.4f}".format(toc - tic))
tic = time.time()
output = lr_max_model(text, audio, vision)
toc = time.time()
print("L.R. Model Inference Time: {:.4f}".format(toc - tic))
tic = time.time()
output = lr_max_model_2(text, audio, vision)
toc = time.time()
print("L.R Model 2 Inference Time: {:.4f}".format(toc - tic))
tic = time.time()
output = comp_max_model(text, audio, vision)
toc = time.time()
print("Comp. Model Inference Time: {:.4f}".format(toc - tic))

Full Model Inference Time: 0.0292
L.R. Model Inference Time: 0.0050
L.R Model 2 Inference Time: 0.0044
Comp. Model Inference Time: 0.0045


In [18]:
ind = full_result['test_accuracy'].index(full_max_accuracy)
print(full_result['test_f1'][ind])

ind = lr_result['test_accuracy'].index(lr_max_accuracy)
print(lr_result['test_f1'][ind])

ind = lr_result_2['test_accuracy'].index(lr_max_accuracy_2)
print(lr_result_2['test_f1'][ind])

ind = comp_result['test_accuracy'].index(comp_max_accuracy)
print(comp_result['test_f1'][ind])

0.557919621749409
0.6448979591836734
0.6591760299625469
0.6473029045643154
