In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))
# display("text/html", "<style>.container { width:100% !important; }</style>")

In [None]:
import os
os.environ["PROJ_DIR"] = '/PATH/TO/REPO/ddo'
os.environ["FID_DIR"] = '/PATH/TO/CACHE/FOLDER/fid-stats'
os.environ["EXP_PATH"] = '/PATH/TO/CACHE/FOLDER/exp'
%cd /PATH/TO/REPO/ddo

In [None]:
! nvidia-smi

In [None]:
import importlib, sys
import functools
import math

import numpy as np

import torch
import torch.nn as nn

import matplotlib
import matplotlib.pyplot as plt

MYBACKEND = plt.get_backend()
print(MYBACKEND)

%matplotlib inline

# 1. dataset

In [None]:
import os
import argparse

from utils import datasets
from utils.visualize import plot_scatter, plot_contourf
matplotlib.use(MYBACKEND)


In [None]:
from gano import *
from gano import _folder_name_key_tuples
import argparse

args = argparse.Namespace()


# arguments
args.seed = 1
args.lr = 1e-4 
args.ema_decay = 0.999
args.weight_decay = 0.
args.train_batch_size = 32
args.vis_batch_size = 36
args.num_iterations = 1000000

# i/o
args.exp_path = os.getenv('EXP_PATH')
args.print_every = 1000
args.save_every = 5000
args.ckpt_every = 10000
args.eval_every = 20000
args.vis_every = 10000
args.resume = True

# ddp
args.num_proc_node = 1
args.num_process_per_node = 1
args.node_rank = 0
args.local_rank = 0
args.global_rank = 0
args.global_size = 1
args.master_address  = '127.0.0.1'
args.master_port = '6020'

# dataset
args.train_img_height = 32
args.input_dim = 1
args.coord_dim = 2
args.upsample = True
args.upsample_resolution = 64
args.upsample_filter_size = 9
args.use_radial = False
args.dataset = 'mnistsdf_{:d}'.format(args.train_img_height)
args.data = os.path.join(os.getenv('SLURM_TMPDIR'), 'data')

# init plot params
args.plot = False
args.eval_img_height = 64
args.eval_batch_size = 512
args.eval_use_ema = True
args.eval_fid = True
args.eval_pr = False
args.eval_num_samples = 50000
args.eval_resize_mode = 'tensor'
args.eval_interpolation = 'bilinear'
args.eval_antialias = 'false'
args.eval_cache = False
args.fid_dir = os.getenv('FID_DIR')

# init model
args.model = 'gano-uno'
args.npad = 0 # padding of latio 8/128 in U-NO
args.modes = 32 # number of Fourier modes in the initial FNO layer
args.d_co_domain= args.upsample_resolution if args.upsample else args.train_img_height # the dimension of the co-domain of the initial U-NO layer.
args.lmbd_grad = 10.0 # Lagrange coefficinet for gradient penalty
args.n_critic = 10 # every n_critic iteration the generator is updated

# folder name
_folder_name = []
for k, abbr in _folder_name_key_tuples:
    if not hasattr(args, k) or getattr(args, k) is None:
        continue
    if k.startswith('upsample_') and args.upsample is False:
        continue
    if k == 'ch_mult':
        _folder_name += [abbr+''.join([str(i) for i in getattr(args, k)])]
    elif k == 'act':
        _folder_name += [abbr+getattr(args, k).__class__.__name__.lower()]
    else:
        _folder_name += [abbr+str(getattr(args, k)).lower()]
folder_name = '-'.join(_folder_name)
args.folder_path = os.path.join(args.exp_path, args.dataset, folder_name)

print(args.folder_path)
print(os.path.isdir(args.folder_path))

In [None]:
args.distributed = False
args.train_batch_size_per_gpu = args.train_batch_size
args.eval_batch_size_per_gpu = args.eval_batch_size

checkpoint_file = os.path.join(args.folder_path, 'checkpoint_fid.pt')
gen, count, best_fid_score, args = load_model(checkpoint_file)
print(checkpoint_file)
print(count, best_fid_score)
gen_optimizer = None

In [None]:
# init plot params
my_dpi = 300
num_samples = 64
args.eval_img_height = 64
args.eval_batch_size = 64
args.eval_use_ema = True

# init grf
grf = GaussianRF_idct(args.coord_dim, args.eval_img_height, alpha=1.5, tau=1.0, device='cuda')

# init model generator
fn_create_generator = functools.partial(
    create_generator,
    gen=gen,
    grf=grf,
    transform=None,
    clip=False,
)

g = fn_create_generator(batch_size=args.eval_batch_size, num_samples=args.eval_num_samples) #num_samples)
sample = next(g)[:num_samples]
nrow = int(num_samples**0.5)

plt.close('all')

# plot
image = get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=2, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_orig_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=2, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_masked_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

In [None]:
# init plot params
my_dpi = 300
num_samples = 64
args.eval_img_height = 128
args.eval_batch_size = 64
args.eval_use_ema = True

# init grf
grf = GaussianRF_idct(args.coord_dim, args.eval_img_height, alpha=1.5, tau=1.0, device='cuda')

# init model generator
fn_create_generator = functools.partial(
    create_generator,
    gen=gen,
    grf=grf,
    transform=None,
    clip=False,
)

g = fn_create_generator(batch_size=args.eval_batch_size, num_samples=args.eval_num_samples) #num_samples)
sample = next(g)[:num_samples]
nrow = int(num_samples**0.5)

plt.close('all')

# plot
image = get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=4, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_orig_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=4, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_masked_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

# plot
new_num_samples = 16
new_nrow = int(new_num_samples**0.5)
indices = ((torch.arange(new_nrow)*nrow)[...,None] + torch.arange(new_nrow)[None]).reshape(-1)

image = get_grid_image(from_sdf_to_01_clip(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_orig_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_masked_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()


In [None]:
# init plot params
my_dpi = 300
num_samples = 64
args.eval_img_height = 256
args.eval_batch_size = 64
args.eval_use_ema = True

# init grf
grf = GaussianRF_idct(args.coord_dim, args.eval_img_height, alpha=1.5, tau=1.0, device='cuda')

# init model generator
fn_create_generator = functools.partial(
    create_generator,
    gen=gen,
    grf=grf,
    transform=None,
    clip=False,
)

g = fn_create_generator(batch_size=args.eval_batch_size, num_samples=args.eval_num_samples) #num_samples)
sample = next(g)[:num_samples]
nrow = int(num_samples**0.5)

plt.close('all')

# plot
image = get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_orig_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_masked_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

# plot
new_num_samples = 4
new_nrow = int(new_num_samples**0.5)
indices = ((torch.arange(new_nrow)*nrow)[...,None] + torch.arange(new_nrow)[None]).reshape(-1)

image = get_grid_image(from_sdf_to_01_clip(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_orig_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/gano_{args.eval_img_height}x{args.eval_img_height}_masked_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()
