-
Notifications
You must be signed in to change notification settings - Fork 0
/
ddc.py
39 lines (29 loc) · 1.08 KB
/
ddc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Define a ddc Layer
import torch.nn as nn
class none_layer(nn.Module):
def __init__(self):
super().__init__()
self.output = None
def forward(self, x):
self.output = x
return self.output
class DDC(nn.Module):
def __init__(self, input_dim, cfg):
"""
DDC clustering module
:param input_dim: Shape of inputs.
:param cfg: DDC config. See `config.defaults.DDC`
"""
super().__init__()
hidden_layers = [nn.Linear(input_dim[0], cfg.n_hidden), nn.ReLU()]
if cfg.use_bn:
hidden_layers.append(nn.BatchNorm1d(num_features=cfg.n_hidden))
self.hidden = nn.Sequential(*hidden_layers)
self.output = nn.Sequential(nn.Linear(cfg.n_hidden, cfg.n_clusters), nn.Softmax(dim=1))
if cfg.direct:
self.hidden = none_layer()
self.output = nn.Sequential(nn.Linear(input_dim[0], cfg.n_clusters), nn.Softmax(dim=1))
def forward(self, x):
hidden = self.hidden(x)
output = self.output(hidden)
return output, hidden