In [1]:
# Copyright 2020 Erik Härkönen. All rights reserved.
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy
# of the License at http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# Show top 10 PCs for StyleGAN2 ffhq
# Center along component before manipulation
# Also show cleaned up PCs based on top10, a couple of cleaned up later style PCs
%matplotlib inline
from notebook_init import *
import time

out_root = Path('out/consistency')
makedirs(out_root, exist_ok=True)
# rand = lambda : np.random.randint(np.iinfo(np.int32).max)

In [2]:
use_w = True
inst = get_instrumented_model('StyleGAN2', 'ffhq', 'style', device, inst=inst, use_w=use_w)
model = inst.model
model.truncation = 1.0

Loading ../models/checkpoints/stylegan2/stylegan2_ffhq_1024.pt


In [3]:
gs_dir = np.load('./global_directions/ganspace_directions_ffhq.npy')#Note! Only ffhq is provided.
gs_dir = torch.from_numpy(gs_dir).to(device)
sf_dir = np.load('./global_directions/sefa_directions_ffhq.npy')#Note! Only ffhq is provided.
sf_dir = torch.from_numpy(sf_dir).to(device)
class compare_basis_config:
    n_samples = 50
    seed = 0
    subspace_dim = 2

In [4]:
torch.autograd.set_grad_enabled(True)
eval_config = compare_basis_config()

## Grassmannain Metric between two random O(n)

In [None]:
eval_config.n_samples = 100

for metric_type in ['geodesic', 'proj']:
    '''
    Grassmannain Metric between two Random Orthogonal Matrix
    '''
    consistency_list = []
    time_list = []

    for dim in range(1, 51):
        eval_config.subspace_dim = dim

        timer = time.time()
        consistency = evaluate_random_basis_consistency(model, eval_config, metric_type = metric_type)
        consistency_list.append(consistency)
        time_list.append(round(time.time() - timer, 2))

        if (dim + 1) % 10 == 0:
            print(f'Evaluated dim {dim + 1}')

    plt.plot(consistency_list)
    plt.savefig(f'./out/consistency/consistency_o(n)_metric_{metric_type}_n_{eval_config.n_samples}.png')
    plt.show()

    with open(f'./out/consistency/consistency_o(n)_metric_{metric_type}_n_{eval_config.n_samples}.dill', 'wb') as f:
        pickle.dump(consistency_list, f)

## Grassmannain Metric between two Random w

In [None]:
eval_config.n_samples = 1000

for metric_type in ['geodesic', 'proj']:
    '''
    Grassmannain Metric between two Random w
    '''
    consistency_list = []
    time_list = []
    for dim in range(1, 51):
        eval_config.subspace_dim = dim

        timer = time.time()
        consistency = evaluate_basis_consistency(model, eval_config, metric_type = metric_type)
        consistency_list.append(consistency)
        time_list.append(round(time.time() - timer, 2))

        if (dim + 1) % 10 == 0:
            print(f'Evaluated dim {dim + 1}')

    plt.plot(consistency_list)
    plt.savefig(f'./out/consistency/consistency_random_metric_{metric_type}_n_{eval_config.n_samples}.png')
    plt.show()
    
    with open(f'./out/consistency/consistency_random_metric_{metric_type}_n_{eval_config.n_samples}.dill', 'wb') as f:
        pickle.dump(consistency_list, f)
    

## Grassmannain Metric between two Close w

In [None]:
eval_config.n_samples = 1000

for metric_type in ['geodesic', 'proj']:
    '''
    Grassmannain Metric between two Random w
    '''
    consistency_list = []
    time_list = []
    eps = 1e-1

    for dim in range(1, 51):
        eval_config.subspace_dim = dim

        timer = time.time()
        consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type)
        consistency_list.append(consistency)
        time_list.append(round(time.time() - timer, 2))

        if (dim + 1) % 10 == 0:
            print(f'Evaluated dim {dim + 1}')

    plt.plot(consistency_list)
    plt.savefig(f'./out/consistency/consistency_local_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.png')
    plt.show()
    
    with open(f'./out/consistency/consistency_local_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.dill', 'wb') as f:
        pickle.dump(consistency_list, f)

## Grassmannain Metric between random w and GANSpace

In [None]:
eval_config.n_samples = 1000
global_basis = gs_dir.squeeze().t().cpu()

for metric_type in ['geodesic', 'proj']:
    '''
    Grassmannain Metric between two Random w
    '''
    consistency_list = []
    time_list = []

    for dim in range(1, 51):
        eval_config.subspace_dim = dim

        timer = time.time()
        consistency = evaluate_basis_consistency_to_global(model, eval_config, global_basis,  metric_type = metric_type)
        consistency_list.append(consistency)
        time_list.append(round(time.time() - timer, 2))

        if (dim + 1) % 10 == 0:
            print(f'Evaluated dim {dim + 1}')

    plt.plot(consistency_list)
    plt.savefig(f'./out/consistency/consistency_to_global_ganspace_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.png')
    plt.show()
    
    with open(f'./out/consistency/consistency_to_global_ganspace_metric_{metric_type}_n_{eval_config.n_samples}.dill', 'wb') as f:
        pickle.dump(consistency_list, f)

## Grassmannain Metric between random w and SeFa

In [None]:
eval_config.n_samples = 1000
global_basis = sf_dir.squeeze().t().cpu()

for metric_type in ['geodesic', 'proj']:
    '''
    Grassmannain Metric between two Random w
    '''
    consistency_list = []
    time_list = []
    
    for dim in range(1, 51):
        eval_config.subspace_dim = dim

        timer = time.time()
        consistency = evaluate_basis_consistency_to_global(model, eval_config, global_basis,  metric_type = metric_type)
        consistency_list.append(consistency)
        time_list.append(round(time.time() - timer, 2))

        if (dim + 1) % 10 == 0:
            print(f'Evaluated dim {dim + 1}')

    plt.plot(consistency_list)
    plt.savefig(f'./out/consistency/consistency_to_global_sefa_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.png')
    plt.show()
    
    with open(f'./out/consistency/consistency_to_global_sefa_metric_{metric_type}_n_{eval_config.n_samples}.dill', 'wb') as f:
        pickle.dump(consistency_list, f)

## Ablation study for eps in Close W

In [None]:
import pickle
eval_config.n_samples = 1000

for metric_type in ['geodesic', 'proj']:
    for eps in [1e-2, 1e-1, 0.25, 5e-1]:
        consistency_list = []
        time_list = []

        for dim in range(1, 51):
            eval_config.subspace_dim = dim

            timer = time.time()
            consistency = evaluate_basis_consistency_local(model, eval_config, eps = eps,  metric_type = metric_type)
            consistency_list.append(consistency)
            time_list.append(round(time.time() - timer, 2))

            if (dim + 1) % 10 == 0:
                print(f'Evaluated dim {dim + 1}')

        plt.plot(consistency_list)
        plt.savefig(f'./out/consistency/consistency_local_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.png')
        plt.show()

        with open(f'./out/consistency/consistency_local_metric_{metric_type}_eps_{eps}_n_{eval_config.n_samples}.dill', 'wb') as f:
            pickle.dump(consistency_list, f)
    