-
Notifications
You must be signed in to change notification settings - Fork 94
/
dsm.py
38 lines (29 loc) · 1.42 KB
/
dsm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.autograd as autograd
def dsm(energy_net, samples, sigma=1):
samples.requires_grad_(True)
vector = torch.randn_like(samples) * sigma
perturbed_inputs = samples + vector
logp = -energy_net(perturbed_inputs)
dlogp = sigma ** 2 * autograd.grad(logp.sum(), perturbed_inputs, create_graph=True)[0]
kernel = vector
loss = torch.norm(dlogp + kernel, dim=-1) ** 2
loss = loss.mean() / 2.
return loss
def dsm_score_estimation(scorenet, samples, sigma=0.01):
perturbed_samples = samples + torch.randn_like(samples) * sigma
target = - 1 / (sigma ** 2) * (perturbed_samples - samples)
scores = scorenet(perturbed_samples)
target = target.view(target.shape[0], -1)
scores = scores.view(scores.shape[0], -1)
loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1).mean(dim=0)
return loss
def anneal_dsm_score_estimation(scorenet, samples, labels, sigmas, anneal_power=2.):
used_sigmas = sigmas[labels].view(samples.shape[0], *([1] * len(samples.shape[1:])))
perturbed_samples = samples + torch.randn_like(samples) * used_sigmas
target = - 1 / (used_sigmas ** 2) * (perturbed_samples - samples)
scores = scorenet(perturbed_samples, labels)
target = target.view(target.shape[0], -1)
scores = scores.view(scores.shape[0], -1)
loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1) * used_sigmas.squeeze() ** anneal_power
return loss.mean(dim=0)