In [50]:
import glob
import sys
import os
import argparse
from networks import model_resnet_64px_prog
import torch
import numpy as np

In [5]:
gen, _ = model_resnet_64px_prog.get_network(z_dim=128, sigm=True)

  nn.init.xavier_uniform(self.dense.weight.data, 1.)
  nn.init.xavier_uniform(self.final.weight.data, 1.)
  nn.init.xavier_uniform(self.conv1.weight.data, 1.)
  nn.init.xavier_uniform(self.conv2.weight.data, 1.)
  nn.init.xavier_uniform(self.conv1.weight.data, 1.)
  nn.init.xavier_uniform(self.conv2.weight.data, 1.)
  nn.init.xavier_uniform(self.bypass_conv.weight.data, np.sqrt(2))
  nn.init.xavier_uniform(self.conv1.weight.data, 1.)
  nn.init.xavier_uniform(self.conv2.weight.data, 1.)
  nn.init.xavier_uniform(self.bypass_conv.weight.data, np.sqrt(2))


Generator:
Generator(
  (dense): Linear(in_features=128, out_features=8192, bias=True)
  (final): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (rbn1): ResBlockGenerator(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (ups): Upsample(scale_factor=2, mode=nearest)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bypass): Sequential(
      (0): Upsample(scale_factor=2, mode=nearest)
    )
  )
  (rbn2): ResBlockGenerator(
    (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (up

  nn.init.xavier_uniform(self.fc.weight.data, 1.)


Download `g_100.pkl` from here: https://mega.nz/#!sP5TmQgL!G3t_928bR3uQRmjETOHfK_xAjISKPYJ61iGvedkL7A4 and place it in the `models/` folder

In [157]:
dat = torch.load("models/g_100.pkl")
gen.load_state_dict(dat)

In [125]:
z_dim = 128

In [126]:
gen.eval()

Generator(
  (dense): Linear(in_features=128, out_features=8192, bias=True)
  (final): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (rbn1): ResBlockGenerator(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (ups): Upsample(scale_factor=2, mode=nearest)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bypass): Sequential(
      (0): Upsample(scale_factor=2, mode=nearest)
    )
  )
  (rbn2): ResBlockGenerator(
    (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (ups): Upsampl

In [12]:
import ipywidgets as widgets
from IPython.display import clear_output
from ipywidgets import IntSlider, Output
from ipywidgets.widgets.interaction import show_inline_matplotlib_plots
import matplotlib.pyplot as plt
%matplotlib inline

In [13]:
def grid_from_z(zval):
    sz = int(np.ceil(np.sqrt(len(zval))))
    grid = np.zeros((1, sz, sz))
    c = 0
    for i in range(sz):
        for j in range(sz):
            if c == len(zval):
                return grid
            grid[0, i, j] = zval[c]
            c += 1
    return grid

In [22]:
def sample(bs):
    z = np.random.binomial(1, 0.5, size=(bs, z_dim)).astype(np.float32)
    return z

-----------

## Randomly generating individual images

In [127]:
b = widgets.Button(
    description='Generate image',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
display(b)

out = widgets.Output()
display(out)

def on_button_clicked(b):
    with out:
        clear_output()
        z = sample(1)
        z_torch = (torch.from_numpy(z).float() - 0.5) / 0.5
        grid = grid_from_z(z[0])

        f, axes = plt.subplots(1, 2)
        axes[0].matshow(grid[0])
        axes[0].set_title('z')
        gen_img = (gen(z_torch)*0.5 + 0.5).detach().numpy()[0]
        axes[1].imshow(gen_img.swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[1].set_title('G(z)')
        # 
        for ax in axes:
            ax.set_xticklabels([])
            ax.set_xticks([])
            ax.set_yticklabels([])
            ax.set_yticks([])
        axes[0].grid(color='w', linestyle='-', linewidth=1)
        axes[0].xaxis.set_ticks(np.arange(12)+0.5)
        axes[0].yaxis.set_ticks(np.arange(12)+0.5)
        show_inline_matplotlib_plots()     

b.on_click(on_button_clicked)

Button(button_style='info', description='Generate image', icon='check', style=ButtonStyle(), tooltip='Click me…

Output()

-----------

## Combining two randomly generated images via bitwise OR

In [128]:
b2 = widgets.Button(
    description='Generate image',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
display(b2)

out2 = widgets.Output()
display(out2)

def on_button_clicked(b):
    with out2:
        clear_output()
        
        z = sample(2)
        z_torch = (torch.from_numpy(z).float() - 0.5) / 0.5      
        z_mix = (z[0].astype(np.bool) | z[1].astype(np.bool)).astype(np.float32)  
        z_mix_torch = (torch.from_numpy(z_mix).float() - 0.5) / 0.5
        
        grid1 = grid_from_z(z[0])
        grid2 = grid_from_z(z[1])
        grid3 = grid_from_z(z_mix)

        f, axes = plt.subplots(1, 6, figsize=(10,12))
        f.set_figheight(10)
        f.set_figwidth(15)
        axes[0].matshow(grid1[0])
        axes[0].set_title('z1')
        gen_img = (gen(z_torch)*0.5 + 0.5).detach().numpy()
        axes[1].imshow(gen_img[0].swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[1].set_title('G(z1)')
        axes[2].matshow(grid2[0])
        axes[2].set_title('z2')
        axes[3].imshow(gen_img[1].swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[3].set_title('G(z2)')
        axes[4].matshow(grid3[0])
        axes[4].set_title('z1 OR z2')
        gen_mix_img = (gen(z_mix_torch)*0.5 + 0.5).detach().numpy()[0]
        axes[5].imshow(gen_mix_img.swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[5].set_title('G(z1 OR z2)')
        for b, ax in enumerate(axes):
            ax.set_xticklabels([])
            ax.set_xticks([])
            ax.set_yticklabels([])
            ax.set_yticks([])
            if b in [0,2,4]:
                axes[b].grid(color='w', linestyle='-', linewidth=1)
                axes[b].xaxis.set_ticks(np.arange(12)+0.5)
                axes[b].yaxis.set_ticks(np.arange(12)+0.5)
        
        show_inline_matplotlib_plots()

b2.on_click(on_button_clicked)

Button(button_style='info', description='Generate image', icon='check', style=ButtonStyle(), tooltip='Click me…

Output()

----------

## Combining two randomly generated images via a continuous interpolation

In [166]:
slider_int = widgets.FloatSlider(
    value=0,
    min=0,
    max=1,
    step=0.1,
    description='Interp coef:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)
display(slider_int)

b_int = widgets.Button(
    description='Generate image',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)
display(b_int)

out_int = widgets.Output()
display(out_int)

def on_button_clicked(b):
    with out_int:
        clear_output()
        
        z = sample(2)
        z_torch = (torch.from_numpy(z).float() - 0.5) / 0.5      
        z_mix = slider_int.value*z[0] + (1-slider_int.value)*z[1]
        z_mix_torch = (torch.from_numpy(z_mix).float() - 0.5) / 0.5
        
        grid1 = grid_from_z(z[0])
        grid2 = grid_from_z(z[1])
        grid3 = grid_from_z(z_mix)

        f, axes = plt.subplots(1, 6, figsize=(10,12))
        f.set_figheight(10)
        f.set_figwidth(15)
        axes[0].matshow(grid1[0])
        axes[0].set_title('z1')
        gen_img = (gen(z_torch)*0.5 + 0.5).detach().numpy()
        axes[1].imshow(gen_img[0].swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[1].set_title('G(z1)')
        axes[2].matshow(grid2[0])
        axes[2].set_title('z2')
        axes[3].imshow(gen_img[1].swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[3].set_title('G(z2)')
        axes[4].matshow(grid3[0])
        axes[4].set_title('interp z')
        gen_mix_img = (gen(z_mix_torch)*0.5 + 0.5).detach().numpy()[0]
        axes[5].imshow(gen_mix_img.swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        axes[5].set_title('G(interp)')
        for b, ax in enumerate(axes):
            ax.set_xticklabels([])
            ax.set_xticks([])
            ax.set_yticklabels([])
            ax.set_yticks([])
            if b in [0,2,4]:
                axes[b].grid(color='w', linestyle='-', linewidth=1)
                axes[b].xaxis.set_ticks(np.arange(12)+0.5)
                axes[b].yaxis.set_ticks(np.arange(12)+0.5)
        
        show_inline_matplotlib_plots()

b_int.on_click(on_button_clicked)

FloatSlider(value=0.0, continuous_update=False, description='Interp coef:', max=1.0, readout_format='.1f')

Button(button_style='info', description='Generate image', icon='check', style=ButtonStyle(), tooltip='Click me…

Output()

In [161]:
slider_int.value

0.4

--------

## Manually select the bits to turn on

(Use your SHIFT/CMD keys to select multiple z's!)

In [129]:
selector = widgets.SelectMultiple(
    options=['z%i'%i for i in range(128)],
    value=['z1'],
    rows=16,
    description='zs',
    disabled=False,
)
b3 = widgets.Button(
    description='Generate image',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)

display(selector)
display(b3)

out3 = widgets.Output()
display(out3)

def on_button_clicked(b):
    with out3:
        clear_output()
        
        z = np.zeros((1, z_dim)).astype(np.float32)
        z[:, selector.index] += 1.

        z_torch = (torch.from_numpy(z).float() - 0.5) / 0.5
        grid = grid_from_z(z[0])

        f, (ax1, ax2) = plt.subplots(1, 2)
        ax1.matshow(grid[0])
        ax1.set_xticklabels([])
        ax1.set_xticks([])
        ax1.set_yticklabels([])
        ax1.set_yticks([])
        ax1.grid(color='w', linestyle='-', linewidth=1)
        ax1.set_title('z')
        ax1.xaxis.set_ticks(np.arange(12)+0.5)
        ax1.yaxis.set_ticks(np.arange(12)+0.5)
        gen_img = (gen(z_torch)*0.5 + 0.5).detach().numpy()[0]
        ax2.imshow(gen_img.swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        ax2.set_title('G(z)')
        ax2.axis('off')
        show_inline_matplotlib_plots()

b3.on_click(on_button_clicked)

SelectMultiple(description='zs', index=(1,), options=('z0', 'z1', 'z2', 'z3', 'z4', 'z5', 'z6', 'z7', 'z8', 'z…

Button(button_style='info', description='Generate image', icon='check', style=ButtonStyle(), tooltip='Click me…

Output()

-----

## Progressively turning on more bits

In [130]:
slider = widgets.IntSlider(
    value=0,
    min=0,
    max=127,
    step=1,
    description='# bits on:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)
display(slider)

b4 = widgets.Button(
    description='Generate image',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)

display(b4)

out4 = widgets.Output()
display(out4)

def on_button_clicked(b):
    with out4:
        clear_output()
        z = np.zeros((1, z_dim)).astype(np.float32)
        z[:, 0:slider.value] += 1.

        z_torch = (torch.from_numpy(z).float() - 0.5) / 0.5
        grid = grid_from_z(z[0])

        f, (ax1, ax2) = plt.subplots(1, 2)
        ax1.matshow(grid[0])
        ax1.set_title('z')
        ax1.set_xticklabels([])
        ax1.set_xticks([])
        ax1.set_yticklabels([])
        ax1.set_yticks([])
        ax1.grid(color='w', linestyle='-', linewidth=1)
        ax1.xaxis.set_ticks(np.arange(12)+0.5)
        ax1.yaxis.set_ticks(np.arange(12)+0.5)
        gen_img = (gen(z_torch)*0.5 + 0.5).detach().numpy()[0]
        ax2.imshow(gen_img.swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        ax2.set_title('G(z)')
        ax2.axis('off')
        show_inline_matplotlib_plots()

b4.on_click(on_button_clicked)

IntSlider(value=0, continuous_update=False, description='# bits on:', max=127)

Button(button_style='info', description='Generate image', icon='check', style=ButtonStyle(), tooltip='Click me…

Output()

--------

## Turn off a bit at one position

In [131]:
slider2 = widgets.IntSlider(
    value=0,
    min=0,
    max=127,
    step=1,
    description='Position #',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)
display(slider2)

b5 = widgets.Button(
    description='Generate image',
    disabled=False,
    button_style='info',
    tooltip='Click me',
    icon='check'
)

display(b5)

out5 = widgets.Output()
display(out5)

def on_button_clicked(b):
    with out5:
        clear_output()
        z = np.ones((1, z_dim)).astype(np.float32)
        z[:, slider2.value] = 0.

        z_torch = (torch.from_numpy(z).float() - 0.5) / 0.5
        grid = grid_from_z(z[0])

        f, (ax1, ax2) = plt.subplots(1, 2)
        ax1.matshow(grid[0])
        ax1.set_title('z')
        ax1.set_xticklabels([])
        ax1.set_xticks([])
        ax1.set_yticklabels([])
        ax1.set_yticks([])
        ax1.grid(color='w', linestyle='-', linewidth=1)
        ax1.xaxis.set_ticks(np.arange(12)+0.5)
        ax1.yaxis.set_ticks(np.arange(12)+0.5)
        gen_img = (gen(z_torch)*0.5 + 0.5).detach().numpy()[0]
        ax2.imshow(gen_img.swapaxes(0,1).swapaxes(1,2), interpolation='bilinear')
        ax2.set_title('G(z)')
        ax2.axis('off')
        show_inline_matplotlib_plots()

b5.on_click(on_button_clicked)

IntSlider(value=0, continuous_update=False, description='Position #', max=127)

Button(button_style='info', description='Generate image', icon='check', style=ButtonStyle(), tooltip='Click me…

Output()

------

In [147]:
%%bash
rm -r tmp
mkdir -p tmp/images
mkdir -p tmp/codes

In [148]:
rand_z = sample(128)
rand_z.shape

(128, 128)

In [149]:
# https://jakevdp.github.io/blog/2013/08/07/conways-game-of-life/
def life_step(X):
    """Game of life step using generator expressions"""
    nbrs_count = sum(np.roll(np.roll(X, i, 0), j, 1)
                     for i in (-1, 0, 1) for j in (-1, 0, 1)
                     if (i != 0 or j != 0))
    return (nbrs_count == 3) | (X & (nbrs_count == 2))


In [150]:
N = 100
X = np.zeros((12, 12), dtype=bool)
X = np.random.random((12, 12)) > 0.7
for j in range(N):
    if j % 100 == 0:
        print("Progress: %i" % j)
    this_z = X.flatten()[0:128].astype(np.float32)
    
    plt.matshow(grid_from_z(this_z)[0])
    num = '{:03d}'.format(j)
    plt.savefig('tmp/codes/%s.png' % num)
    plt.axis('off')
    plt.title('z')
    plt.close()
    
    z_torch = (torch.from_numpy(this_z).float() - 0.5) / 0.5
    plt.imshow(gen(z_torch).detach().numpy()[0].swapaxes(0,1).swapaxes(1,2)*0.5 + 0.5,
              interpolation='bilinear')
    plt.savefig('tmp/images/%s.png' % num)
    plt.close()
    
    X = life_step(X)
    print(np.sum(X))
    

Progress: 0




58
56
46
35
38
41
37
36
40
42
34
32
30
42
32
33
37
51
42
45
51
37
39
37
43
27
31
37
37
39
32
38
44
52
32
38
33
25
21
22
26
29
27
20
17
18
19
17
23
24
29
23
26
30
25
31
22
23
29
22
26
20
28
25
24
24
32
27
34
21
17
11
9
10
12
16
17
22
18
17
20
25
24
20
25
33
30
27
31
28
17
18
10
6
5
3
2
0
0
0


In [151]:
from skimage.io import imread, imsave

In [152]:
%%bash
rm -r tmp/final
mkdir tmp/final

rm: tmp/final: No such file or directory


In [153]:
for img_filename in glob.glob("tmp/images/*.png"):
    code_filename = "tmp/codes/%s" % os.path.basename(img_filename)
    img1 = imread(img_filename)
    img2 = imread(code_filename)
    imgf = np.hstack((img2, img1))
    imsave(fname="tmp/final/%s" % os.path.basename(img_filename), arr=imgf)

----