In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))
# display("text/html", "<style>.container { width:100% !important; }</style>")

In [None]:
import os
os.environ["PROJ_DIR"] = '/PATH/TO/REPO/ddo'
os.environ["FID_DIR"] = '/PATH/TO/CACHE/FOLDER/fid-stats'
os.environ["EXP_PATH"] = '/PATH/TO/CACHE/FOLDER/exp'
%cd /PATH/TO/REPO/ddo

In [None]:
! nvidia-smi

In [None]:
import importlib, sys
import functools
import math

import numpy as np

import torch
import torch.nn as nn

import matplotlib
import matplotlib.pyplot as plt

MYBACKEND = plt.get_backend()
print(MYBACKEND)

%matplotlib inline

# 1. dataset

In [None]:
import os
import argparse

from utils import datasets
from utils.visualize import plot_scatter, plot_contourf
matplotlib.use(MYBACKEND)


In [None]:
from main import *
from main import _folder_name_key_tuples
import argparse

args = argparse.Namespace()

# seed
args.seed = 0
args.command_type = 'train'

# i/o
args.exp_path = os.getenv('EXP_PATH')
args.print_every = 1000
args.save_every = 5000
args.ckpt_every = 100000
args.eval_every = 50000
args.vis_every = 10000
args.plot = False
args.resume = True

# optimization
args.train_batch_size = 32
args.vis_batch_size = 36
args.optimizer = 'adam'
args.lr = 0.0001
args.lr_scheduler = 'none'
args.ema_decay = 0.999
args.weight_decay = 0.
args.beta1 = 0.9
args.beta2 = 0.999
args.lr_rampup_kimg = 0
args.num_iterations = 2000000

# dataset
args.train_img_height = 32
args.dataset = 'mnistsdf_{:d}'.format(args.train_img_height)
args.data = os.path.join(os.getenv('SLURM_TMPDIR'), 'data')
args.dequantize = False
args.transform = 'sdf'
args.input_dim = 1
args.coord_dim = 2

# model
args.model = 'fnounet2d'
args.modes = 32
args.act = None
args.ch = 64
args.ch_mult = (1,2,2)
args.num_res_blocks = 4
args.dropout = 0.
args.discard_resamp_with_conv = False
args.use_pointwise_op = True
args.use_radial = False
args.use_pos = True
args.norm = 'group_norm'

# upsample
args.upsample = True
args.upsample_resolution = 64
args.upsample_filter_size = 9

# forward
args.timestep_sampler = 'low_discrepancy'
args.ns_method = 'vp_cosine'
args.disp_method = 'sine'
args.sigma_blur_min = 0.05
args.sigma_blur_max = 0.25

# noise
args.gp_type = 'exponential'
args.gp_exponent = 2.0
args.gp_length_scale = 0.05
args.gp_sigma = 1.0
args.gp_modes = None

# EM param
args.num_steps = 250
args.s_min = 1e-4
args.sampler = 'denoise'
args.eval_lmbd = 0.

# eval
args.checkpoint_file = 'checkpoint.pt'
args.eval_img_height = 64
args.eval_batch_size = 1024
args.eval_use_ema = True
args.eval_fid = True
args.eval_pr = False
args.eval_num_samples = 50000
args.eval_resize_mode = 'tensor'
args.eval_interpolation = 'bilinear'
args.eval_antialias = False
args.eval_cache = False
args.fid_dir = os.getenv('FID_DIR')

# ddp
args.num_proc_node = 1
args.num_process_per_node = 1
args.node_rank = 0
args.local_rank = 0
args.global_rank = 0
args.global_size = 1
args.master_address  = '127.0.0.1'
args.master_port = None


# folder_name
_folder_name = []
for k, abbr in _folder_name_key_tuples:
    if not hasattr(args, k) or getattr(args, k) is None:
        continue
    if k.startswith('upsample_') and args.upsample is False:
        continue
    if k == "use_lmbd" and args.use_lmbd is False:
        continue
    if k == "use_pos" and args.use_pos is False:
        continue
    if k in ['min_scale', 'sigma_blur_min', 'sigma_blur_max'] and args.disp_method is None:
        continue
    if k in ['ch_mult']:
        _folder_name += [abbr+''.join([str(i) for i in getattr(args, k)])]
    elif k == "use_pos" and args.use_pos:
        _folder_name += [abbr]
    else:
        _folder_name += [abbr+str(getattr(args, k)).lower()]
folder_name = '-'.join(_folder_name)
args.folder_path = os.path.join(args.exp_path, args.dataset, folder_name)

print(args.folder_path)
print(os.path.isdir(args.folder_path))

In [None]:
args.distributed = False
args.train_batch_size_per_gpu = args.train_batch_size
args.eval_batch_size_per_gpu = args.eval_batch_size

checkpoint_file = os.path.join(args.folder_path, 'checkpoint_fid.pt')
gen_sde, count, best_fid_score, args = load_model(checkpoint_file)
print(checkpoint_file)
print(count, best_fid_score)
gen_sde_optimizer = None

In [None]:
from utils.utils import Writer

# init EM param
args.num_steps = 250

# init plot params
my_dpi = 300
num_samples = 64
args.eval_img_height = 64
args.eval_batch_size = 64
args.eval_use_ema = True
sampler = 'denoise'
print('Visualize {:d}x{:d}-res samples at iter {:d}'.format(args.eval_img_height, args.eval_img_height, count))
print('')


# log
writer = Writer(0, args.folder_path)
print(f'folder path: {args.folder_path}')

# switch to EMA parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# sample = sample_image(gen_sde, v_plot, num_steps=args.num_steps, transform=from_sdf_to_01, clip=True, disable_tqdm=False, sampler=sampler)
args.eval_batch_size = min(num_samples, args.eval_batch_size)
sample = []
num_iters = int(np.ceil(num_samples / args.eval_batch_size))
for i in range(num_iters):
    print(i+1, '/', num_iters)
    sample_ = sample_image(gen_sde, batch_size=args.eval_batch_size, img_height=args.eval_img_height, num_steps=args.num_steps, transform=None, clip=False, disable_tqdm=False, sampler=sampler)
    sample += [sample_]
sample = torch.cat(sample, dim=0)[:num_samples]
nrow = int(num_samples**0.5)
writer.add_image('test/samples/{}/{}/{:d}x{:d}'.format(sampler, args.num_steps, args.eval_img_height, args.eval_img_height),
                 get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                 count,
                )
writer.add_image('test/samples/{}/{}/{:d}x{:d}/masked'.format(sampler, args.num_steps, args.eval_img_height, args.eval_img_height),
                 get_grid_image(from_sdf_to_mask(sample[:nrow**2]), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                 count,
                )
writer.flush()
print('Visualized samples at iter {:d}'.format(count))

# switch back to original parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# plot
image = get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=2, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_orig_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=2, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_masked_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()


In [None]:
from utils.utils import Writer

# init EM param
args.num_steps = 250

# init plot params
my_dpi = 300
num_samples = 64
args.eval_img_height = 128
args.eval_batch_size = 32
args.eval_use_ema = True
sampler = 'denoise'
print('Visualize {:d}x{:d}-res samples at iter {:d}'.format(args.eval_img_height, args.eval_img_height, count))
print('')


# log
writer = Writer(0, args.folder_path)
print(f'folder path: {args.folder_path}')
# print(str(args))

# switch to EMA parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# sample = sample_image(gen_sde, v_plot, num_steps=args.num_steps, transform=from_sdf_to_01, clip=True, disable_tqdm=False, sampler=sampler)
args.eval_batch_size = min(num_samples, args.eval_batch_size)
sample = []
num_iters = int(np.ceil(num_samples / args.eval_batch_size))
for i in range(num_iters):
    print(i+1, '/', num_iters)
    sample_ = sample_image(gen_sde, batch_size=args.eval_batch_size, img_height=args.eval_img_height, num_steps=args.num_steps, transform=None, clip=False, disable_tqdm=False, sampler=sampler)
    sample += [sample_]
sample = torch.cat(sample, dim=0)[:num_samples]
nrow = int(num_samples**0.5)
writer.add_image('test/samples/{}/{}/{:d}x{:d}'.format(sampler, args.num_steps, args.eval_img_height, args.eval_img_height),
                 get_grid_image(from_sdf_to_01(sample[:nrow**2]), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                 count,
                )
writer.add_image('test/samples/{}/{}/{:d}x{:d}/masked'.format(sampler, args.num_steps, args.eval_img_height, args.eval_img_height),
                 get_grid_image(from_sdf_to_mask(sample[:nrow**2]), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                 count,
                )
writer.flush()
print('Visualized samples at iter {:d}'.format(count))

# switch back to original parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# plot
image = get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=4, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_orig_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=4, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_masked_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

# plot
new_num_samples = 16
new_nrow = int(new_num_samples**0.5)
indices = ((torch.arange(new_nrow)*nrow)[...,None] + torch.arange(new_nrow)[None]).reshape(-1)

image = get_grid_image(from_sdf_to_01_clip(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_orig_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_masked_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()


In [None]:
from utils.utils import Writer

# init EM param
args.num_steps = 250

# init plot params
my_dpi = 300
num_samples = 64
args.eval_img_height = 256
args.eval_batch_size = 16
args.eval_use_ema = True
sampler = 'denoise'
print('Visualize {:d}x{:d}-res samples at iter {:d}'.format(args.eval_img_height, args.eval_img_height, count))
print('')


# log
writer = Writer(0, args.folder_path)
print(f'folder path: {args.folder_path}')

# switch to EMA parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# sample = sample_image(gen_sde, v_plot, num_steps=args.num_steps, transform=from_sdf_to_01, clip=True, disable_tqdm=False, sampler=sampler)
args.eval_batch_size = min(num_samples, args.eval_batch_size)
sample = []
num_iters = int(np.ceil(num_samples / args.eval_batch_size))
for i in range(num_iters):
    print(i+1, '/', num_iters)
    sample_ = sample_image(gen_sde, batch_size=args.eval_batch_size, img_height=args.eval_img_height, num_steps=args.num_steps, transform=None, clip=False, disable_tqdm=False, sampler=sampler)
    sample += [sample_]
sample = torch.cat(sample, dim=0)[:num_samples]
nrow = int(num_samples**0.5)
writer.add_image('test/samples/{}/{}/{:d}x{:d}'.format(sampler, args.num_steps, args.eval_img_height, args.eval_img_height),
                 get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                 count,
                )
writer.add_image('test/samples/{}/{}/{:d}x{:d}/masked'.format(sampler, args.num_steps, args.eval_img_height, args.eval_img_height),
                 get_grid_image(from_sdf_to_mask(sample[:nrow**2]), nrow=nrow, pad_value=0, padding=2, to_numpy=False),
                 count,
                )
writer.flush()
print('Visualized samples at iter {:d}'.format(count))

# switch back to original parameters
if args.eval_use_ema and hasattr(gen_sde_optimizer, 'swap_parameters_with_ema'):
    gen_sde_optimizer.swap_parameters_with_ema(store_params_in_ema=True)

# plot
image = get_grid_image(from_sdf_to_01_clip(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_orig_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[:nrow**2]).cpu(), nrow=nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_masked_n{num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

# plot
new_num_samples = 4
new_nrow = int(new_num_samples**0.5)
indices = ((torch.arange(new_nrow)*nrow)[...,None] + torch.arange(new_nrow)[None]).reshape(-1)

image = get_grid_image(from_sdf_to_01_clip(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_orig_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()

image = get_grid_image(from_sdf_to_mask(sample[indices]).cpu(), nrow=new_nrow, pad_value=0, padding=8, to_numpy=True)
plt.figure(figsize=(image.shape[0]/my_dpi, image.shape[1]/my_dpi), dpi=my_dpi)
# plt.figure(figsize=(12, 12))
plt.imshow(image)
ax = plt.gca()
ax.axis(False)
plt.savefig(f'samples/ddo_{args.eval_img_height}x{args.eval_img_height}_masked_n{new_num_samples}.pdf', dpi=my_dpi, bbox_inches='tight', transparent=True)
plt.show()
