-
Notifications
You must be signed in to change notification settings - Fork 173
/
abstract_weighting.py
126 lines (113 loc) · 5.56 KB
/
abstract_weighting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch, sys, random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class AbsWeighting(nn.Module):
r"""An abstract class for weighting strategies.
"""
def __init__(self):
super(AbsWeighting, self).__init__()
def init_param(self):
r"""Define and initialize some trainable parameters required by specific weighting methods.
"""
pass
def _compute_grad_dim(self):
self.grad_index = []
for param in self.get_share_params():
self.grad_index.append(param.data.numel())
self.grad_dim = sum(self.grad_index)
def _grad2vec(self):
grad = torch.zeros(self.grad_dim)
count = 0
for param in self.get_share_params():
if param.grad is not None:
beg = 0 if count == 0 else sum(self.grad_index[:count])
end = sum(self.grad_index[:(count+1)])
grad[beg:end] = param.grad.data.view(-1)
count += 1
return grad
def _compute_grad(self, losses, mode, rep_grad=False):
'''
mode: backward, autograd
'''
if not rep_grad:
grads = torch.zeros(self.task_num, self.grad_dim).to(self.device)
for tn in range(self.task_num):
if mode == 'backward':
losses[tn].backward(retain_graph=True) if (tn+1)!=self.task_num else losses[tn].backward()
grads[tn] = self._grad2vec()
elif mode == 'autograd':
grad = list(torch.autograd.grad(losses[tn], self.get_share_params(), retain_graph=True))
grads[tn] = torch.cat([g.view(-1) for g in grad])
else:
raise ValueError('No support {} mode for gradient computation')
self.zero_grad_share_params()
else:
if not isinstance(self.rep, dict):
grads = torch.zeros(self.task_num, *self.rep.size()).to(self.device)
else:
grads = [torch.zeros(*self.rep[task].size()) for task in self.task_name]
for tn, task in enumerate(self.task_name):
if mode == 'backward':
losses[tn].backward(retain_graph=True) if (tn+1)!=self.task_num else losses[tn].backward()
grads[tn] = self.rep_tasks[task].grad.data.clone()
return grads
def _reset_grad(self, new_grads):
count = 0
for param in self.get_share_params():
if param.grad is not None:
beg = 0 if count == 0 else sum(self.grad_index[:count])
end = sum(self.grad_index[:(count+1)])
param.grad.data = new_grads[beg:end].contiguous().view(param.data.size()).data.clone()
count += 1
def _get_grads(self, losses, mode='backward'):
r"""This function is used to return the gradients of representations or shared parameters.
If ``rep_grad`` is ``True``, it returns a list with two elements. The first element is \
the gradients of the representations with the size of [task_num, batch_size, rep_size]. \
The second element is the resized gradients with size of [task_num, -1], which means \
the gradient of each task is resized as a vector.
If ``rep_grad`` is ``False``, it returns the gradients of the shared parameters with size \
of [task_num, -1], which means the gradient of each task is resized as a vector.
"""
if self.rep_grad:
per_grads = self._compute_grad(losses, mode, rep_grad=True)
if not isinstance(self.rep, dict):
grads = per_grads.reshape(self.task_num, self.rep.size()[0], -1).sum(1)
else:
try:
grads = torch.stack(per_grads).sum(1).view(self.task_num, -1)
except:
raise ValueError('The representation dimensions of different tasks must be consistent')
return [per_grads, grads]
else:
self._compute_grad_dim()
grads = self._compute_grad(losses, mode)
return grads
def _backward_new_grads(self, batch_weight, per_grads=None, grads=None):
r"""This function is used to reset the gradients and make a backward.
Args:
batch_weight (torch.Tensor): A tensor with size of [task_num].
per_grad (torch.Tensor): It is needed if ``rep_grad`` is True. The gradients of the representations.
grads (torch.Tensor): It is needed if ``rep_grad`` is False. The gradients of the shared parameters.
"""
if self.rep_grad:
if not isinstance(self.rep, dict):
# transformed_grad = torch.einsum('i, i... -> ...', batch_weight, per_grads)
transformed_grad = sum([batch_weight[i] * per_grads[i] for i in range(self.task_num)])
self.rep.backward(transformed_grad)
else:
for tn, task in enumerate(self.task_name):
rg = True if (tn+1)!=self.task_num else False
self.rep[task].backward(batch_weight[tn]*per_grads[tn], retain_graph=rg)
else:
# new_grads = torch.einsum('i, i... -> ...', batch_weight, grads)
new_grads = sum([batch_weight[i] * grads[i] for i in range(self.task_num)])
self._reset_grad(new_grads)
@property
def backward(self, losses, **kwargs):
r"""
Args:
losses (list): A list of losses of each task.
kwargs (dict): A dictionary of hyperparameters of weighting methods.
"""
pass