In [13]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append('..')
import transects
from ganwrapper import Generator
from tqdm import tqdm

import util
from config import *
G = Generator(image_size=512)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
# load latents across reg vals
im_num = 0
regs = [10000, 1e-1, 0]
latents = []
im_orig = mpimg.imread(oj(DIR_IMS, f'{im_num + 1:05}.jpg'))
for reg in regs:
    folder = f'generated_images_{reg}'
    im_fname = oj(DIR_PROCESSED, folder, f'{im_num + 1:05}.png')
    ims.append(mpimg.imread(im_fname))
    latents.append(np.load(oj(DIR_PROCESSED, folder, f'{im_num + 1:05}.npy')))

# now let's manipulate the images

In [22]:
kwargs = {
    # change these
    'save_dir': 'results/tnew',
    'latents': latents,
    
    # probably not these
    'G': G,
    'model_dir': DIR_LINEAR_DIRECTIONS, #"transects/data/latent-models/",
    'orth': True,
    'randomize_seeds': False,
    'return_ims': True
}



# make 1D transects
LIMS = {
    'C': [-1.5, 1.7],
    'H': [-0.5, 0.0],
    'G': [-1.75, 1.75],
    
    # these are not calibrated
    'A': [-2, 2],
    'B': [-2, 2],
    'M': [-2, 2],
    'S': [-2, 2],
    'E': [-2, 2],
    'W': [-2, 2],    
}

In [24]:
ATTRS = 'C'
N_IMS = 7
LIMS['C'] = [-1.4, 1.4]

'''
ATTRS = 'CGAH' # HAGCBMSEW # CHG
N_IMS = 11
'''

transects_1d = {}
for attr in ATTRS:
    ims, attr_df = transects.make_transects(
        attr=attr,
        N_IMS_LIST=[N_IMS],
        LIMS_LIST=LIMS[attr],
        force_project_to_boundary=False,
        **kwargs
    )
    transects_1d[attr] = ims
    
# make CHG transects
'''
ims, attr_df = transects.make_transects(
    **kwargs
)
'''

100%|██████████| 3/3 [00:01<00:00,  2.32it/s]


'\nims, attr_df = transects.make_transects(\n    **kwargs\n)\n'

In [None]:

def visualize_individual_latent(transects_1d):
    '''viz 1-d transects for individual latent
    '''
    ims = np.array([transects_1d[a] for a in ATTRS])
    print(ims.shape)
    ims = ims.reshape((len(ATTRS), N_IMS, *ims.shape[2:]))
    util.plot_grid(ims, ylabs=[ATTR_LABELS[a] for a in ATTRS])
    

REG_LABS = {
    10000: 'Unexpanded',
    1e-1: 'Regularized',
    0: 'Unregularized'
}
def visualize_varying_reg(transects_1d):
    '''viz 1-d transect (1 attribute) for latents as reg changes
    '''
    ims = np.array([transects_1d[a] for a in ATTRS])
    print(ims.shape)
    ims = ims.reshape((len(latents), N_IMS, *ims.shape[2:]))
    util.plot_grid(ims, ylabs=[REG_LABS[reg] for reg in regs], emphasize_col=N_IMS//2)    

    
# visualize_individual_latent(transects_1d)


# varying reg fig
ax = plt.subplot(111, facecolor='white')
plt.imshow(im_orig)
plt.ylabel('Original image')
util.emphasize_box(ax)
util.savefig('manipulations_orig')
visualize_varying_reg(transects_1d)
util.savefig('manipulations_full')

# project things to be neutral

In [None]:
# load latents
for im_num in np.arange(10):
    im_orig_fname = oj(DIR_IMS, f'{im_num + 1:05}.jpg')
    im_gen_fname = oj(DIR_GEN, f'{im_num + 1:05}.png')
    latents = [np.load(oj(DIR_GEN, f'{im_num + 1:05}.npy'))]
    im_orig = mpimg.imread(im_orig_fname)
    im_gen = mpimg.imread(im_gen_fname)


    attr = 'C'
    im_neutral, W_neutral = transects.make_transects(
        G=G,
        attr=attr,
        latents=latents,
        force_project_to_boundary=False,
        N_IMS_LIST=[1],
        LIMS_LIST=[0, 0],
        return_project_to_boundary=True,
        orth=True,
    )
    util.plot_row([im_orig, im_gen, im_neutral], annot_list=['orig', 'rec', 'neutral'])

# let's just look at a bunch of reconstructions
(at the best reg)

In [15]:
DIR_ORIG = '../data/celeba-hq/ims/'
DIRS_GEN = '../data_processed/celeba-hq/'

reg = 0.1
IM_NUMS = np.arange(30, 60)

ims_orig = []
ims_rec = []
for IM_NUM in IM_NUMS:
    folder = f'generated_images_{reg}'
    im_fname = oj(DIRS_GEN, folder, f'{IM_NUM:05}.png')
    ims_orig.append(mpimg.imread(oj(DIR_IMS, f'{IM_NUM:05}.jpg')))
    ims_rec.append(mpimg.imread(im_fname))

In [None]:
util.plot_grid(ims_orig)

In [None]:
util.plot_grid(ims_rec)

### 