This repository provides the codes for training and testing the Multi-Gated Mixture of Expert with Exclusivity (MMoEEx) used for multi-task learning as described in the paper:
-
To run the code on terminal:
Create a virtual environment using the these requirements, and run the code:
python main.py config/config_file.yaml
-
Config Files most import parameters:
a.
tasks
: define the tasks the model will use and depends on thedata
being used.b.
models
: MMoE (no task exclusivity/exclusion, no MAML), Md (task exclusivity/exclusion, no MAML), and MMoEEx ( task exclusivity/exclusion + MAML)c.
save_tensor
: boolean.d.
seqlen
,prop
andlambda
: Required only ifdata
ismimic
.seqlen
is the maximum size of the sequences;prop
is the proportion of the dataset being used;lambda
is the loss weight.
We also added four examples of config files in our repo.
main.py
: call the data loaders, responsible for the training phase (loss functions, optimization);data_preprocessing.py
: reads the data, performs all the preprocessing, creates and outputs the data loaders;mmoeex.py
: it has thetorch.nn.Module
s for the models, implementation of MMoE, and MMoEEX;utils.py
: support functions for main.py (organize outputs, calculate AUC values, gradients, etc).
Short description of the datasets
3 tasks:
- Income (+50000)
- marital status (is married)
- education (at least undergrad).
Training data (199,523), validation shape (49,881), test shape (49,881), and 482 features.
This dataset is public available, so our code download the datasets into the mtl_datasets/census_dataset
folder. If
you already have the data downloaded in this folder, the code will only load the data.
4 tasks:
- IHM (in-hospital Mortality)
- decomp (decompensation, time-series)
- LOS (length-of-stay, time-series)
- pheno (phenotyping).
Training (28233), validation (6152), test (6056), time-series data, 76 features
This dataset is not public available, so you need to submit a request to work with this data
at https://mimic.mit.edu/iii/gettingstarted/. We followed the
pre-processing steps available
here: https://github.com/YerevaNN/mimic3-benchmarks and saved the
results in a folder named mimic_dataset
.
128 tasks: 439863 molecules, with 1024 features. Each task is a biological target.
This dataset is public available. We provide this dataset in our repo because its pre-processing takes more time than the Census dataset. This Jupyter Notebook shows from where to download and perform the pre-processing.
If you use MMoEEx-MTL, please consider citing:
@inproceedings{AokiBIOKDD21,
author = {Raquel Aoki and Frederick Tung and Gabriel L. Oliveira},
title = {{Heterogeneous Multi-task Learning with Expert Diversity}},
booktitle = {BIOKDD},
year = {2021},
}