Skip to content

Commit

Permalink
adding support for batching for cppn image representation by creating…
Browse files Browse the repository at this point in the history
… a batch of networks. this should allow for forcing diversity when representing images via cppn networks
  • Loading branch information
shaibagon committed Dec 23, 2021
1 parent 6123386 commit 7312165
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions lucent/optvis/param/cppn.py
Expand Up @@ -27,7 +27,7 @@ def forward(self, x):
return torch.cat([x/0.67, (x*x)/0.6], 1)


def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8,
def cppn(size, num_output_channels=3, batch=None, num_hidden_channels=24, num_layers=8,
activation_fn=CompositeActivation, normalize=False):

r = 3 ** 0.5
Expand All @@ -40,32 +40,37 @@ def cppn(size, num_output_channels=3, num_hidden_channels=24, num_layers=8,

input_tensor = torch.stack([x, y], dim=0).unsqueeze(0).to(device)

layers = []
kernel_size = 1
for i in range(num_layers):
out_c = num_hidden_channels
in_c = out_c * 2 # * 2 for composite activation
if i == 0:
in_c = 2
if i == num_layers - 1:
out_c = num_output_channels
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size)))
if normalize:
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c)))
if i < num_layers - 1:
layers.append(('actv{}'.format(i), activation_fn()))
else:
layers.append(('output', torch.nn.Sigmoid()))
# batching is handled via a network per example in the batch
batch = 1 if batch is None else batch
nets = torch.nn.ModuleList()
for bi in range(batch):
layers = []
kernel_size = 1
for i in range(num_layers):
out_c = num_hidden_channels
in_c = out_c * 2 # * 2 for composite activation
if i == 0:
in_c = 2
if i == num_layers - 1:
out_c = num_output_channels
layers.append(('conv{}'.format(i), torch.nn.Conv2d(in_c, out_c, kernel_size)))
if normalize:
layers.append(('norm{}'.format(i), torch.nn.InstanceNorm2d(out_c)))
if i < num_layers - 1:
layers.append(('actv{}'.format(i), activation_fn()))
else:
layers.append(('output', torch.nn.Sigmoid()))

# Initialize model
net = torch.nn.Sequential(OrderedDict(layers)).to(device)
# Initialize weights
def weights_init(module):
if isinstance(module, torch.nn.Conv2d):
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
net.apply(weights_init)
# Set last conv2d layer's weights to 0
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight)
return net.parameters(), lambda: net(input_tensor)
# Initialize model
net = torch.nn.Sequential(OrderedDict(layers)).to(device)
# Initialize weights
def weights_init(module):
if isinstance(module, torch.nn.Conv2d):
torch.nn.init.normal_(module.weight, 0, np.sqrt(1/module.in_channels))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
net.apply(weights_init)
# Set last conv2d layer's weights to 0
torch.nn.init.zeros_(dict(net.named_children())['conv{}'.format(num_layers - 1)].weight)
nets.append(net)
return nets.parameters(), lambda: torch.cat([net(input_tensor) for net in nets], dim=0)

0 comments on commit 7312165

Please sign in to comment.