In [None]:
%load_ext autoreload
%autoreload 2

import io
import os
import sys
import IPython.display
import PIL.Image
from pprint import pformat

import numpy as np

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)
    
from research.scripts.bigbigan import *

In [None]:
from research.data.things_dataset import ThingsDataset
from research.imagenet_classes import imagenet_classes
import torchvision.transforms as T
import torch.nn.functional as F

transform = T.Compose([
    T.Resize(256),
    T.ToTensor(),
])

things_dataset = ThingsDataset(
    root="X:\\Datasets\\EEG\\Things-concepts-and-images\\",
    transform=transform
)

In [None]:
model_name = 'bigbigan-resnet50'# ResNet-50
#model_name = 'bigbigan-revnet50x4' # RevNet-50 x4
module_path = f'https://tfhub.dev/deepmind/{model_name}/1'

# module = hub.Module(module_path, trainable=True, tags={'train'})  # training
module = hub.Module(module_path)  # inference

for signature in module.get_signature_names():
    print('Signature:', signature)
    print('Inputs:', pformat(module.get_input_info_dict(signature)))
    print('Outputs:', pformat(module.get_output_info_dict(signature)))
    print()

In [None]:
bigbigan = BigBiGAN(module)

# Make input placeholders for x (`enc_ph`) and z (`gen_ph`).
enc_ph = bigbigan.make_encoder_ph()
gen_ph = bigbigan.make_generator_ph()

# Compute samples G(z) from encoder input z (`gen_ph`).
gen_samples = bigbigan.generate(gen_ph)

# Compute reconstructions G(E(x)) of encoder input x (`enc_ph`).
recon_x = bigbigan.reconstruct_x(enc_ph, upsample=True)

# Compute encoder features used for representation learning evaluations given
# encoder input x (`enc_ph`).
enc_features = bigbigan.encode(enc_ph, return_all_features=True)

# Compute discriminator scores for encoder pairs (x, E(x)) given x (`enc_ph`)
# and generator pairs (G(z), z) given z (`gen_ph`).
disc_scores_enc = bigbigan.discriminate(*bigbigan.enc_pairs_for_disc(enc_ph))
disc_scores_gen = bigbigan.discriminate(*bigbigan.gen_pairs_for_disc(gen_ph))

# Compute losses.
losses = bigbigan.losses(enc_ph, gen_ph)

In [None]:
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

In [None]:
from scipy import ndimage, misc
import h5py
from pathlib import Path

stimulus_ids = ['1443537.022563', '1621127.019020', '1677366.018182', '1846331.017038', '1858441.011077', '1943899.024131', '1976957.013223', '2071294.046212', 
                '2128385.020264', '2139199.010398', '2190790.015121', '2274259.024319', '2416519.012793', '2437136.012836', '2437971.005013', '2690373.007713', 
                '2797295.015411', '2824058.018729', '2882301.014188', '2916179.024850', '2950256.022949', '2951358.023759', '3064758.038750', '3122295.031279', 
                '3124170.013920', '3237416.058334', '3272010.011001', '3345837.012501', '3379051.008496', '3452741.024622', '3455488.028622', '3482252.022530', 
                '3495258.009895', '3584254.005040', '3626115.019498', '3710193.022225', '3716966.028524', '3761084.043533', '3767745.000109', '3941684.021672', 
                '3954393.010038', '4210120.009062', '4252077.010859', '4254777.016338', '4297750.025624', '4387400.016693', '4507155.021299', '4533802.019479', 
                '4554684.053399', '4572121.003262']

features_root = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\decoded_features')
z_pred = np.load(features_root / 'bigbigan-resnet50__z_mean__sub-03__test-prediction__v4.npy')
z_target = np.load(features_root / 'bigbigan-resnet50__z_mean__sub-03__test__v4.npy')

root = 'X:\\Datasets\\Deep-Image-Reconstruction'
stimulus_images = h5py.File(Path(root) / "derivatives" / "stimulus_images.hdf5", "r")

x_stim = []
for stimulus_id in stimulus_ids:
    x = stimulus_images[stimulus_id]['data'][:]
    x = ndimage.zoom(x, (128 / 500, 128 / 500, 1))
    x = x / 256 * 2 - 1
    x_stim.append(x)
x_stim = np.stack(x_stim)

x_target = np.concatenate([
    sess.run(gen_samples, feed_dict={gen_ph: z[None]}) 
    for z in list(z_target)
])

x_pred = np.concatenate([
    sess.run(gen_samples, feed_dict={gen_ph: z[None] / z.std()}) 
    for z in list(z_pred)
])

In [None]:
from PIL import Image

splits = 10
cols = 50 // splits

out_path = 'X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\results'
name = 'bigbigan_result03_v2'

for i, (stim, target, pred) in enumerate(zip(np.split(x_stim, splits), np.split(x_target, splits), np.split(x_pred, splits))):
    out = np.concatenate([stim, pred])
    img = imgrid(image_to_uint8(out), cols=cols)
    Image.fromarray(img).save(Path(out_path) / f'{name}_{i*cols}-{(i+1)*cols}.png')
    imshow(img)

In [None]:
from scipy import ndimage, misc
import h5py
from pathlib import Path

stimulus_ids = {
    'natural_test': ['1443537.022563', '1621127.019020', '1677366.018182', '1846331.017038', '1858441.011077', '1943899.024131', 
                     '1976957.013223', '2071294.046212', '2128385.020264', '2139199.010398', '2190790.015121', '2274259.024319', 
                     '2416519.012793', '2437136.012836', '2437971.005013', '2690373.007713', '2797295.015411', '2824058.018729', 
                     '2882301.014188', '2916179.024850', '2950256.022949', '2951358.023759', '3064758.038750', '3122295.031279', 
                     '3124170.013920', '3237416.058334', '3272010.011001', '3345837.012501', '3379051.008496', '3452741.024622', 
                     '3455488.028622', '3482252.022530', '3495258.009895', '3584254.005040', '3626115.019498', '3710193.022225', 
                     '3716966.028524', '3761084.043533', '3767745.000109', '3941684.021672', '3954393.010038', '4210120.009062', 
                     '4252077.010859', '4254777.016338', '4297750.025624', '4387400.016693', '4507155.021299', '4533802.019479', 
                     '4554684.053399', '4572121.003262'],
    'imagery': ['1443537.022563', '1621127.019020', '1677366.018182', '1846331.017038', '1858441.011077', '1943899.024131', 
                '1976957.013223', '2071294.046212', '2128385.020264', '2139199.010398', '2190790.015121', '2274259.024319', 
                '2416519.012793', '2437136.012836', '2437971.005013', '2690373.007713', '2797295.015411', '2824058.018729', 
                '2882301.014188', '2916179.024850', '2950256.022949', '2951358.023759', '3064758.038750', '3122295.031279', 
                '3124170.013920', '3237416.058334'],
}

subject = 'sub-03'
version = 3
features_root = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\decoded_features')

z_target = np.load(features_root / 'bigbigan-resnet50__z_mean__sub-03__test__v4.npy')

z_test_pred = np.load(features_root / 'bigbigan-resnet50' / 'z_mean' / subject / f'natural_test__variational__v{version}.npy')
z_imagery_pred = np.load(features_root / 'bigbigan-resnet50' / 'z_mean' / subject / f'imagery__variational__v{version}.npy')


root = 'X:\\Datasets\\Deep-Image-Reconstruction'
stimulus_images = h5py.File(Path(root) / "derivatives" / "stimulus_images.hdf5", "r")

def make_images(z_pred, stimulus_ids):
    x_stim = []
    for stimulus_id in stimulus_ids:
        x = stimulus_images[stimulus_id]['data'][:]
        x = ndimage.zoom(x, (128 / 500, 128 / 500, 1))
        x = x / 256 * 2 - 1
        x_stim.append(x)
    x_stim = np.stack(x_stim)

    x_target = np.concatenate([
        sess.run(gen_samples, feed_dict={gen_ph: z[None]}) 
        for z in list(z_target)
    ])

    num_samples = 20
    z_pred_mean = z_pred[:, 0]
    z_pred_std = z_pred[:, 1]
    sample = np.random.randn(z_pred.shape[0], num_samples, z_pred.shape[2])
    z_pred = z_pred_mean[:, None] + z_pred_std[:, None] * sample

    x_pred = np.stack([
        np.concatenate([
            sess.run(gen_samples, feed_dict={gen_ph: z[None]}) 
            for z in list(z_stim)
        ])
        for z_stim in list(z_pred)
    ])
    
    return x_stim, x_target, x_pred

x_test_stim, x_test_target, x_test_pred = make_images(z_test_pred, stimulus_ids['natural_test'])
x_imagery_stim, x_imagery_target, x_imagery_pred = make_images(z_imagery_pred, stimulus_ids['imagery'])

In [None]:
from PIL import Image
from pathlib import Path

cols = 5

out_path = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\results')
out_path = out_path / 'bigbigan-resnet50' / 'variational_encoder' / subject
out_path.mkdir(exist_ok=True, parents=True)
name = f'bigbigan-resnet50_variational_result_v{version}-0'

groups = [('natural_test', x_test_stim, x_test_target, x_test_pred),
          ('imagery', x_imagery_stim, x_imagery_target, x_imagery_pred)]

for session, x_stim, x_target, x_pred in groups:
    save_path = out_path / session
    save_path.mkdir(exist_ok=True, parents=True)
    
    for i, (stim, target, pred) in enumerate(zip(list(x_stim), list(x_target), list(x_pred))):
        out = np.concatenate([stim[None], target[None], pred[:18]])

        img = imgrid(image_to_uint8(out), cols=cols)
        Image.fromarray(img).save(Path(save_path) / f'{name}_id-{i}.png')
        imshow(img)

In [None]:
z_pred = np.load('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\bigbigan-resnet50-test-prediction.npy')
feed_dict = {gen_ph: z_pred[32:48] * 4}
_out_samples = sess.run(gen_samples, feed_dict=feed_dict)
print('samples shape:', _out_samples.shape)
imshow(imgrid(image_to_uint8(_out_samples), cols=4))

In [None]:
z_test = np.load('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\bigbigan-resnet50-test.npy')
feed_dict = {gen_ph: z_test[16:32]}
_out_samples = sess.run(gen_samples, feed_dict=feed_dict)
print('samples shape:', _out_samples.shape)
imshow(imgrid(image_to_uint8(_out_samples), cols=4))

In [None]:
feed_dict = {gen_ph: np.random.randn(1, 120)}
_out_samples = sess.run(gen_samples, feed_dict=feed_dict)
print('samples shape:', _out_samples.shape)
imshow(imgrid(image_to_uint8(_out_samples), cols=4))

In [None]:
from functools import partial
from tqdm.notebook import tqdm
from PIL import Image
import h5py
from pathlib import Path

dataset_path = Path('D:\\Datasets\\NSD')
derivatives_path = dataset_path / 'derivatives' / 'stimulus_embeddings'
stimulu_path = dataset_path / 'nsddata_stimuli' / 'stimuli' / 'nsd' / 'nsd_stimuli.hdf5'
stimulus_images = h5py.File(stimulu_path, 'r')['imgBrick']

from torchvision import transforms as T
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = T.Compose([T.Resize(256), T.ToTensor(),])

with h5py.File(derivatives_path / f"{model_name}-embeddings.hdf5", "a") as f:
    N = stimulus_images.shape[0]

    for stimulus_id in tqdm(range(N)):
        
        image_data = stimulus_images[stimulus_id]
        image = Image.fromarray(image_data)
        x = preprocess(image).unsqueeze(0) * 2. - 1.
        x = x.permute(0, 2, 3, 1)
        
        out_recons, out_features = sess.run([recon_x, enc_features], feed_dict={enc_ph: x})
        features = {'reconstruction': out_recons, **out_features}
        
        for feature_name, feature in features.items():
            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)
            f[feature_name][stimulus_id] = feature


In [None]:
from functools import partial
from tqdm.notebook import tqdm
from PIL import Image
import h5py

derivatives_path = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\')

from pathlib import Path
root = 'X:\\Datasets\\Deep-Image-Reconstruction'
stimulus_images = h5py.File(Path(root) / "derivatives" / "stimulus_images.hdf5", "r")

from torchvision import transforms as T
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = T.Compose([T.Resize(256), T.ToTensor(),])

with h5py.File(derivatives_path / f"{model_name}-features.hdf5", "a") as f:
    for stimulus_id, stimulus_image in tqdm(stimulus_images.items()):

        image_data = stimulus_image['data'][:]
        image = Image.fromarray(image_data)
        x = preprocess(image).unsqueeze(0) * 2. - 1.
        x = x.permute(0, 2, 3, 1)
        print(x.shape, x.min(), x.max())
        
        out_recons, out_features = sess.run([recon_x, enc_features], feed_dict={enc_ph: x})
        
        #print(out_recons.shape)
        #print([(k, v.shape) for k, v in out_features.items()])
        
        features = {'reconstruction': out_recons, **out_features}

        if stimulus_id not in f:
            stimulus = f.create_group(stimulus_id)
        else:
            stimulus = f[stimulus_id]

        for node_name, feature in features.items():
            feature = feature[0]
            if node_name in stimulus:
                stimulus[node_name][:] = feature
            else:
                stimulus[node_name] = feature

In [None]:
stimulus_images = h5py.File(Path(root) / "derivatives" / "stimulus_images.hdf5", "r")



In [None]:
from ipywidgets import interact
import matplotlib.pyplot as plt

stimulus_images = h5py.File(Path(root) / "derivatives" / "stimulus_images.hdf5", "r")
reconstructions = h5py.File(derivatives_path / f"{model_name}-features.hdf5", "r")

@interact(stimulus_id=list(stimulus_images.keys()))
def compare(stimulus_id):
    original = stimulus_images[stimulus_id]['data'][:]
    reconstruction = reconstructions[stimulus_id]['reconstruction'][:]
    print(original.shape, reconstruction.shape)
    
    print(original.max(), original.min())
    print(reconstruction.max(), reconstruction.min())
    plt.imshow(original)
    plt.show()
    plt.imshow(reconstruction * 0.5 + 0.5)

In [None]:
import torch
from torch.utils.data import DataLoader
from PIL import Image
from pathlib import Path

batch_size = 16
dataloader = DataLoader(things_dataset, batch_size=batch_size)
out_path = Path("X:\\Results\\Neurophysical-Data-Decoding\\BigBiGAN-Inversions\\resnet50_1\\")
image_path = out_path / "images"
latent_path = out_path / "latents"
image_path.mkdir(exist_ok=True, parents=True)
latent_path.mkdir(exist_ok=True, parents=True)

save_features = ['z_sample', 'z_mean', 'z_stdev', 'default']
for i, batch in enumerate(dataloader):
    data = batch['data']
    data = torch.movedim(data, 1, -1)
    data = data.numpy()
    data = data * 2. - 1.
    print(data.shape, data.min(), data.max())
    break
    
    out_recons, out_features = sess.run([recon_x, enc_features], feed_dict={enc_ph: data})
    
    inputs_and_recons = interleave(data, out_recons)
    out = imgrid(image_to_uint8(inputs_and_recons), cols=8)
    image = Image.fromarray(out)
    
    name = f"{i * batch_size}-{(i + 1) * batch_size - 1}"
    image.save(image_path / f"{name}.png")
    
    latents = np.stack([out_features[feature] for feature in save_features], axis=1)
    np.save(latent_path / f"{name}.npy", latents)

In [None]:
import numpy as np
from pathlib import Path

latents_path = Path("X:\\Results\\Neurophysical-Data-Decoding\\BigBiGan-Inversions\\resnet50_1\\latents\\")

latent_file_paths = list(latents_path.iterdir())
latent_file_paths.sort(key=lambda file_path: int(file_path.stem.split("-")[0]) )
latents = [np.load(latent_file_path) for latent_file_path in latent_file_paths]
latents = np.concatenate(latents)

In [None]:
out_path = Path("X:\\Datasets\\EEG\\Things-supplementary\\Latents\\bigbigan-resnet50\\")
for i, feature in enumerate(['z_sample', 'z_mean', 'z_stdev']):
    latent = latents[:, i]
    np.save(out_path / f"{feature}.npy", latent)

In [None]:
# BigGAN-deep models
# module_path = 'https://tfhub.dev/deepmind/biggan-deep-128/1'  # 128x128 BigGAN-deep
module_path = 'https://tfhub.dev/deepmind/biggan-deep-256/1'  # 256x256 BigGAN-deep
# module_path = 'https://tfhub.dev/deepmind/biggan-deep-512/1'  # 512x512 BigGAN-deep

# BigGAN (original) models
# module_path = 'https://tfhub.dev/deepmind/biggan-128/2'  # 128x128 BigGAN
# module_path = 'https://tfhub.dev/deepmind/biggan-256/2'  # 256x256 BigGAN
# module_path = 'https://tfhub.dev/deepmind/biggan-512/2'  # 512x512 BigGAN

In [None]:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import os
import io
import IPython.display
import numpy as np
import PIL.Image
from scipy.stats import truncnorm
import tensorflow_hub as hub

In [None]:
class_vector_path = Path("X:\\Datasets\\EEG\\Things-supplementary\\ImageNet-Classification\\resetnet152\\concept_averages\\one_hot.csv")
concept_imagenet_classes = np.loadtxt(class_vector_path, delimiter=",")

In [None]:
import torch
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)
from pathlib import Path

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
model = BigGAN.from_pretrained('biggan-deep-256')

# Prepare a input
truncation = 0.4
#class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3)
class_vector = concept_imagenet_classes[[8, 8, 8]]
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3)

# All in tensors
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector).float()

# If you have a GPU, put everything on cuda
noise_vector = noise_vector.to('cuda')
class_vector = class_vector.to('cuda')
model.to('cuda')

# Generate an image
with torch.no_grad():
    output = model(noise_vector, class_vector, truncation)

# If you have a GPU put back on CPU
output = output.to('cpu')

# Save results as png images
#save_as_images(output)

In [None]:
model = BigGAN.from_pretrained('biggan-128')

In [None]:
# Load BigGAN-deep 128 module.
module = hub.Module('https://tfhub.dev/deepmind/biggan-deep-128/1')

# Sample random noise (z) and ImageNet label (y) inputs.
batch_size = 8
truncation = 0.5  # scalar truncation value in [0.0, 1.0]
z = truncation * tf.random.truncated_normal([batch_size, 128])  # noise sample
y_index = tf.random.uniform([batch_size], maxval=1000, dtype=tf.int32)
y = tf.one_hot(y_index, 1000)  # one-hot ImageNet label

# Call BigGAN on a dict of the inputs to generate a batch of images with shape
# [8, 128, 128, 3] and range [-1, 1].
samples = module(dict(y=y, z=z, truncation=truncation), signature="image_feature_vector", as_dict=True)

In [None]:
module_path = 'https://tfhub.dev/deepmind/biggan-deep-128/1' 