Skip to content

Tensorflow implementation of three architectures for multi-task learning, a paradigm to learn different prediction tasks jointly using one model

License

Notifications You must be signed in to change notification settings

clabrugere/multitask-learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multi-task learning

This repository contains the implementation three architectures for multi-task learning: shared bottom (a), mixture of experts (b) and multi-gate mixture of experts (c). Multi-task learning is a paradigm where one model learns different tasks jointly by sharing some of its parameters across tasks. It allows to save on resources (compute time, memory), reduce engineering complexity and points of failure of prediction pipelines, and even improve prediction performances for tasks that are correlated where information sharing is beneficial. Nevertheless, it can also suffer from negative transfer for tasks that are too different or with contradictory objectives.

One industry application of this paradigm is to model the funnel in advertising. Inputs are usually the same for CTR and CVR tasks: user and item characteristics but the feedback and sample space differ. One can encode special properties of the funnel directly into the architecture and the loss to improve the overall performance of the system:

  • encode the causal nature of the funnel: a conversion happens after a click, hence $P(click) \geq P(conversion)$. A penalty term can be added in the loss function to account for that.
  • encode the difference in sample space: we usually are interested in predicting a conversion after a click. For a two-task problem modeling CTR and CVR, the loss could then be: $\mathcal{L}_{ctr} + \mathcal{L}_{cvr}$ where $\mathcal{L}_{cvr} = -y_{ctr}y_{cvr} \cdot log( \hat{y}_{ctr}\hat{y}_{cvr} ) + (1 - y_{ctr}y_{cvr}) \cdot log (1 - \hat{y}_{ctr}\hat{y}_{cvr})$ such that $\mathcal{L}_{cvr}$ models the error for post-click conversion $P(conversion \mid click)$.
  • have different experts for the different entities represented in the data (such as user and item) like in Two-towers architecture.

Models

Architecture

  • Shared bottom is the simplest architecture. Inputs are projected to an embedding space and an encoder learns a global representation that is shared by every task. It is used as input for task-specific encoders (called towers in the litterature) that output the final predictions for every tasks.
  • Mixture of experts uses an ensemble of encoders (called experts) shared by all tasks that takes a global embedding projection as input to learn different representations of the data. Each task-specific encoders take a linear combination of experts outputs as input, using a global gating mechanism. Task-specific encoders then projects it to their own task-specific space.
  • Multi-gate mixture of experts is very similar to the mixture of expert model but with the difference that each tasks has its own gate. It allows to better route information of the different expert representations to the tasks as gates focus on catpuring information only relevant for their respective tasks.

In the implementation of this repository, gates are simple linear projection with a softmax activation. A temperature scaling in the softmax could be added to control the collective influence of experts. In addition, every tasks encoders have the same architecture for simplicity sake, but it can easily be adapted to fit more complex applications. Finally, a learnt linear projection is applied to the continuous vectors of inputs and before the concatenation with the learnt embeddings of discrete modalities, in order to project them into the same latent space.

An example of a simple multi-task loss is implemented in models/loss.py, to model multiple binary classification tasks.

Dependencies

Thie repository has the following dependencies:

  • python 3.9+
  • tensorflow 2.12+

Getting Started

git clone https://github.com/clabrugere/multitask-learning.git

Usage

# load your dataset
train_sparse_data = ...
train_dense_data = ...
train_labels = ...

model = MultiGateMixtureOfExperts(
   num_tasks=num_tasks,
   num_emb=num_embeddings,
   ...
)

# train the model
loss = MultiTaskBCE(num_tasks=num_tasks)
optimizer = tf.keras.optimizers.Adam()

model.compile(optimizer=optimizer, loss=loss)
model.fit(
   x=[train_sparse_data, train_dense_data],
   y=train_labels,
   epochs=20,
)

# make predictions
y_pred = model.predict(X_test)

References

About

Tensorflow implementation of three architectures for multi-task learning, a paradigm to learn different prediction tasks jointly using one model

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published