In [1]:
import argparse
import pickle

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import lmdb
from tqdm import tqdm

from dataset import ImageFileDataset, CodeRow
from vqvae import VQVAE
from skimage import io, color
import numpy as np
import copy

from sklearn import preprocessing
from matplotlib import pyplot as plt

import time
import os

In [None]:
def extract(lmdb_env, loader, model, device):
    index = 0

    with lmdb_env.begin(write=True) as txn:
        pbar = tqdm(loader)

        for img, _, filename in pbar:
            img = img.to(device)

            _, _, _, id_t, id_b = model.encode(img)
            id_t = id_t.detach().cpu().numpy()
            id_b = id_b.detach().cpu().numpy()

            for file, top, bottom in zip(filename, id_t, id_b):
                row = CodeRow(top=top, bottom=bottom, filename=file)
                txn.put(str(index).encode('utf-8'), pickle.dumps(row))
                index += 1
                pbar.set_description(f'inserted: {index}')

        txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))

In [2]:
def extract2img(new_path, loader, model, device):
    
    if not os.path.exists(new_path):
        os.mkdir(new_path)
        
    pbar = tqdm(loader)

    for img, _, filename in pbar:
        img = img.to(device)
        
        # quant_t, quant_b, diff_t + diff_b, id_t, id_b
        _, quant_b, _, _, _ = model.encode(img)
        # id_t = id_t.detach().cpu().numpy()
        quant_b = quant_b.detach().cpu().numpy()
        
        for i, file in enumerate(quant_b):
            old_shape=copy.copy(file.shape)
            scaler=preprocessing.MinMaxScaler((0,255))
            scaler.fit(file.reshape(-1, 1))
            
            file=scaler.transform(file.reshape(-1, 1)).reshape(old_shape)
            file=file[0]
            file=np.repeat(file[:, :, np.newaxis], 3, axis=2)
            file=file.astype(np.uint8)
            folder, name=filename[i].split('\\')
            
            if not os.path.exists(f'{new_path}/{folder}'):
                os.mkdir(f'{new_path}/{folder}')
            
            io.imsave(f'{new_path}/{folder}/{name}', file)


In [5]:
# resize_shape = (256, 256)

names_new={'Ultra_Co11':'AB_Co11_medium',
           'Ultra_Co15':'AB_Co15_medium_small',
           'Ultra_Co25':'AB_Co25_small',
           'Ultra_Co6_2':'AB_Co6_large',
           'Ultra_Co8':'AB_Co_8_medium_small'}


resize_shape = (512, 512)
device = 'cuda'

dataset_path = '../../datasets/original/o_bc_left_4x_768_360_median'
model_path='runs_unique/emb_dim_1_n_embed_256_bc_left_4x_768_360_728/vqvae_004_train_0.00041_test_0.00039.pt'
new_path='runs/embs_images_emb_dim_1_n_embed_256_o_bc_left_4x_768_360_768_median'

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = ImageFileDataset(dataset_path, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)

model =    VQVAE(
                in_channel=3,
                channel=128,
                n_res_block=6,
                n_res_channel=32,
                embed_dim=1,
                n_embed=256,
                decay=0.99
                ).to(device)

model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()

map_size = 100 * 1024 * 1024 * 1024

# env = lmdb.open(new_path, map_size=map_size)

# extract(env, loader, model, device)

In [None]:
extract2img(new_path, loader, model, device)

100%|███████████████████████████████████████████████████████████████████████████████▋| 225/226 [00:59<00:00,  4.89it/s]

In [None]:
plt.imshow(file.swapaxes(0,-1),cmap='gray')
plt.show()