In [1]:
import os
import cv2
import lmdb
import caffe
import random
import numpy as np
from datetime import datetime

In [12]:
def parse_txt(txt_name):
    name_labels = []
    with open(txt_name, 'r') as f:
        for line in f.readlines():
            name, label = line.strip().split()
            label = int(label)
            name_labels.append([name, label])
    return name_labels

def parse_dir(dir_name):
    name_labels = []
    classes = os.listdir(dir_name)
    classes.sort()
    for i,class_i in enumerate(classes):
        names = [f for f in os.listdir(os.path.join(dir_name, class_i)) if f.endswith(".jpg")]
        for name in names:
            name = os.path.join(dir_name, class_i, name)
            name_labels.append([name, i])
    return name_labels

def read_image(img_name, size=224):
    img = cv2.imread(img_name)
    img = cv2.resize(img, (size, size))
    img = np.transpose(img, (2, 0, 1)) # hwc to chw
    return img

def make_lmdb(data_src, lmdb_name, append=False, size=224):
    if os.path.isfile(data_src):
        name_labels = parse_txt(data_src)
    elif os.path.isdir(data_src):
        name_labels = parse_dir(data_src)
    else:
        print("invalid data source.")
        return
    
    random.shuffle(name_labels)
    random.shuffle(name_labels)
    random.shuffle(name_labels)
    
    N = len(name_labels)
    map_size = N * 3 * size * size * 3  # last 3 for redundancy purposes
    # map_size = 137438953472 # 1 TB, still good for use
    
    env = lmdb.open(lmdb_name, map_size=map_size)
    
    if append:
        max_key = env.stat()["entries"] # assume all str_ids are consecutive and start from 0,1,2,3,...
        # # a more thorough approach
        # max_key = 0
        # for key, value in env.cursor():
        #     max_key = max(max_key, key)
        
        # reopen env with bigger size
        env.close()
        map_size = (max_key + N) * 3 * size * size * 3
        env = lmdb.open(lmdb_name, map_size=map_size)
        
        print("Existed num of samples: ", max_key)
    else:
        max_key = 0
        print("Create new lmdb.")
    
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i,name_label in enumerate(name_labels):
            name, label = name_label
            img = read_image(name, size)
            
            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = 3
            datum.height = size
            datum.width = size
            datum.data = img.tobytes()  # or .tostring() if numpy < 1.9
            datum.label = label
            str_id = '{:08}'.format(i + max_key)

            # The encode is only essential in Python 3
            txn.put(str_id.encode('ascii'), datum.SerializeToString())
            
            if i % 1000 == 0:
                print(datetime.now(), " --> ", i)

                
def get_num_records(lmdb_name):
    env = lmdb.open(lmdb_name)
    max_key = env.stat()["entries"] # assume all str_ids are consecutive and start from 0,1,2,3,...
    env.close()
    return max_key


def delete_from_lmdb(lmdb_name, keys, size=224):
    N = get_num_records(lmdb_name)
    map_size = N * 3 * size * size * 3  # last 3 for redundancy purposes
    env = lmdb.open(lmdb_name, map_size=map_size)
    with env.begin(write=True) as txn:
        for key in keys:
            str_id = '{:08}'.format(key).encode('ascii')
            status = txn.delete(str_id)
            if not status:
                print("Does not exist:", key)

In [None]:
data_src = "/home/hdd0/Data/caffe-testdata/dataset_kaggledogvscat/test_set"
lmdb_name = "/home/hdd0/Data/caffe-testdata/dataset_kaggledogvscat/test_set-lmdb"

%time make_lmdb(data_src, lmdb_name, append=False)

In [13]:
lmdb_name = "/home/hdd0/Data/caffe-testdata/dataset_kaggledogvscat/test_set-lmdb"
print(get_num_records(lmdb_name))

keys = [2, 5, 19, 32, 35, 85]
delete_from_lmdb(lmdb_name, keys)

print(get_num_records(lmdb_name))

1994
Does not exist: 2
Does not exist: 5
Does not exist: 19
Does not exist: 32
Does not exist: 35
Does not exist: 85
1994


In [None]:
N = 1000

# Let's pretend this is interesting data
X = np.random.randint(low=0, high=255, size=(N, 3, 32, 32), dtype=np.uint8)
y = np.random.randint(low=0, high=1, size=N, dtype=np.int64)

# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 3
print(X.nbytes, X.shape)

env = lmdb.open('mylmdb', map_size=map_size)

with env.begin(write=True) as txn:
    # txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
        datum.label = int(y[i])
        str_id = '{:08}'.format(i)

        # The encode is only essential in Python 3
        txn.put(str_id.encode('ascii'), datum.SerializeToString())

In [None]:
# env = lmdb.open('mylmdb', readonly=True)
# with env.begin() as txn:
#     raw_datum = txn.get(b'00000000')

# datum = caffe.proto.caffe_pb2.Datum()
# datum.ParseFromString(raw_datum)

# flat_x = np.fromstring(datum.data, dtype=np.uint8)
# x = flat_x.reshape(datum.channels, datum.height, datum.width)
# y = datum.label
# print(x.shape, y)

In [None]:
value0, value1000 = None, None
with env.begin() as txn:
    cursor = txn.cursor()
    for key, value in cursor:
        print(key)
        if key == b'00000000':
            print(type(key), key)
            value0 = value
        if key == b'00001000':
            print(type(key), key)
            value1000 = value
    
print(type(value0), type(value1000))
print(value0 == value1000)

In [None]:
env = lmdb.open('mylmdb', map_size=map_size)

with env.begin(write=True) as txn:
    # txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
        datum.label = int(y[i])
        str_id = '{:08}'.format(i+1000)

        # The encode is only essential in Python 3
        txn.put(str_id.encode('ascii'), datum.SerializeToString())