In [1]:
# !wget https://gist.githubusercontent.com/huseinzol05/98974ae8c6c7a65d4bc0af9f5003786a/raw/2e06e71ef7349a57bc58cc9913ae6bae1f9f8447/mp.py

In [2]:
import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
# !pip3 install hf-transfer datasets -U 

In [4]:
from huggingface_hub import snapshot_download
from tqdm import tqdm
from glob import glob
import pandas as pd
import mp

In [5]:
folder = snapshot_download(repo_id="malaysia-ai/crawl-google-image-malaysia-location", repo_type = 'dataset')

In [6]:
files = glob(f'{folder}/data/*.parquet')
files = sorted(files, key = lambda x: int(x.split('/')[-1].split('-')[1]))
len(files)

1000

In [7]:
def loop(files):
    files, _ = files
    r = []
    for f in tqdm(files):
        try:
            df = pd.read_parquet(f)
            r.append({
                'f': f,
                'len': len(df)
            })
        except:
            pass
    return r

In [8]:
loop((files[:2], 0))

100%|██████████| 2/2 [00:00<00:00,  2.54it/s]


[{'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00000-of-01000.parquet',
  'len': 4039},
 {'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00001-of-01000.parquet',
  'len': 4039}]

In [9]:
r = mp.multiprocessing(files, loop, cores = 5)

100%|██████████| 200/200 [00:42<00:00,  4.71it/s]
100%|██████████| 200/200 [23:20<00:00,  7.00s/it]
100%|██████████| 200/200 [40:27<00:00, 12.14s/it]
100%|██████████| 200/200 [40:31<00:00, 12.16s/it]
100%|██████████| 200/200 [41:35<00:00, 12.48s/it]


In [12]:
import json

with open('lengths.json', 'w') as fopen:
    json.dump(r, fopen)

In [13]:
r

[{'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00000-of-01000.parquet',
  'len': 4039},
 {'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00001-of-01000.parquet',
  'len': 4039},
 {'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00002-of-01000.parquet',
  'len': 4039},
 {'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00003-of-01000.parquet',
  'len': 4039},
 {'f': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00004-

In [15]:
global_indices = {}
start = 0
for f in r:
    row = {
        'start': start,
        'end': f['len'],
        'filename': f['f'],
    }
    global_indices[start] = row
    start += f['len']

In [16]:
global_indices

{0: {'start': 0,
  'end': 4039,
  'filename': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00000-of-01000.parquet'},
 4039: {'start': 4039,
  'end': 4039,
  'filename': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00001-of-01000.parquet'},
 8078: {'start': 8078,
  'end': 4039,
  'filename': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00002-of-01000.parquet'},
 12117: {'start': 12117,
  'end': 4039,
  'filename': '/home/ubuntu/.cache/huggingface/hub/datasets--malaysia-ai--crawl-google-image-malaysia-location/snapshots/da8d9dc03a7d6779de6062836030a1c90c0e74ea/data/train-00003-of-01000.parquet'},
 16156: {'start': 16156,
  'end': 4039,
  'filename': '/

In [17]:
from torch.utils.data import DataLoader, Dataset

In [23]:
class Train(Dataset):
    def __init__(self, indices, maxlen_cache_df=5):
        self.indices = {}
        for k, v in indices.items():
            for i in range(int(k), v['start'] + v['end'], 1):
                self.indices[i] = v
        
        self.max_index = len(self.indices)
        self.cache_df = {}
        self.maxlen_cache_df = maxlen_cache_df
    
    def __len__(self):
        return self.max_index
    
    def __getitem__(self, item):
        if item < 0:
            item = self.max_index + item

        v = self.indices[item]
        chunk_index = item - v['start']
        if v['filename'] not in self.cache_df:
            df = pd.read_parquet(v['filename'])
            if len(self.cache_df) >= self.maxlen_cache_df:
                keys = list(self.cache_df.keys())
                self.cache_df.pop(sorted(keys)[0], None)
            self.cache_df[v['filename']] = df
        else:
            df = self.cache_df[v['filename']]
        
        row = df.iloc[chunk_index]
        # audio = self.audio.decode_example(self.audio.encode_example(row['filename']))
        return {'array': row}

In [24]:
train = Train(global_indices)

In [26]:
train[-1]

{'array': alt_text                  KETEREH, KELANTAN, MALAYSIA - Yelp ...
 parent_href    /imgres?q=taman%20Mesra2%20Ketereh&imgurl=http...
 filename                            image-location/54185-24.jpeg
 image          {'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...
 keyword                                     taman Mesra2 Ketereh
 Name: 4037, dtype: object}