Skip to content
master
Switch branches/tags
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 

EmbraceNet: A robust deep learning architecture for multimodal classification

EmbraceNet

News

Introduction

EmbraceNet is a novel multimodal integration architecture for deep learning models, which provides good compatibility with any network structure, in-depth consideration of correlations between different modalities, and seamless handling of missing data. This repository contains the official PyTorch- and TensorFlow-based implementations of the EmbraceNet model, which is explained in the following paper.

  • J.-H. Choi, J.-S. Lee. EmbraceNet: A robust deep learning architecture for multimodal classification. Information Fusion, vol. 51, pp. 259-270, Nov. 2019 [Paper] [arXiv]
@article{choi2019embracenet,
  title={EmbraceNet: A robust deep learning architecture for multimodal classification},
  author={Choi, Jun-Ho and Lee, Jong-Seok},
  journal={Information Fusion},
  volume={51},
  pages={259--270},
  year={2019},
  publisher={Elsevier}
}

Dependencies

PyTorch-based

  • Python 3.7+
  • PyTorch 1.5+

TensorFlow-based (2.x)

  • Python 3.8+
  • TensorFlow 2.4+

TensorFlow-based (1.x)

  • Python 3.6+
  • TensorFlow 1.8+ (<2.0)

Getting started

The implementations of the EmbraceNet model are in the embracenet_pytorch/, embracenet_tf2/, and embracenet_tf1/ folders. Copy the appropriate folder for your framework to your code base and import it.

# for PyTorch-based
from embracenet_pytorch import EmbraceNet

# for TensorFlow-based (2.x)
from embracenet_tf2 import EmbraceNet

# for TensorFlow-based (1.x)
from embracenet_tf1 import EmbraceNet

Here is a code snippet to employ EmbraceNet.

PyTorch-based

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build a pre-processing network for each modality.
# Assume that there are two pre-processed modalities (modality1, modality2) having sizes of 512 and 128.

# Create an EmbraceNet object.
embracenet = EmbraceNet(device=device, input_size_list=[512, 128], embracement_size=256)

# Feed the output of the pre-processing network to EmbraceNet at the "forward" function of your module.
embraced_output = embracenet(input_list=[modality1, modality2]))

# Employ a post-processing network with inputting embraced_output.

Please refer to the comments in embracenet_pytorch/embracenet.py for more information.

TensorFlow-based (2.x)

# Build a pre-processing network for each modality.
# Assume that there are two pre-processed modalities (modality1, modality2) having sizes of 512 and 128.

# Create an EmbraceNet object.
embracenet = EmbraceNet(input_size_list=[512, 128], embracement_size=256)

# Feed the output of the pre-processing network to EmbraceNet at the "forward" function of your module.
embraced_output = embracenet(input_list=[modality1, modality2]))

# Employ a post-processing network with inputting embraced_output.

Please refer to the comments in embracenet_tf2/embracenet.py for more information.

TensorFlow-based (1.x)

# Create an EmbraceNet object.
embracenet = EmbraceNet(batch_size=16, embracement_size=256)

# Build a pre-processing network for each modality.
# Then, feed the output of the pre-processing network to EmbraceNet.
embracenet.add_modality(input_data=modality1, input_size=512)
embracenet.add_modality(input_data=modality2, input_size=128)

# Integrate the modality data.
embraced_output = embracenet.embrace()

# Build a post-processing network with inputting embraced_output.

Please refer to the comments in embracenet_tf1/embracenet.py for more information.

Examples

Example codes that employ EmbraceNet to build classifiers of Fashion MNIST are included in the examples/fashion_mnist_pytorch/, examples/fashion_mnist_tf2/, and examples/fashion_mnist_tf1/ folders.

About

Robust multimodal integration method implemented in PyTorch and TensorFlow

Topics

Resources

License

Releases

No releases published

Packages

No packages published

Languages