-
Notifications
You must be signed in to change notification settings - Fork 4
/
elastic_weight_consolidation.py
61 lines (51 loc) · 2.46 KB
/
elastic_weight_consolidation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import autograd
import numpy as np
from torch.utils.data import DataLoader
class ElasticWeightConsolidation:
def __init__(self, model, weight=0.1):
self.model = model
self.weight = weight
def _update_mean_params(self):
for param_name, param in self.model.named_parameters():
_buff_param_name = param_name.replace('.', '__')
self.model.register_buffer(_buff_param_name+'_estimated_mean', param.data.clone())
def _update_fisher_params(self, current_ds, batch_size, num_batch):
dl = current_ds
log_liklihoods = []
for i, (input, target) in enumerate(dl):
input, target = input.to(torch.device('cuda')), target.to(torch.device('cuda'))
if i > num_batch - 1:
break
output = F.log_softmax(self.model(input), dim=1)
log_liklihoods.append(output[:, target])
log_likelihood = torch.cat(log_liklihoods).mean()
grad_log_liklihood = autograd.grad(log_likelihood, self.model.parameters())
_buff_param_names = [param[0].replace('.', '__') for param in self.model.named_parameters()]
for _buff_param_name, param in zip(_buff_param_names, grad_log_liklihood):
self.model.register_buffer(_buff_param_name+'_estimated_fisher', param.data.clone() ** 2)
def register_ewc_params(self, dataset, batch_size, num_batches):
self._update_fisher_params(dataset, batch_size, num_batches)
self._update_mean_params()
def _compute_consolidation_loss(self, weight):
try:
losses = []
for param_name, param in self.model.named_parameters():
_buff_param_name = param_name.replace('.', '__')
estimated_mean = getattr(self.model, '{}_estimated_mean'.format(_buff_param_name))
estimated_fisher = getattr(self.model, '{}_estimated_fisher'.format(_buff_param_name))
losses.append((estimated_fisher * (param - estimated_mean) ** 2).sum())
return (weight / 2) * sum(losses)
except AttributeError:
return 0
def consolidation_loss(self, output, target):
return self._compute_consolidation_loss(self.weight)
def save(self, filename):
torch.save(self.model, filename)
def load(self, filename):
self.model = torch.load(filename)