Skip to content

Commit

Permalink
added support for Nash-MTL (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baijiong-Lin committed Jul 22, 2022
1 parent d70664b commit 04cf41a
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 11 deletions.
13 changes: 12 additions & 1 deletion LibMTL/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
## CAGrad
_parser.add_argument('--calpha', type=float, default=0.5, help='calpha for CAGrad')
_parser.add_argument('--rescale', type=int, default=1, help='rescale for CAGrad')
## Nash_MTL
_parser.add_argument('--update_weights_every', type=int, default=1, help='update_weights_every for Nash_MTL')
_parser.add_argument('--optim_niter', type=int, default=20, help='optim_niter for Nash_MTL')
_parser.add_argument('--max_norm', type=float, default=1.0, help='max_norm for Nash_MTL')

# args for architecture
## CGC
Expand All @@ -62,7 +66,7 @@ def prepare_args(params):
"""
kwargs = {'weight_args': {}, 'arch_args': {}}
if params.weighting in ['EW', 'UW', 'GradNorm', 'GLS', 'RLW', 'MGDA', 'IMTL',
'PCGrad', 'GradVac', 'CAGrad', 'GradDrop', 'DWA', 'DIY']:
'PCGrad', 'GradVac', 'CAGrad', 'GradDrop', 'DWA', 'Nash_MTL']:
if params.weighting in ['DWA']:
if params.T is not None:
kwargs['weight_args']['T'] = params.T
Expand Down Expand Up @@ -97,6 +101,13 @@ def prepare_args(params):
kwargs['weight_args']['rescale'] = params.rescale
else:
raise ValueError('CAGrad needs keywaord calpha and rescale')
elif params.weighting in ['Nash_MTL']:
if params.update_weights_every is not None and params.optim_niter is not None and params.max_norm is not None:
kwargs['weight_args']['update_weights_every'] = params.update_weights_every
kwargs['weight_args']['optim_niter'] = params.optim_niter
kwargs['weight_args']['max_norm'] = params.max_norm
else:
raise ValueError('Nash_MTL needs update_weights_every, optim_niter, and max_norm')
else:
raise ValueError('No support weighting method {}'.format(params.weighting))

Expand Down
138 changes: 138 additions & 0 deletions LibMTL/weighting/Nash_MTL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from LibMTL.weighting.abstract_weighting import AbsWeighting

try:
import cvxpy as cp
except ModuleNotFoundError:
from pip._internal import main as pip
pip(['install', '--user', 'cvxpy'])
import cvxpy as cp

class Nash_MTL(AbsWeighting):
r"""Nash-MTL.
This method is proposed in `Multi-Task Learning as a Bargaining Game (ICML 2022) <https://proceedings.mlr.press/v162/navon22a/navon22a.pdf>`_ \
and implemented by modifying from the `official PyTorch implementation <https://github.com/AvivNavon/nash-mtl>`_.
Args:
update_weights_every (int, default=1): Period of weights update.
optim_niter (int, default=20): The max iteration of optimization solver.
max_norm (float, default=1.0): The max norm of the gradients.
.. warning::
Nash_MTL is not supported by representation gradients, i.e., ``rep_grad`` must be ``False``.
"""
def __init__(self):
super(Nash_MTL, self).__init__()

def init_param(self):
self.step = 0
self.prvs_alpha_param = None
self.init_gtg = np.eye(self.task_num)
self.prvs_alpha = np.ones(self.task_num, dtype=np.float32)
self.normalization_factor = np.ones((1,))

def _stop_criteria(self, gtg, alpha_t):
return (
(self.alpha_param.value is None)
or (np.linalg.norm(gtg @ alpha_t - 1 / (alpha_t + 1e-10)) < 1e-3)
or (
np.linalg.norm(self.alpha_param.value - self.prvs_alpha_param.value)
< 1e-6
)
)

def solve_optimization(self, gtg: np.array):
self.G_param.value = gtg
self.normalization_factor_param.value = self.normalization_factor

alpha_t = self.prvs_alpha
for _ in range(self.optim_niter):
self.alpha_param.value = alpha_t
self.prvs_alpha_param.value = alpha_t

try:
self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100)
except:
self.alpha_param.value = self.prvs_alpha_param.value

if self._stop_criteria(gtg, alpha_t):
break

alpha_t = self.alpha_param.value

if alpha_t is not None:
self.prvs_alpha = alpha_t

return self.prvs_alpha

def _calc_phi_alpha_linearization(self):
G_prvs_alpha = self.G_param @ self.prvs_alpha_param
prvs_phi_tag = 1 / self.prvs_alpha_param + (1 / G_prvs_alpha) @ self.G_param
phi_alpha = prvs_phi_tag @ (self.alpha_param - self.prvs_alpha_param)
return phi_alpha

def _init_optim_problem(self):
self.alpha_param = cp.Variable(shape=(self.task_num,), nonneg=True)
self.prvs_alpha_param = cp.Parameter(
shape=(self.task_num,), value=self.prvs_alpha
)
self.G_param = cp.Parameter(
shape=(self.task_num, self.task_num), value=self.init_gtg
)
self.normalization_factor_param = cp.Parameter(
shape=(1,), value=np.array([1.0])
)

self.phi_alpha = self._calc_phi_alpha_linearization()

G_alpha = self.G_param @ self.alpha_param
constraint = []
for i in range(self.task_num):
constraint.append(
-cp.log(self.alpha_param[i] * self.normalization_factor_param)
- cp.log(G_alpha[i])
<= 0
)
obj = cp.Minimize(
cp.sum(G_alpha) + self.phi_alpha / self.normalization_factor_param
)
self.prob = cp.Problem(obj, constraint)

def backward(self, losses, **kwargs):
self.update_weights_every = kwargs['update_weights_every']
self.optim_niter = kwargs['optim_niter']
self.max_norm = kwargs['max_norm']

if self.step == 0:
self._init_optim_problem()
if (self.step % self.update_weights_every) == 0:
self.step += 1

if self.rep_grad:
raise ValueError('No support method Nash_MTL with representation gradients (rep_grad=True)')
else:
self._compute_grad_dim()
grads = self._compute_grad(losses, mode='autograd')

GTG = torch.mm(grads, grads.t())
self.normalization_factor = torch.norm(GTG).detach().cpu().numpy().reshape((1,))
GTG = GTG / self.normalization_factor.item()
alpha = self.solve_optimization(GTG.cpu().detach().numpy())
alpha = torch.from_numpy(alpha).to(torch.float32).to(self.device)
else:
self.step += 1
alpha = self.prvs_alpha

torch.sum(alpha*losses).backward()

if self.max_norm > 0:
torch.nn.utils.clip_grad_norm_(self.get_share_params(), self.max_norm)

return alpha.detach().cpu().numpy()
6 changes: 2 additions & 4 deletions LibMTL/weighting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from LibMTL.weighting.PCGrad import PCGrad
from LibMTL.weighting.GradVac import GradVac
from LibMTL.weighting.IMTL import IMTL
# from LibMTL.weighting.MOML import MOML
from LibMTL.weighting.CAGrad import CAGrad
# from LibMTL.weighting.RotoGrad import RotoGrad
from LibMTL.weighting.Nash_MTL import Nash_MTL
from LibMTL.weighting.RLW import RLW

__all__ = ['AbsWeighting',
Expand All @@ -25,7 +24,6 @@
'PCGrad',
'GradVac',
'IMTL',
# 'MOML',
'CAGrad',
# 'RotoGrad',
'Nash_MTL',
'RLW']
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

## News

- **[Jul 21 2022]**: Added support for [Learning to Branch](http://proceedings.mlr.press/v119/guo20e/guo20e.pdf). Many thanks to [@yuezhixiong](https://github.com/yuezhixiong) ([#14](https://github.com/median-research-group/LibMTL/pull/14)).
- **[Jul 22 2022]**: Added support for [Nash-MTL](https://proceedings.mlr.press/v162/navon22a/navon22a.pdf) (ICML 2022).
- **[Jul 21 2022]**: Added support for [Learning to Branch](http://proceedings.mlr.press/v119/guo20e/guo20e.pdf) (ICML 2020). Many thanks to [@yuezhixiong](https://github.com/yuezhixiong) ([#14](https://github.com/median-research-group/LibMTL/pull/14)).
- **[Mar 29 2022]**: Paper is now available on the [arXiv](https://arxiv.org/abs/2203.14338).

## Table of Content
Expand All @@ -28,7 +29,7 @@
## Features

- **Unified**: ``LibMTL`` provides a unified code base to implement and a consistent evaluation procedure including data processing, metric objectives, and hyper-parameters on several representative MTL benchmark datasets, which allows quantitative, fair, and consistent comparisons between different MTL algorithms.
- **Comprehensive**: ``LibMTL`` supports 96 MTL models combined by 8 architectures and 12 loss weighting strategies. Meanwhile, ``LibMTL`` provides a fair comparison on 3 computer vision datasets.
- **Comprehensive**: ``LibMTL`` supports 104 MTL models combined by 8 architectures and 13 loss weighting strategies. Meanwhile, ``LibMTL`` provides a fair comparison on 3 computer vision datasets.
- **Extensible**: ``LibMTL`` follows the modular design principles, which allows users to flexibly and conveniently add customized components or make personalized modifications. Therefore, users can easily and fast develop novel loss weighting strategies and architectures or apply the existing MTL algorithms to new application scenarios with the support of ``LibMTL``.

## Overall Framework
Expand All @@ -41,7 +42,7 @@ Each module is introduced in [Docs](https://libmtl.readthedocs.io/en/latest/docs

``LibMTL`` currently supports the following algorithms:

- 12 loss weighting strategies.
- 13 loss weighting strategies.

| Weighting Strategy | Venues | Comments |
| ------------------------------------------------------------ | ------------------- | ------------------------------------------------------------ |
Expand All @@ -56,6 +57,7 @@ Each module is introduced in [Docs](https://libmtl.readthedocs.io/en/latest/docs
| Impartial Multi-Task Learning ([IMTL](https://openreview.net/forum?id=IMPnRXEWpvr)) | ICLR 2021 | Implemented by us |
| Gradient Vaccine ([GradVac](https://openreview.net/forum?id=F1vEjWK-lH_)) | ICLR 2021 Spotlight | Implemented by us |
| Conflict-Averse Gradient descent ([CAGrad](https://openreview.net/forum?id=_61Qh8tULj_)) | NeurIPS 2021 | Referenced from [official PyTorch implementation](https://github.com/Cranial-XIX/CAGrad) |
| [Nash-MTL](https://proceedings.mlr.press/v162/navon22a/navon22a.pdf) | ICML 2022 | Referenced from [official PyTorch implementation](https://github.com/AvivNavon/nash-mtl) |
| Random Loss Weighting ([RLW](https://arxiv.org/abs/2111.10603)) | arXiv | Implemented by us |

- 8 architectures.
Expand All @@ -66,8 +68,7 @@ Each module is introduced in [Docs](https://libmtl.readthedocs.io/en/latest/docs
| Cross-stitch Networks ([Cross_stitch](https://openaccess.thecvf.com/content_cvpr_2016/papers/Misra_Cross-Stitch_Networks_for_CVPR_2016_paper.pdf)) | CVPR 2016 | Implemented by us |
| Multi-gate Mixture-of-Experts ([MMoE](https://dl.acm.org/doi/10.1145/3219819.3220007)) | KDD 2018 | Implemented by us |
| Multi-Task Attention Network ([MTAN](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_End-To-End_Multi-Task_Learning_With_Attention_CVPR_2019_paper.pdf)) | CVPR 2019 | Referenced from [official PyTorch implementation](https://github.com/lorenmt/mtan) |
| Customized Gate Control ([CGC](https://dl.acm.org/doi/10.1145/3383313.3412236)) | ACM RecSys 2020 Best Paper | Implemented by us |
| Progressive Layered Extraction ([PLE](https://dl.acm.org/doi/10.1145/3383313.3412236)) | ACM RecSys 2020 Best Paper | Implemented by us |
| Customized Gate Control ([CGC](https://dl.acm.org/doi/10.1145/3383313.3412236)), Progressive Layered Extraction ([PLE](https://dl.acm.org/doi/10.1145/3383313.3412236)) | ACM RecSys 2020 Best Paper | Implemented by us |
| Learning to Branch ([LTB](http://proceedings.mlr.press/v119/guo20e/guo20e.pdf)) | ICML 2020 | Implemented by us |
| [DSelect-k](https://proceedings.neurips.cc/paper/2021/hash/f5ac21cd0ef1b88e9848571aeb53551a-Abstract.html) | NeurIPS 2021 | Referenced from [official TensorFlow implementation](https://github.com/google-research/google-research/tree/master/dselect_k_moe) |

Expand Down Expand Up @@ -132,7 +133,7 @@ If you have any question or suggestion, please feel free to contact us by [raisi

## Acknowledgements

We would like to thank the authors that release the public repositories (listed alphabetically): [CAGrad](https://github.com/Cranial-XIX/CAGrad), [dselect_k_moe](https://github.com/google-research/google-research/tree/master/dselect_k_moe), [MultiObjectiveOptimization](https://github.com/isl-org/MultiObjectiveOptimization), and [mtan](https://github.com/lorenmt/mtan).
We would like to thank the authors that release the public repositories (listed alphabetically): [CAGrad](https://github.com/Cranial-XIX/CAGrad), [dselect_k_moe](https://github.com/google-research/google-research/tree/master/dselect_k_moe), [MultiObjectiveOptimization](https://github.com/isl-org/MultiObjectiveOptimization), [mtan](https://github.com/lorenmt/mtan), and [nash-mtl](https://github.com/AvivNavon/nash-mtl).

## License

Expand Down

0 comments on commit 04cf41a

Please sign in to comment.