In [4]:
from univ.utils.folder2lmdb import dump_pickle
from torch.utils.data import DataLoader, Dataset
import os.path as osp
import os
import pickle
import lmdb
import random
import numpy as np
import pandas as pd
import json

In [5]:
# Input and output LMDB paths
input_lmdb_path = "/root/univ-data/imagenet100/train.lmdb"
output_lmdb_path = "/root/univ-data/imagenet50/train.lmdb"
# input_lmdb_path = "/root/data/imagenet100/val.lmdb"
# output_lmdb_path = "/root/data/imagenet50/val.lmdb"

num_sampled_classes = 50


In [None]:
class MyImageFolderLMDB(Dataset):
    def __init__(self, db_path, transform=None, target_transform=None):
        self.db_path = db_path
        self.transform = transform
        self.target_transform = target_transform

        env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path),
                        readonly=True, lock=False,
                        readahead=False, meminit=False)
        with env.begin(write=False) as txn:
            self.length = pickle.loads(txn.get(b'__len__'))
            self.keys = pickle.loads(txn.get(b'__keys__'))

    def open_lmdb(self):
        self.env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        self.txn = self.env.begin(write=False, buffers=True)
        self.length = pickle.loads(self.txn.get(b'__len__'))
        self.keys = pickle.loads(self.txn.get(b'__keys__'))

    def __getitem__(self, index):
        if not hasattr(self, 'txn'):
            self.open_lmdb()

        img, target = None, None
        byteflow = self.txn.get(self.keys[index])
        unpacked = pickle.loads(byteflow)

        # load image
        img = unpacked[0][0]

        # load label
        target = unpacked[1]
        target = target.squeeze()

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'


class ConsecutiveLabelMapper:
    def __init__(self) -> None:
        """The lmdb files for ImageNet50 do not have consecutive labels, e.g., there are labels 95, 96 etc.,
            although there are only 50 classes in total. This leads to errors when calculating cross entropy.
            The problem could be solved by recreating the lmdb files, but this class was faster to implement.
        """
        self.label_map = dict()
        self.num_classes = 0

    def __call__(self, label: np.ndarray):
        label = int(label)
        mapped_label = self.label_map.get(label, None)
        if mapped_label is None:
            self.label_map[label] = self.num_classes
            mapped_label = self.num_classes
            self.num_classes += 1
        return mapped_label


dataset = MyImageFolderLMDB(input_lmdb_path)
data_loader = DataLoader(dataset, num_workers=1)  # 1 worker or concurrency will kill the label mapping

if osp.exists(output_lmdb_path):
    os.remove(output_lmdb_path)
if osp.exists(output_lmdb_path+"-lock"):
    os.remove(output_lmdb_path+"-lock")


output_lmdb_path
isdir = osp.isdir(output_lmdb_path)

print(f"Generating LMDB to {output_lmdb_path}")
map_size = int(1e10)  # this should be adjusted based on OS/db size
db = lmdb.open(output_lmdb_path, subdir=isdir,
                map_size=map_size, readonly=False,
                meminit=False, map_async=True)

print(len(dataset), len(data_loader))

classes = {label.item() for _, label in data_loader}
random.seed(34786567)
sampled_classes = list(sorted(random.sample(classes, num_sampled_classes)))
print("Classes in original file", classes)
print("Sampled classes", sampled_classes)

write_frequency = 4000
n_samples = 0
sample_index = 0
all_labels = []
target_transform = ConsecutiveLabelMapper()
with db.begin(write=True) as txn:
    for image, label in data_loader:
        label = label.numpy()
        if label.squeeze() in sampled_classes:
            mapped_label = target_transform(label.squeeze())
            all_labels.append(mapped_label)
            txn.put(u'{}'.format(sample_index).encode('ascii'), dump_pickle((image, np.array(mapped_label))))
            sample_index += 1
            if sample_index % write_frequency == 0:
                print("[%d/%d]" % (sample_index, len(data_loader)))

keys = [u'{}'.format(k).encode('ascii') for k in range(sample_index)]
with db.begin(write=True) as txn:
    txn.put(b'__keys__', dump_pickle(keys))
    txn.put(b'__len__', dump_pickle(len(keys)))

print("Flushing database ...")
db.sync()
db.close()

# Save labels
labels_file_path = osp.splitext(output_lmdb_path)[0] + "_labels.csv"
pd.Series(all_labels).to_csv(labels_file_path)


# Save label mapping
original_dataset = osp.split(osp.split(input_lmdb_path)[0])[1]
output_dataset = osp.split(osp.split(output_lmdb_path)[0])[1]
label_map_path = osp.splitext(output_lmdb_path)[0] + f"_{original_dataset}_to_{output_dataset}_labelmap.json"
with open(label_map_path, "w") as f:
    json.dump(target_transform.label_map, f)