-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
140 lines (103 loc) · 4.08 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
import torch.nn.functional as F
# helper conv function
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
"""Creates a convolutional layer, with optional batch normalization.
"""
layers = []
conv_layer = nn.Conv2d(in_channels, out_channels,
kernel_size, stride, padding, bias=False)
# append conv layer
layers.append(conv_layer)
if batch_norm:
# append batchnorm layer
layers.append(nn.BatchNorm2d(out_channels))
# using Sequential container
return nn.Sequential(*layers)
class Discriminator(nn.Module):
def __init__(self, conv_dim=32):
super(Discriminator, self).__init__()
# complete init function
self.conv_dim = conv_dim
# 32x32 input
self.conv1 = conv(3, conv_dim, 4, batch_norm=False) # first layer, no batch_norm
# 16x16 out
self.conv2 = conv(conv_dim, conv_dim * 2, 4)
# 8x8 out
self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
# 4x4 out
# final, fully-connected layer
self.fc = nn.Linear(conv_dim * 4 * 4 * 4, 1)
def forward(self, x):
# all hidden layers + leaky relu activation
out = F.leaky_relu(self.conv1(x), 0.2)
out = F.leaky_relu(self.conv2(out), 0.2)
out = F.leaky_relu(self.conv3(out), 0.2)
# flatten
out = out.view(-1, self.conv_dim * 4 * 4 * 4)
# final output layer
out = self.fc(out)
return out
# helper deconv function
def deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):
"""Creates a transposed-convolutional layer, with optional batch normalization.
"""
# create a sequence of transpose + optional batch norm layers
layers = []
transpose_conv_layer = nn.ConvTranspose2d(in_channels, out_channels,
kernel_size, stride, padding, bias=False)
# append transpose convolutional layer
layers.append(transpose_conv_layer)
if batch_norm:
# append batchnorm layer
layers.append(nn.BatchNorm2d(out_channels))
return nn.Sequential(*layers)
class Generator(nn.Module):
def __init__(self, z_size, conv_dim=32):
super(Generator, self).__init__()
# complete init function
self.conv_dim = conv_dim
# first, fully-connected layer
self.fc = nn.Linear(z_size, conv_dim * 4 * 4 * 4)
# transpose conv layers
self.t_conv1 = deconv(conv_dim * 4, conv_dim * 2, 4)
self.t_conv2 = deconv(conv_dim * 2, conv_dim, 4)
self.t_conv3 = deconv(conv_dim, 3, 4, batch_norm=False)
def forward(self, x):
# fully-connected + reshape
out = self.fc(x)
out = out.view(-1, self.conv_dim * 4, 4, 4) # (batch_size, depth, 4, 4)
# hidden transpose conv layers + relu
out = F.relu(self.t_conv1(out))
out = F.relu(self.t_conv2(out))
# last layer + tanh activation
out = self.t_conv3(out)
out = torch.tanh(out)
return out
def weights_init_normal(m):
"""
Applies initial weights to certain layers in a model .
The weights are taken from a normal distribution
with mean = 0, std dev = 0.02.
:param m: A module or layer in a network
"""
# classname will be something like:
# `Conv`, `BatchNorm2d`, `Linear`, etc.
classname = m.__class__.__name__
# TODO: Apply initial weights to convolutional and linear layers
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
nn.init.normal_(m.weight.data, mean=0, std=0.02)
elif classname.find('BatchNorm2d') != -1:
nn.init.normal_(m.weight.data, mean=0, std=0.02)
def build_network(d_conv_dim, g_conv_dim, z_size):
# define discriminator and generator
D = Discriminator(d_conv_dim)
G = Generator(z_size=z_size, conv_dim=g_conv_dim)
# initialize model weights
D.apply(weights_init_normal)
G.apply(weights_init_normal)
#print(D)
#print()
#print(G)
return D, G