In [15]:
import argparse
import os

import torch
from torchvision.utils import save_image
from tqdm import tqdm

from vqvae import VQVAE
from pixelsnail import PixelSNAIL
import numpy as np
# from matplotlib import pyplot as plt
from torch import nn
from skimage import io, color
from torchvision import datasets, transforms, utils

from matplotlib import pyplot as plt

from sklearn import preprocessing
import copy

In [16]:
@torch.no_grad()
def sample_model(model, device, batch, size, temperature, condition=None):
    row = torch.zeros(batch, *size, dtype=torch.int64).to(device)
    cache = {}

    for i in tqdm(range(size[0])):
        for j in range(size[1]):
            out, cache = model(row[:, : i + 1, :], condition=condition, cache=cache)
            prob = torch.softmax(out[:, :, i, j] / temperature, 1)
            sample = torch.multinomial(prob, 1).squeeze(-1)
            row[:, i, j] = sample

    return row


def load_model(model, checkpoint, device):

    if model == 'vqvae':
        model =    VQVAE(
                        in_channel=3,
                        channel=128,
                        n_res_block=6,
                        n_res_channel=32,
                        embed_dim=1,
                        n_embed=256,
                        decay=0.99
                        ).to(device)

    elif model == 'pixelsnail_top':
        model = PixelSNAIL(
            [32, 32],
            512,
            channel,
            5,
            4,
            n_res_block,
            n_res_channel,
            dropout=dropout,
            n_out_res_block=n_out_res_block,
        )

    elif model == 'pixelsnail_bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            channel,
            5,
            4,
            n_res_block,
            n_res_channel,
            attention=False,
            dropout=dropout,
            n_cond_res_block=n_cond_res_block,
            cond_res_channel=n_res_channel,
        )
        
    model.load_state_dict(torch.load(checkpoint),strict=False)
    model = model.to(device)
    model.eval()

    return model

In [3]:
def convert_img(out):
    utils.save_image(
        out,
        f"test1.png",
        nrow=4,
        normalize=True,
        )
    img=io.imread('test1.png')
    os.remove('test1.png')
    return img

In [17]:
device = 'cuda'
n_res_block=4
n_res_channel=256
n_out_res_block=0
n_cond_res_block=3
channel=256
dropout=0.1

batch=1

# vqvae='runs/pixelsnail_emb_dim_1_n_embed_256_bc_left_4x_768_zeros/bottom/vqvae_010_train_0.01976_test_0.01965.pt'
# top='runs/14.04.2024/65_pixelsnail_top_train_loss_1.281939_acc_0.379639_test_loss_1.230041_acc_0.391327.pt'
# bottom='runs/pixelsnail_emb_dim_1_n_embed_256_bc_left_4x_768_zeros/bottom/13_pixelsnail_bottom_train_loss_3.277416_acc_0.120651_test_loss_2.938762_acc_0.138824.pt'

# dataset_path = '../diffusion/runs/wc_co_median_latent/samples/'
# vqvae='../vqvae2_rosalinity/runs/emb_dim_1_n_embed_256_o_bc_left_4x_768_360_median/vqvae_006_train_0.00047_test_0.00053.pt'
vqvae='runs_unique/emb_dim_1_n_embed_256_bc_left_4x_768_360_728/vqvae_004_train_0.00041_test_0.00039.pt'

filename='test.png'
temp=10

model_vqvae = load_model('vqvae', vqvae, device)
# model_top = load_model('pixelsnail_top', top, device)
# model_bottom = load_model('pixelsnail_bottom', bottom, device)

In [19]:
# top_sample = sample_model(model_top, device, batch, [32, 32], temp)
top_sample_z=torch.zeros((1,64,64),dtype=torch.float32).to(device).unsqueeze(dim=0)

# bottom_sample = sample_model(
#     model_bottom, device, batch, [64, 64], temp, condition=top_sample_z
# )

names_new={'Ultra_Co11':'AB_Co11_medium',
           'Ultra_Co15':'AB_Co15_medium_small',
           'Ultra_Co25':'AB_Co25_small',
           'Ultra_Co6_2':'AB_Co6_large',
           'Ultra_Co8':'AB_Co_8_medium_small'}


for i in range(100):
    # for j in range(4):
        
        # bottom_sample_o=color.rgb2gray(io.imread(f'../diffusion/embs_images_emb_dim_1_n_embed_256_o_bc_left_4x_768_360_768_median_Ultra_Co11/g/00{i}.png')[:128,128*j:128*(j+1)])
        
        bottom_sample_o=color.rgb2gray(io.imread(f'../diffusion/runs/embs_images_emb_dim_1_n_embed_256_o_bc_left_4x_768_360_768_median_Ultra_Co25/generated/{i}.png'))
        
        old_shape=copy.copy(bottom_sample_o.shape)
        scaler=preprocessing.MinMaxScaler((-1.5,1.5))
        scaler.fit(bottom_sample_o.reshape(-1, 1))
        bottom_sample=scaler.transform(bottom_sample_o.reshape(-1, 1)).reshape(old_shape)

        bottom_sample=torch.FloatTensor(bottom_sample).unsqueeze(dim=0).unsqueeze(dim=0).to(device)

        model_vqvae.to('cuda')

        with torch.no_grad():
            decoded_sample=model_vqvae.decode(top_sample_z, bottom_sample)

        decoded_sample=convert_img(decoded_sample)

        # plt.imshow(decoded_sample)
        # plt.axis("off")
        # plt.savefig(f'{i}_{j}.png',bbox_inches='tight')
        # plt.imsave(f'{i}_{j}.png', decoded_sample)
        plt.imsave(f'latent_diff_dataset/Ultra_Co25/{i}.png', decoded_sample)
        # plt.show()

In [14]:
bottom_sample_o.shape

(256, 256)

In [8]:
top_sample_z=torch.zeros((1,64,64),dtype=torch.float32).to(device).unsqueeze(dim=0)

bottom_sample_o=color.rgb2gray(io.imread(f'latent_medium_0099.jpg'))

old_shape=copy.copy(bottom_sample_o.shape)
scaler=preprocessing.MinMaxScaler((-1.5,1.5))
scaler.fit(bottom_sample_o.reshape(-1, 1))
bottom_sample=scaler.transform(bottom_sample_o.reshape(-1, 1)).reshape(old_shape)

bottom_sample=torch.FloatTensor(bottom_sample).unsqueeze(dim=0).unsqueeze(dim=0).to(device)

model_vqvae.to('cuda')

with torch.no_grad():
    decoded_sample=model_vqvae.decode(top_sample_z, bottom_sample)

decoded_sample=convert_img(decoded_sample)

# plt.imshow(decoded_sample)
# plt.axis("off")
# plt.savefig(f'{i}_{j}.png',bbox_inches='tight')
plt.imsave(f'latent_medium_0099_restored.jpeg', decoded_sample)
# plt.show()

In [None]:
model_vqvae.to('cuda')

with torch.no_grad():
    decoded_sample=model_vqvae.decode(top_sample_z, bottom_sample)

decoded_sample=convert_img(decoded_sample)

In [None]:
plt.imshow(decoded_sample)
plt.axis("off")
plt.savefig('ps.png',bbox_inches='tight')
plt.show()

In [None]:
model=model_top
device=device
batch=5
size=[32, 32]
temp=1
condition=None
row = torch.zeros(batch, *size, dtype=torch.int64).to(device)
cache = {}

for i in tqdm(range(size[0])):
    for j in range(size[1]):
        out, cache = model(row[:, : i + 1, :], condition=condition, cache=cache)
        prob = torch.softmax(out[:, :, i, j] / temp, 1)
        sample = torch.multinomial(prob, 1).squeeze(-1)
        row[:, i, j] = sample

In [None]:
img=decoded_sample[0].cpu().detach().numpy()

img=np.swapaxes(img, 0,-1)