In [None]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# Comparison to GAN steerability
%matplotlib inline
from notebook_init import *

out_root = Path('out/figures/steerability_comp')
makedirs(out_root, exist_ok=True)
rand = lambda : np.random.randint(np.iinfo(np.int32).max)

In [None]:
inst = get_instrumented_model('BigGAN-512', 'husky', 'generator.gen_z', device, inst=inst)
model = inst.model

pc_config = Config(components=80, n=1_000_000,
    layer='generator.gen_z', model='BigGAN-512', output_class='husky')
dump_name = get_or_compute(pc_config, inst)

with np.load(dump_name) as data:
    lat_comp = data['lat_comp']
    lat_mean = data['lat_mean']
    lat_std = data['lat_stdev']

# Indices determined by visual inspection
delta_ours_transl = lat_comp[0]
delta_ours_zoom = lat_comp[6]

In [None]:
def apply_edit(seed, delta, alpha, n_frames=7):
    z = model.sample_latent(1, seed=seed).cpu().numpy()
    
    frames = []
    for a in np.linspace(-alpha, alpha, n_frames):
        frames.append(model.sample_np(z + a*delta))
        
    return frames

def show_strip(frames):
    plt.figure(figsize=(20,20))
    plt.axis('off')
    plt.imshow(np.hstack(pad_frames(frames, 64)))
    plt.show()

In [None]:
import pickle
with open('gan_steer-linear_zoom_512.pkl', 'rb') as f:
    delta_steerability_zoom = pickle.load(f)['w_zoom'].reshape(1, 128)
with open('gan_steer-linear_shiftx_512.pkl', 'rb') as f:
    delta_steerability_transl = pickle.load(f)['w_shiftx'].reshape(1, 128)

In [None]:
normalize = lambda t : t / np.sqrt(np.sum(t.reshape(-1)**2))

# Normalize all
delta_steerability_zoom = normalize(delta_steerability_zoom)
delta_steerability_transl = normalize(delta_steerability_transl)
delta_ours_zoom = normalize(delta_ours_zoom)
delta_ours_transl = normalize(delta_ours_transl)

In [None]:
# Angles
dotp_zoom = np.dot(delta_steerability_zoom.reshape(-1), delta_ours_zoom.reshape(-1))
dotp_transl = np.dot(delta_steerability_transl.reshape(-1), delta_ours_transl.reshape(-1))

if dotp_zoom < 0:
    delta_ours_zoom *= -1
    
if dotp_transl < 0:
    delta_ours_transl *= -1

print('Zoom similarity:', dotp_zoom)
print('Translate similarity:', dotp_transl)

In [None]:
model.truncation = 0.6

def compute(prefix, imgclass, seeds, d_ours, scale_ours, d_steer, scale_steer):
    model.set_output_class(imgclass)
    makedirs(out_root / prefix, exist_ok=True)
    
    for seed in seeds:
        print(seed)
        deltas = [d_ours, d_steer]
        scales = [scale_ours, scale_steer]
        names = ['ours', 'steerability']

        for delta, name, scale in zip(deltas, names, scales):
            frames = apply_edit(seed, delta*scale, 1.0)
            for i, frame in enumerate(frames):
                Image.fromarray(np.uint8(frame*255)).save(
                    out_root / prefix / f'{imgclass}_{name}_{i}.png')
            
            strip = np.hstack(pad_frames(frames, 64))
            plt.figure(figsize=(12,12))
            plt.imshow(strip)
            plt.axis('off')
            plt.tight_layout()
            plt.title(f'{prefix} - {name}, scale={scale}')
            plt.show()


compute('zoom', 'robin', [560157313], delta_ours_zoom, 3.0, delta_steerability_zoom, 5.5)
compute('zoom', 'ship', [107715983], delta_ours_zoom, 3.0, delta_steerability_zoom, 5.0)
compute('translate', 'golden_retriever', [552411435], delta_ours_transl, 2.0, delta_steerability_transl, 4.5)
compute('translate', 'lemon', [331582800], delta_ours_transl, 3.0, delta_steerability_transl, 6.0)

print('Done')

