In [21]:
import random
import os
from typing import NamedTuple

import numpy as np
import torch
import torchvision
from torch.utils import data
from PIL import Image

import model
from tensor_transforms import convert_to_coord_format

In [22]:
class FFHQ256Arguments(NamedTuple):
    """CIPSskip for FFHQ-256"""
    Generator = 'CIPSskip'
    output_dir = 'ffhq256_g_ema.pt'
    out_path = 'checkpoint'
    size = 256
    coords_size = 256
    fc_dim = 512 
    latent = 512
    style_dim = 512
    n_mlp = 8
    activation = None
    channel_multiplier = 2
    ckpt = os.path.join(out_path, output_dir)
    coords_integer_values = False
    path="checkpoint/ffhq_256_g_ema.pt"

In [None]:
class FFHQ1024Arguments(NamedTuple):
    """CIPSskip-progressive for FFHQ-1024"""
    Generator = 'CIPSskip'
    output_dir = 'ffhq1024_g_ema.pt'
    out_path = 'checkpoint'
    size = 256
    coords_size = 1024
    fc_dim = 512 
    latent = 512
    style_dim = 512
    n_mlp = 8
    activation = None
    channel_multiplier = 2
    ckpt = os.path.join(out_path, output_dir)
    coords_integer_values = False
    path="checkpoint/ffhq1024_g_ema.pt"

In [None]:
class Churches256Arguments(NamedTuple):
    """CIPSskip for LSUN-Churches-256"""
    Generator = 'CIPSskip'
    output_dir = 'churches_g_ema.pt'
    out_path = 'checkpoint'
    size = 256
    coords_size = 256
    fc_dim = 512 
    latent = 512
    style_dim = 512
    n_mlp = 8
    activation = None
    channel_multiplier = 2
    ckpt = os.path.join(out_path, output_dir)
    coords_integer_values = False
    path="checkpoint/churchs_g_ema.pt"

In [None]:
class Lanscapes256Arguments(NamedTuple):
    """CIPSres for Landscapes-256"""
    Generator = 'CIPSres'
    output_dir = 'landscapes_g_ema.pt'
    out_path = 'checkpoint'
    size = 256
    coords_size = 256
    fc_dim = 512 
    latent = 512
    style_dim = 512
    n_mlp = 8
    activation = None
    channel_multiplier = 2
    ckpt = os.path.join(out_path, output_dir)
    coords_integer_values = False
    path="checkpoint/landscapes_g_ema.pt"

In [None]:
args=FFHQ256Arguments()
device = 'cuda'

In [None]:
Generator = getattr(model, args.Generator)
g_ema = Generator(size=args.size, hidden_size=args.fc_dim, style_dim=args.latent, n_mlp=args.n_mlp,
                  activation=args.activation, channel_multiplier=args.channel_multiplier,
                  ).to(device)
g_ema.eval();

In [None]:
path=args.path
ckpt = torch.load(path)
g_ema.load_state_dict(ckpt)

In [None]:
def get_image(tensor, nrow=2, padding=2,
               normalize=False, range=None, scale_each=False, pad_value=0):
    
    grid = torchvision.utils.make_grid(tensor, nrow=nrow,padding=padding, pad_value=pad_value,
                     normalize=normalize, range=range, scale_each=scale_each)
    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
    im = Image.fromarray(ndarr)
    return im

## Finding mean for truncation trick 

In [None]:
n_sample = 1
sample_z = torch.randn(n_sample, args.latent, device=device)
converted_full = convert_to_coord_format(sample_z.size(0), args.coords_size, args.coords_size, device,
                                         integer_values=args.coords_integer_values)

latents = []
samples = []
with torch.no_grad():
    for _ in range(100):
        sample_z = torch.randn(n_sample, args.latent, device=device)
        sample, latent = g_ema(converted_full, [sample_z], return_latents=True)
        latents.append(latent.cpu())
        samples.append(sample.cpu())

samples = torch.cat(samples, 0)
latents = torch.cat(latents, 0)

truncation_latent = latents.mean(0).cuda()

print('truncation_latent', truncation_latent.shape)
assert len(truncation_latent.shape)==1 and truncation_latent.size(0) == 512, 'smt wrong'

## Sampling with truncation trick 

In [31]:
n_sample = 8
sample_z = torch.randn(n_sample, args.latent, device=device)
converted_full = convert_to_coord_format(sample_z.size(0), args.coords_size, args.coords_size, device,
                                         integer_values=args.coords_integer_values)
    
print(converted_full)
with torch.no_grad():
    style = g_ema.style(sample_z)
    print(style.size())
    sample, _ = g_ema(converted_full, [style], 
                      truncation=0.6,
                      truncation_latent=truncation_latent,
                      input_is_latent=True,)
    
im = get_image(sample,                        
                nrow=4,
                normalize=True,
                range=(-1, 1),)

im.show()

tensor([[[[-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000],
          [-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000],
          [-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000],
          ...,
          [-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000],
          [-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000],
          [-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000]],

         [[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-0.9843, -0.9843, -0.9843,  ..., -0.9843, -0.9843, -0.9843],
          ...,
          [ 0.9843,  0.9843,  0.9843,  ...,  0.9843,  0.9843,  0.9843],
          [ 0.9922,  0.9922,  0.9922,  ...,  0.9922,  0.9922,  0.9922],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]]],


        [[[-1.0000, -0.9922, -0.9843,  ...,  0.9843,  0.9922,  1.0000],
          [-1.0000, -0.9922,

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
ERROR:root:dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/poong2/CIPS/3sdb7cub/file_stream
NoneType: None


# 좌표 변환 실험.

In [30]:
n_sample = 1
sample_z = torch.randn(n_sample, args.latent, device=device)
converted_full = convert_to_coord_format(sample_z.size(0), args.coords_size, args.coords_size, device,
                                         integer_values=args.coords_integer_values)

    
print(converted_full)
with torch.no_grad():
    style = g_ema.style(sample_z)
    print(style.size())
    sample, _ = g_ema(converted_full, [style], 
                      truncation=0.6,
                      truncation_latent=truncation_latent,
                      input_is_latent=True,)
    
im = get_image(sample,                        
                nrow=int(n_sample ** 0.5),
                normalize=True,
                range=(-1, 1),)

im.show()

tensor([[[[-1.0000, -0.9980, -0.9961,  ..., -0.0039, -0.0020,  0.0000],
          [-1.0000, -0.9980, -0.9961,  ..., -0.0039, -0.0020,  0.0000],
          [-1.0000, -0.9980, -0.9961,  ..., -0.0039, -0.0020,  0.0000],
          ...,
          [-1.0000, -0.9980, -0.9961,  ..., -0.0039, -0.0020,  0.0000],
          [-1.0000, -0.9980, -0.9961,  ..., -0.0039, -0.0020,  0.0000],
          [-1.0000, -0.9980, -0.9961,  ..., -0.0039, -0.0020,  0.0000]],

         [[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-0.9980, -0.9980, -0.9980,  ..., -0.9980, -0.9980, -0.9980],
          [-0.9961, -0.9961, -0.9961,  ..., -0.9961, -0.9961, -0.9961],
          ...,
          [-0.0039, -0.0039, -0.0039,  ..., -0.0039, -0.0039, -0.0039],
          [-0.0020, -0.0020, -0.0020,  ..., -0.0020, -0.0020, -0.0020],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]],
       device='cuda:0')
torch.Size([1, 512])


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
ERROR:root:dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/poong2/CIPS/3sdb7cub/file_stream
NoneType: None
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
ERROR:root:dropped chunk 404 Client Error: Not Found for url: https://api.wandb.ai/files/poong2/CIPS/3sdb7cub/file_stream
NoneType: None


# Wandb 코드

In [35]:

import wandb
wandb.init(
    project="CIPS",
    name="FFHQ-256 pretrained generating",
)
example=[]
n_sample=8
for iter in range(100):
    sample_z = torch.randn(n_sample, args.latent, device=device)
    converted_full = convert_to_coord_format(sample_z.size(0), args.coords_size, args.coords_size, device,
                                         integer_values=args.coords_integer_values)

    with torch.no_grad():
        style = g_ema.style(sample_z)
        print(style.size())
        sample, _ = g_ema(converted_full, [style], 
                        truncation=0.6,
                        truncation_latent=truncation_latent,
                        input_is_latent=True,)
        
        im = get_image(sample,                        
                        nrow=4,
                        normalize=True,
                        range=(-1, 1),)
        image=wandb.Image(im,caption=f"iter{iter}")
        example.append(image)
wandb.log({"examples":example})
wandb.finish()
    
    

VBox(children=(Label(value=' 76.78MB of 76.78MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

[34m[1mwandb[0m: wandb version 0.12.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


torch.Size([8, 512])




torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8, 512])
torch.Size([8

VBox(children=(Label(value=' 70.66MB of 70.66MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

In [37]:
import wandb
from tensor_transforms import convert_to_dif_coord_format
wandb.init(
    project="CIPS",
    name="FFHQ-256 [-1,0][-1,0]",
)
example=[]
n_sample=8
for iter in range(100):
    sample_z = torch.randn(n_sample, args.latent, device=device)
    converted_full = convert_to_dif_coord_format(sample_z.size(0), args.coords_size, args.coords_size, device,
                h_range=(-1.0,0),w_range=(-1.0,0),integer_values=args.coords_integer_values)

    with torch.no_grad():
        style = g_ema.style(sample_z)
        print(style.size())
        sample, _ = g_ema(converted_full, [style], 
                        truncation=0.6,
                        truncation_latent=truncation_latent,
                        input_is_latent=True,)
        
        im = get_image(sample,                        
                        nrow=4,
                        normalize=True,
                        range=(-1, 1),)
        image=wandb.Image(im,caption=f"iter{iter}")
        example.append(image)
wandb.log({"examples":example})

ImportError: cannot import name 'convert_to_dif_coord_format' from 'tensor_transforms' (/home/jaepoong/바탕화면/CIPS/tensor_transforms.py)