This repository has been archived by the owner on Aug 31, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 70
/
ewc.py
119 lines (103 loc) · 3.89 KB
/
ewc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .common import MLP, ResNet18
class Net(torch.nn.Module):
def __init__(self,
n_inputs,
n_outputs,
n_tasks,
args):
super(Net, self).__init__()
nl, nh = args.n_layers, args.n_hiddens
self.reg = args.memory_strength
# setup network
self.is_cifar = (args.data_file == 'cifar100.pt')
if self.is_cifar:
self.net = ResNet18(n_outputs)
else:
self.net = MLP([n_inputs] + [nh] * nl + [n_outputs])
# setup optimizer
self.opt = torch.optim.SGD(self.net.parameters(), lr=args.lr)
# setup losses
self.bce = torch.nn.CrossEntropyLoss()
# setup memories
self.current_task = 0
self.fisher = {}
self.optpar = {}
self.memx = None
self.memy = None
if self.is_cifar:
self.nc_per_task = n_outputs / n_tasks
else:
self.nc_per_task = n_outputs
self.n_outputs = n_outputs
self.n_memories = args.n_memories
def compute_offsets(self, task):
if self.is_cifar:
offset1 = task * self.nc_per_task
offset2 = (task + 1) * self.nc_per_task
else:
offset1 = 0
offset2 = self.n_outputs
return int(offset1), int(offset2)
def forward(self, x, t):
output = self.net(x)
if self.is_cifar:
# make sure we predict classes within the current task
offset1, offset2 = self.compute_offsets(t)
if offset1 > 0:
output[:, :offset1].data.fill_(-10e10)
if offset2 < self.n_outputs:
output[:, int(offset2):self.n_outputs].data.fill_(-10e10)
return output
def observe(self, x, t, y):
self.net.train()
# next task?
if t != self.current_task:
self.net.zero_grad()
if self.is_cifar:
offset1, offset2 = self.compute_offsets(self.current_task)
self.bce((self.net(self.memx)[:, offset1: offset2]),
self.memy - offset1).backward()
else:
self.bce(self(self.memx,
self.current_task),
self.memy).backward()
self.fisher[self.current_task] = []
self.optpar[self.current_task] = []
for p in self.net.parameters():
pd = p.data.clone()
pg = p.grad.data.clone().pow(2)
self.optpar[self.current_task].append(pd)
self.fisher[self.current_task].append(pg)
self.current_task = t
self.memx = None
self.memy = None
if self.memx is None:
self.memx = x.data.clone()
self.memy = y.data.clone()
else:
if self.memx.size(0) < self.n_memories:
self.memx = torch.cat((self.memx, x.data.clone()))
self.memy = torch.cat((self.memy, y.data.clone()))
if self.memx.size(0) > self.n_memories:
self.memx = self.memx[:self.n_memories]
self.memy = self.memy[:self.n_memories]
self.net.zero_grad()
if self.is_cifar:
offset1, offset2 = self.compute_offsets(t)
loss = self.bce((self.net(x)[:, offset1: offset2]),
y - offset1)
else:
loss = self.bce(self(x, t), y)
for tt in range(t):
for i, p in enumerate(self.net.parameters()):
l = self.reg * self.fisher[tt][i]
l = l * (p - self.optpar[tt][i]).pow(2)
loss += l.sum()
loss.backward()
self.opt.step()