-
-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
adv.py
85 lines (74 loc) · 2.79 KB
/
adv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
import numpy as np
class LambdaSheduler(nn.Module):
def __init__(self, gamma=1.0, max_iter=1000, **kwargs):
super(LambdaSheduler, self).__init__()
self.gamma = gamma
self.max_iter = max_iter
self.curr_iter = 0
def lamb(self):
p = self.curr_iter / self.max_iter
lamb = 2. / (1. + np.exp(-self.gamma * p)) - 1
return lamb
def step(self):
self.curr_iter = min(self.curr_iter + 1, self.max_iter)
class AdversarialLoss(nn.Module):
'''
Acknowledgement: The adversarial loss implementation is inspired by http://transfer.thuml.ai/
'''
def __init__(self, gamma=1.0, max_iter=1000, use_lambda_scheduler=True, **kwargs):
super(AdversarialLoss, self).__init__()
self.domain_classifier = Discriminator()
self.use_lambda_scheduler = use_lambda_scheduler
if self.use_lambda_scheduler:
self.lambda_scheduler = LambdaSheduler(gamma, max_iter)
def forward(self, source, target):
lamb = 1.0
if self.use_lambda_scheduler:
lamb = self.lambda_scheduler.lamb()
self.lambda_scheduler.step()
source_loss = self.get_adversarial_result(source, True, lamb)
target_loss = self.get_adversarial_result(target, False, lamb)
adv_loss = 0.5 * (source_loss + target_loss)
return adv_loss
def get_adversarial_result(self, x, source=True, lamb=1.0):
x = ReverseLayerF.apply(x, lamb)
domain_pred = self.domain_classifier(x)
device = domain_pred.device
if source:
domain_label = torch.ones(len(x), 1).long()
else:
domain_label = torch.zeros(len(x), 1).long()
loss_fn = nn.BCELoss()
loss_adv = loss_fn(domain_pred, domain_label.float().to(device))
return loss_adv
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class Discriminator(nn.Module):
def __init__(self, input_dim=256, hidden_dim=256):
super(Discriminator, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
layers = [
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
]
self.layers = torch.nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)