In [3]:
import argparse
import pickle

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import lmdb
from tqdm import tqdm

from dataset import ImageFileDataset, CodeRow
from vqvae import VQVAE

In [4]:
def extract(lmdb_env, loader, model, device):
    index = 0

    with lmdb_env.begin(write=True) as txn:
        pbar = tqdm(loader)

        for img, _, filename in pbar:
            img = img.to(device)

            _, _, _, id_t, id_b = model.encode(img)
            id_t = id_t.detach().cpu().numpy()
            id_b = id_b.detach().cpu().numpy()

            for file, top, bottom in zip(filename, id_t, id_b):
                row = CodeRow(top=top, bottom=bottom, filename=file)
                txn.put(str(index).encode('utf-8'), pickle.dumps(row))
                index += 1
                pbar.set_description(f'inserted: {index}')

        txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))

In [7]:

resize_shape = (256, 256)
device = 'cuda'
dataset_path = '../../datasets/original/o_bc_left_4x_768'
model_path='runs/emb_dim_64_n_embed_512_bc_left_4x_768_360/vqvae_013_train_0.01824_test_0.0179.pt'
new_path='runs/embs_emb_dim_64_n_embed_512_bc_left_4x_768'

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = ImageFileDataset(dataset_path, transform=transform)
loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

model = VQVAE()
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()

map_size = 100 * 1024 * 1024 * 1024

env = lmdb.open(new_path, map_size=map_size)

extract(env, loader, model, device)

inserted: 1800: 100%|██████████████████████████████████████████████████████████████████| 15/15 [00:12<00:00,  1.23it/s]
