In [None]:
!git clone https://github.com/LoreGoetschalckx/GANalyze.git

Cloning into 'GANalyze'...
remote: Enumerating objects: 57, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 57 (delta 1), reused 4 (delta 1), pack-reused 50[K
Unpacking objects: 100% (57/57), 864.46 KiB | 4.57 MiB/s, done.


In [None]:
%cd GANalyze/pytorch
!sh download_pretrained.sh

/content/GANalyze/pytorch
Downloading EmoNet weights
--2023-03-22 14:29:40--  http://ganalyze.csail.mit.edu/models/EmoNet_valence_moments_resnet50_5_best.pth.tar
Resolving ganalyze.csail.mit.edu (ganalyze.csail.mit.edu)... 128.30.100.223
Connecting to ganalyze.csail.mit.edu (ganalyze.csail.mit.edu)|128.30.100.223|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 188384043 (180M) [application/x-tar]
Saving to: ‘assessors/EmoNet_valence_moments_resnet50_5_best.pth.tar’


2023-03-22 14:30:16 (5.09 MB/s) - ‘assessors/EmoNet_valence_moments_resnet50_5_best.pth.tar’ saved [188384043/188384043]

Downloading BigGAN weights
--2023-03-22 14:30:16--  http://ganalyze.csail.mit.edu/models/biggan-128.pth
Resolving ganalyze.csail.mit.edu (ganalyze.csail.mit.edu)... 128.30.100.223
Connecting to ganalyze.csail.mit.edu (ganalyze.csail.mit.edu)|128.30.100.223|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 282218374 (269M)
Saving to: ‘generators/biggan-128.

Be careful! The below command cleans all previous checkpoints

In [None]:
!rm -rf /content/GANalyze/pytorch/checkpoints/*

##Training
The below command line runs the train-pytorch.py script to run the training process. Update the num_samples parameter to set number of iterations. Change the checkpoint_resume to continue training on a specified checkpoint.
For further information, check out https://github.com/LoreGoetschalckx/GANalyze

In [None]:
!python train_pytorch.py \
 --generator biggan256 None \
 --assessor emonet \
 --transformer OneDirection None \
 --train_alpha_a -0.5 --train_alpha_b 0.5 \
 --gpu_id 0 --num_samples 2000 --checkpoint_resume 0


approach:  one_direction

[0/2000] Loss 0.131149 (0.131149)
saving checkpoint
[4/2000] Loss 0.036781 (0.083965)
[8/2000] Loss 0.040570 (0.069500)
[12/2000] Loss 0.102708 (0.077802)
[16/2000] Loss 0.161770 (0.094595)
[20/2000] Loss 0.018938 (0.081986)
[24/2000] Loss 0.162383 (0.093471)
[28/2000] Loss 0.079837 (0.091767)
[32/2000] Loss 0.134422 (0.096506)
[36/2000] Loss 0.028711 (0.089727)
[40/2000] Loss 0.060338 (0.087055)
[44/2000] Loss 0.112123 (0.089144)
[48/2000] Loss 0.043444 (0.085629)
[52/2000] Loss 0.027143 (0.081451)
[56/2000] Loss 0.137553 (0.085191)
[60/2000] Loss 0.115322 (0.087074)
[64/2000] Loss 0.085983 (0.087010)
[68/2000] Loss 0.113234 (0.088467)
[72/2000] Loss 0.029761 (0.085377)
[76/2000] Loss 0.033193 (0.082768)
[80/2000] Loss 0.074507 (0.082375)
[84/2000] Loss 0.203287 (0.087871)
[88/2000] Loss 0.135228 (0.089930)
[92/2000] Loss 0.082601 (0.089624)
[96/2000] Loss 0.203917 (0.094196)
[100/2000] Loss 0.134702 (0.095754)
[104/2000] Loss 0.111092 (0.096322)
[108/2000] 

Training script for your reference (You DO NOT need to run this block if you have trained your model through the above command line). However, feel free to play with it and see how other generator + assessor combinations can be experimented.

In [None]:
import argparse
import json
import os
import subprocess

import numpy as np
import torch
import torch.optim as optim

import assessors
import generators
import transformations.pytorch as transformations
import utils.common
import utils.pytorch
import matplotlib.pyplot as plt

# Collect command line arguments
# --------------------------------------------------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default=0, help='which gpu to use.')
parser.add_argument('--num_samples', type=int, default="400000", help='number of samples to train for')
parser.add_argument('--checkpoint_resume', type=int, default=0, help='which checkpoint to load based on batch_start. -1 for latest stored checkpoint')
parser.add_argument('--train_alpha_a', type=float, default=-0.5, help='lower limit for step sizes to use during training')
parser.add_argument('--train_alpha_b', type=float, default=0.5, help='upper limit for step sizes to use during training')
parser.add_argument('--generator', default=["biggan256", "None"], nargs=2, type=str, metavar=["name", "arguments"], help='generator function to use')
parser.add_argument('--assessor', type=str, default="emonet", help='assessor function to compute the image property of interest')
parser.add_argument('--transformer', default=["OneDirection", "None"], nargs=2, type=str, metavar=["name", "arguments"], help="transformer function")

args = parser.parse_args()
opts = vars(args)

# Verify
if opts["checkpoint_resume"] != 0 and opts["checkpoint_resume"] != -1:
    assert(opts["checkpoint_resume"] % 4 == 0)  # Needs to be a multiple of the batch size

# Choose GPU
if opts["gpu_id"] != -1:
    device = torch.device("cuda:" + str(opts["gpu_id"]) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

# Creating directory to store checkpoints
version = subprocess.check_output(["git", "describe", "--always"]).strip().decode("utf-8")
checkpoint_dir = os.path.join(
    "./checkpoints",
    "_".join(opts["generator"]),
    opts["assessor"],
    "_".join(opts["transformer"]),
    version)

if opts["checkpoint_resume"] == 0:
    os.makedirs(checkpoint_dir, exist_ok=False)

# Saving training settings
opts_file = os.path.join(checkpoint_dir, "opts.json")
opts["version"] = version
with open(opts_file, 'w') as fp:
    json.dump(opts, fp)

# Setting up file to store loss values
loss_file = os.path.join(checkpoint_dir, "losses.txt")

# Some characteristics
# --------------------------------------------------------------------------------------------------------------
dim_z = {
    'biggan256': 140,
    'biggan512': 128
}.get(opts['generator'][0])

vocab_size = {'biggan256': 1000, 'biggan512': 1000}.get(opts['generator'][0])

# Setting up Transformer
# --------------------------------------------------------------------------------------------------------------
transformer = opts["transformer"][0]
transformer_arguments = opts["transformer"][1]
if transformer_arguments != "None":
    key_value_pairs = transformer_arguments.split(",")
    key_value_pairs = [pair.split("=") for pair in key_value_pairs]
    transformer_arguments = {pair[0]: pair[1] for pair in key_value_pairs}
else:
    transformer_arguments = {}

transformation = getattr(transformations, transformer)(dim_z, vocab_size, **transformer_arguments)
transformation = transformation.to(device)

# Setting up Generator
# --------------------------------------------------------------------------------------------------------------
generator = opts["generator"][0]
generator_arguments = opts["generator"][1]
if generator_arguments != "None":
    key_value_pairs = generator_arguments.split(",")
    key_value_pairs = [pair.split("=") for pair in key_value_pairs]
    generator_arguments = {pair[0]: pair[1] for pair in key_value_pairs}
else:
    generator_arguments = {}

generator = getattr(generators, generator)(**generator_arguments)

for p in generator.parameters():
    p.requires_grad = False
generator.eval()
generator = generator.to(device)

# Setting up Assessor
# --------------------------------------------------------------------------------------------------------------
assessor_elements = getattr(assessors, opts['assessor'])(True)
if isinstance(assessor_elements, tuple):
    assessor = assessor_elements[0]
    input_transform = assessor_elements[1]
    output_transform = assessor_elements[2]
else:
    assessor = assessor_elements

    def input_transform(x):
        return x  # identity, no preprocessing

    def output_transform(x):
        return x  # identity, no postprocessing

if hasattr(assessor, 'parameters'):
    for p in assessor.parameters():
        p.requires_grad = False
        assessor.eval()
        assessor.to(device)

# Training
# --------------------------------------------------------------------------------------------------------------
# optimizer
optimizer = optim.Adam(transformation.parameters(), lr=0.0002)
losses = utils.common.AverageMeter(name='Loss')

# figure out where to resume
if opts["checkpoint_resume"] == 0:
    checkpoint_resume = 0
elif opts["checkpoint_resume"] == -1:
    available_checkpoints = [x for x in os.listdir(checkpoint_dir) if x.endswith(".pth")]
    available_batch_numbers = [x.split('.')[0].split("_")[-1] for x in available_checkpoints]
    latest_number = max(available_batch_numbers)
    file_to_load = available_checkpoints[available_batch_numbers.index(latest_number)]
    transformation.load_state_dict(torch.load(os.path.join(checkpoint_dir, file_to_load)))
    checkpoint_resume = latest_number
else:
    transformation.load_state_dict(torch.load(os.path.join(checkpoint_dir,
                                                           "pytorch_model_{}.pth".format(opts["checkpoint_resume"]))))
    checkpoint_resume = opts["checkpoint_resume"]

#  training settings
optim_iter = 0
batch_size = 4
train_alpha_a = opts["train_alpha_a"]
train_alpha_b = opts["train_alpha_b"]
num_samples = opts["num_samples"]

# create training set
np.random.seed(seed=0)
truncation = 1
zs = utils.common.truncated_z_sample(num_samples, dim_z, truncation)
ys = np.random.randint(0, vocab_size, size=zs.shape[0])

# loop over data batches
for batch_start in range(0, num_samples, batch_size):

    # zero the parameter gradients
    optimizer.zero_grad()

    # skip batches we've already done (this would happen when resuming from a checkpoint)
    if batch_start <= checkpoint_resume and checkpoint_resume != 0:
        optim_iter = optim_iter + 1
        continue

    # input batch
    s = slice(batch_start, min(num_samples, batch_start + batch_size))
    z = torch.from_numpy(zs[s]).type(torch.FloatTensor).to(device)
    y = torch.from_numpy(ys[s]).to(device)
    step_sizes = (train_alpha_b - train_alpha_a) * \
        np.random.random(size=(batch_size)) + train_alpha_a  # sample step_sizes
    step_sizes_broadcast = np.repeat(step_sizes, dim_z).reshape([batch_size, dim_z])
    step_sizes_broadcast = torch.from_numpy(step_sizes_broadcast).type(torch.FloatTensor).to(device)

    # ganalyze steps
    gan_images = generator(z, utils.pytorch.one_hot(y))
    # save sample images here
    img = gan_images
    gan_images = input_transform(utils.pytorch.denorm(gan_images))
    gan_images = gan_images.view(-1, *gan_images.shape[-3:])
    gan_images = gan_images.to(device)
    out_scores = output_transform(assessor(gan_images)).to(device).float()
    target_scores = out_scores + torch.from_numpy(step_sizes).to(device).float()

    z_transformed = transformation.transform(z, utils.pytorch.one_hot(y), step_sizes_broadcast)
    gan_images_transformed = generator(z_transformed, utils.pytorch.one_hot(y))
    gan_images_transformed = input_transform(utils.pytorch.denorm(gan_images_transformed))
    gan_images_transformed = gan_images_transformed.view(-1, *gan_images_transformed.shape[-3:])
    gan_images_transformed = gan_images_transformed.to(device)
    out_scores_transformed = output_transform(assessor(gan_images_transformed)).to(device).float()

    # compute loss
    loss = transformation.compute_loss(out_scores_transformed, target_scores, batch_start, loss_file)

    # backwards
    loss.backward()
    optimizer.step()

    # print loss
    losses.update(loss.item(), batch_size)
    print(f'[{batch_start}/{num_samples}] {losses}')

    if optim_iter % 50 == 0:
        print("saving checkpoint")
        torch.save(transformation.state_dict(), os.path.join(checkpoint_dir, "pytorch_model_{}.pth".format(batch_start)))
        # plot sample images
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        img_np = img.detach().cpu().numpy().squeeze()

        fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))
        for i in range(4): # By defualt batchsize = 4
            row = i // 2
            col = i % 2
            ax[row, col].imshow(img_np[i])
        plt.savefig(os.path.join(checkpoint_dir, "sample_image_{}.png".format(batch_start)))
        plt.show()

    optim_iter = optim_iter + 1

torch.save(transformation.state_dict(), os.path.join(checkpoint_dir, "pytorch_model_{}.pth".format(opts["num_samples"])))


##Testing
Use the below command line to run the testing phase and generate the interpolation results. Change the paramters to specify the checkpoint to use.

In [None]:
!rm -rf output/*

In [None]:
!python test_pytorch.py \
--alpha 0.1 --test_truncation 1 \
--checkpoint_dir /content/GANalyze/pytorch/checkpoints/biggan256_None/emonet/OneDirection_None/45d4139 \
--checkpoint 1000 \
--gpu_id 0

{'gpu_id': '0', 'alpha': 0.1, 'test_truncation': 1.0, 'checkpoint_dir': '/content/GANalyze/pytorch/checkpoints/biggan256_None/emonet/OneDirection_None/45d4139', 'checkpoint': 1000, 'mode': 'bigger_step'}

approach:  one_direction

bigger_step
y:  100
bigger_step
Traceback (most recent call last):
  File "/content/GANalyze/pytorch/test_pytorch.py", line 301, in <module>
    ims.append(ims_batch)
AttributeError: 'numpy.ndarray' object has no attribute 'append'


Similarly, testing script for your reference (You DO NOT need to run this block if you have run testing through the above command line).

In [None]:
import argparse
import json
import os
import subprocess

import numpy as np
import PIL.ImageDraw
import PIL.ImageFont
import torch

import assessors
import generators
import transformations.pytorch as transformations
import utils.common
import utils.pytorch

# Collect command line arguments
# --------------------------------------------------------------------------------------------------------------
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default=0, help='which gpu to use.')
parser.add_argument('--alpha', type=float, default=0.1, help='stepsize for testing')
parser.add_argument('--test_truncation', type=float, default=1, help='truncation to use in test phase')
parser.add_argument('--checkpoint_dir', type=str, default="", help='path for directory with the checkpoints of the trained model we want to use')
parser.add_argument('--checkpoint', type=int, default=400000, help='which checkpoint to load')
parser.add_argument('--mode', default="bigger_step", choices=["iterative", "bigger_step"],
                    help="how to make the test sequences. bigger_step was used in the paper.")

args = parser.parse_args()
opts = vars(args)
print(opts)

# Choose GPU
if opts["gpu_id"] != -1:
    device = torch.device("cuda:" + str(opts["gpu_id"]) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

# Creating directory to store output visualizations
train_opts_file = os.path.join(opts["checkpoint_dir"], "opts.json")
with open(train_opts_file) as f:
    train_opts = json.load(f)

if not isinstance(train_opts["transformer"], list):
    train_opts["transformer"] = [train_opts["transformer"]]

test_version = str(subprocess.check_output(["git", "describe", "--always"]).strip())
result_dir = os.path.join("./output",
                          "-".join(train_opts["generator"]),
                          train_opts["assessor"],
                          "-".join(train_opts["transformer"]),
                          train_opts["version"],
                          "alpha_" + str(opts["alpha"]) + "_truncation_" + str(opts["test_truncation"]) + "_iteration_" + str(opts["checkpoint"]) + "_" + opts["mode"])

os.makedirs(result_dir, exist_ok=False)

# checkpoint_dir
checkpoint_dir = opts["checkpoint_dir"]

# Saving testing settings
opts_file = os.path.join(result_dir, "opts.json")
opts["test_version"] = test_version
with open(opts_file, 'w') as fp:
    json.dump(opts, fp)

# Some characteristics
# --------------------------------------------------------------------------------------------------------------
dim_z = {
    'biggan256': 140,
    'biggan512': 128
}.get(train_opts['generator'][0])

vocab_size = {'biggan256': 1000, 'biggan512': 1000}.get(train_opts['generator'][0])
categories_file = "./generators/categories_imagenet.txt"
categories = [x.strip() for x in open(categories_file)]

# Setting up Transformer
# --------------------------------------------------------------------------------------------------------------
transformer = train_opts["transformer"][0]
transformer_arguments = train_opts["transformer"][1]
if transformer_arguments != "None":
    key_value_pairs = transformer_arguments.split(",")
    key_value_pairs = [pair.split("=") for pair in key_value_pairs]
    transformer_arguments = {pair[0]: pair[1] for pair in key_value_pairs}
else:
    transformer_arguments = {}

transformation = getattr(transformations, transformer)(dim_z, vocab_size, **transformer_arguments)
transformation = transformation.to(device)

# Setting up Generator
# --------------------------------------------------------------------------------------------------------------
generator = train_opts["generator"][0]
generator_arguments = train_opts["generator"][1]
if generator_arguments != "None":
    key_value_pairs = generator_arguments.split(",")
    key_value_pairs = [pair.split("=") for pair in key_value_pairs]
    generator_arguments = {pair[0]: pair[1] for pair in key_value_pairs}
else:
    generator_arguments = {}

generator = getattr(generators, generator)(**generator_arguments)

for p in generator.parameters():
    p.requires_grad = False
generator.eval()
generator = generator.to(device)

# Setting up Assessor
# --------------------------------------------------------------------------------------------------------------
assessor_elements = getattr(assessors, train_opts['assessor'])(True)
if isinstance(assessor_elements, tuple):
    assessor = assessor_elements[0]
    input_transform = assessor_elements[1]
    output_transform = assessor_elements[2]
else:
    assessor = assessor_elements

    def input_transform(x): return x  # identity, no preprocessing

    def output_transform(x): return x  # identity, no postprocessing

if hasattr(assessor, 'parameters'):
    for p in assessor.parameters():
        p.requires_grad = False
        assessor.eval()
        assessor.to(device)

# Testing
# --------------------------------------------------------------------------------------------------------------
# Figure out where to resume
if opts["checkpoint"] == 0:
    checkpoint = 0
elif opts["checkpoint"] == -1:
    available_checkpoints = [x for x in os.listdir(checkpoint_dir) if x.endswith(".pth")]
    available_batch_numbers = [x.split('.')[0].split("_")[-1] for x in available_checkpoints]
    latest_number = max(available_batch_numbers)
    file_to_load = available_checkpoints[available_batch_numbers.index(latest_number)]
    transformation.load_state_dict(torch.load(os.path.join(checkpoint_dir, file_to_load)))
    checkpoint = latest_number
else:
    transformation.load_state_dict(torch.load(os.path.join(checkpoint_dir,
                                                           "pytorch_model_" + str(opts["checkpoint"]) + ".pth")))
    checkpoint = opts["checkpoint"]

# helper function


def make_image(z, y, step_size, transform):
    if transform:
        z_transformed = transformation.transform(z, y, step_size)
        z_transformed = z.norm() * z_transformed / z_transformed.norm()
        z = z_transformed

    gan_images = utils.pytorch.denorm(generator(z, y))
    gan_images_np = gan_images.permute(0, 2, 3, 1).detach().cpu().numpy()
    gan_images = input_transform(gan_images)
    gan_images = gan_images.view(-1, *gan_images.shape[-3:])
    gan_images = gan_images.to(device)

    out_scores_current = output_transform(assessor(gan_images))
    out_scores_current = out_scores_current.detach().cpu().numpy()
    if len(out_scores_current.shape) == 1:
        out_scores_current = np.expand_dims(out_scores_current, 1)

    return(gan_images_np, z, out_scores_current)


# Test settings
num_samples = 10
truncation = opts["test_truncation"]
iters = 3
np.random.seed(seed=999)
annotate = True

if vocab_size == 0:
    num_categories = 1
else:
    # set to 1 for debugging
    num_categories = 1 #vocab_size

for y in range(num_categories):

    ims = []
    outscores = []

    zs = utils.common.truncated_z_sample(num_samples, dim_z, truncation)
    ys = np.repeat(y, num_samples)
    zs = torch.from_numpy(zs).type(torch.FloatTensor).to(device)
    ys = torch.from_numpy(ys).to(device)
    ys = utils.pytorch.one_hot(ys, vocab_size)
    step_sizes = np.repeat(np.array(opts["alpha"]), num_samples * dim_z).reshape([num_samples, dim_z])
    step_sizes = torch.from_numpy(step_sizes).type(torch.FloatTensor).to(device)
    feed_dicts = []
    for batch_start in range(0, num_samples, 4):
        s = slice(batch_start, min(num_samples, batch_start + 4))
        feed_dicts.append({"z": zs[s], "y": ys[s], "truncation": truncation, "step_sizes": step_sizes[s]})

    for feed_dict in feed_dicts:
        ims_batch = []
        outscores_batch = []
        z_start = feed_dict["z"]
        step_sizes = feed_dict["step_sizes"]

        if opts["mode"] == "iterative":
            print("iterative")

            # original seed image
            x, tmp, outscore = make_image(feed_dict["z"], feed_dict["y"], feed_dict["step_sizes"], transform=False)
            x = np.uint8(x)
            if annotate:
                ims_batch.append(utils.common.annotate_outscore(x, outscore))
            else:
                if annotate:
                    ims_batch.append(utils.common.annotate_outscore(x, outscore))
                else:
                    ims_batch.append(x)
            outscores_batch.append(outscore)

            # negative clone images
            z_next = z_start
            step_sizes = -step_sizes
            for iter in range(0, iters, 1):
                feed_dict["step_sizes"] = step_sizes
                feed_dict["z"] = z_next
                x, tmp, outscore = make_image(feed_dict["z"], feed_dict["y"], feed_dict["step_sizes"], transform=True)
                x = np.uint8(x)
                z_next = tmp
                if annotate:
                    ims_batch.append(utils.common.annotate_outscore(x, outscore))
                else:
                    if annotate:
                        ims_batch.append(utils.common.annotate_outscore(x, outscore))
                    else:
                        ims_batch.append(x)
                outscores_batch.append(outscore)

            ims_batch.reverse()

            # positive clone images
            step_sizes = -step_sizes
            z_next = z_start
            for iter in range(0, iters, 1):
                feed_dict["step_sizes"] = step_sizes
                feed_dict["z"] = z_next

                x, tmp, outscore = make_image(feed_dict["z"], feed_dict["y"], feed_dict["step_sizes"], transform=True)
                x = np.uint8(x)
                z_next = tmp

                if annotate:
                    ims_batch.append(utils.common.annotate_outscore(x, outscore))
                else:
                    ims_batch.append(x)
                outscores_batch.append(outscore)

        else:
            print("bigger_step")

            # original seed image
            x, tmp, outscore = make_image(feed_dict["z"], feed_dict["y"], feed_dict["step_sizes"], transform=False)
            x = np.uint8(x)
            if annotate:
                ims_batch.append(utils.common.annotate_outscore(x, outscore))
            else:
                ims_batch.append(x)
            outscores_batch.append(outscore)

            # negative clone images
            step_sizes = -step_sizes
            for iter in range(0, iters, 1):
                feed_dict["step_sizes"] = step_sizes * (iter + 1)

                x, tmp, outscore = make_image(feed_dict["z"], feed_dict["y"], feed_dict["step_sizes"], transform=True)
                x = np.uint8(x)

                if annotate:
                    ims_batch.append(utils.common.annotate_outscore(x, outscore))
                else:
                    ims_batch.append(x)
                outscores_batch.append(outscore)

            ims_batch.reverse()
            outscores_batch.reverse()

            # positive clone images
            step_sizes = -step_sizes
            for iter in range(0, iters, 1):
                feed_dict["step_sizes"] = step_sizes * (iter + 1)

                x, tmp, outscore = make_image(feed_dict["z"], feed_dict["y"], feed_dict["step_sizes"], transform=True)
                x = np.uint8(x)
                if annotate:
                    ims_batch.append(utils.common.annotate_outscore(x, outscore))
                else:
                    ims_batch.append(x)
                outscores_batch.append(outscore)

        ims_batch = [np.expand_dims(im, 0) for im in ims_batch]
        ims_batch = np.concatenate(ims_batch, axis=0)
        ims_batch = np.transpose(ims_batch, (1, 0, 2, 3, 4))
        ims.append(ims_batch)

        outscores_batch = [np.expand_dims(outscore, 0) for outscore in outscores_batch]
        outscores_batch = np.concatenate(outscores_batch, axis=0)
        outscores_batch = np.transpose(outscores_batch, (1, 0, 2))
        outscores.append(outscores_batch)

    ims = np.concatenate(ims, axis=0)
    outscores = np.concatenate(outscores, axis=0)
    ims_final = np.reshape(ims, (ims.shape[0] * ims.shape[1], ims.shape[2], ims.shape[3], ims.shape[4]))
    I = PIL.Image.fromarray(utils.common.imgrid(ims_final, cols=iters * 2 + 1))
    I.save(os.path.join(result_dir, categories[y] + ".jpg"))
    print("y: ", y)
