In [3]:
from tqdm import tqdm
import pandas as pd
import os

_WIKIDATA_RELEVANT_COLUMNS = ['language', 'page_url', 'image_url', 'caption_reference_description', 'caption_attribution_description', 'page_title', 'section_title']
_DOWNLOAD_ROOT_URL = "https://storage.googleapis.com/gresearch/wit/"
_FILENAME = "wit_v1.{split}.all-{shard}-of-{num_shards}"
_ROOT_OUTPUT_FOLDER = 'wiki_data/'

### Download, unarchive, and process the sharded data

In [None]:
# keeping in reverse, since test / val sets are much smaller and any issues with the code would be surfaced earlier. 
for split in ['test', 'val', 'train']:
    num_shards = 5 if split in ('test', 'val') else 10
    output_dir = os.path.join(_ROOT_OUTPUT_FOLDER, split)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    def get_shard_str(n):
        shard_str = str(n)
        shard_str = "0" * (5 - len(shard_str)) + shard_str
        return shard_str

    for shard in tqdm(range(num_shards)):
        filename = _FILENAME.format(split=split, shard=get_shard_str(shard), num_shards=get_shard_str(num_shards))
        zipped_filename = filename + ".tsv.gz"
        url = _DOWNLOAD_ROOT_URL + zipped_filename
        
        # TODO(pshishodia): Ideally, all of the 3 processes - download, unarchive, trim tsv can be done parallely.
        # i.e, when I'm unarchiving second file, I can download the third. so these can be parallelised. 
        # Offline, I just print these commands in a terminal and use a while loop to check whether I can execute the trim 
        # step every 10s. 
        !wget {url}
        !pigz -d -f {filename + '.tsv'}  # pigz is gzip with parallelization.
        

        df = pd.read_csv(filename + ".tsv", sep='\t', usecols=_WIKIDATA_RELEVANT_COLUMNS)
        df.to_csv(os.path.join(output_dir,  filename + ".csv"))
        os.remove(filename + ".tsv")

### Combine the data
This takes <= 5s for val/test, but ~10 minutes for train.

In [4]:
for split in ['test', 'val', 'train']:
    output_dir = os.path.join(_ROOT_OUTPUT_FOLDER, split)

    # List all csv files in the _ROOT_OUTPUT_FOLDER
    csv_files = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith('.csv')]
    print(f"{len(csv_files)=}")

    # Combine all csv files into a single dataframe
    print(f"Reading {split=}")
    combined_df = pd.concat([pd.read_csv(f) for f in csv_files])
    try:
        combined_df = combined_df.drop(columns=['Unnamed: 0'])
    except:
        pass
    
    print(f"Deduplicating {split=}")
    combined_df = combined_df.groupby(by=['page_url', 'image_url'], as_index=False).first()
    
    print(f"Shuffling {split=}")
    combined_df = combined_df.sample(frac=1).reset_index(drop=True)
    
    print(f"Saving {split=}")
    combined_df.to_csv(os.path.join(_ROOT_OUTPUT_FOLDER, split + ".csv"), index=False)
    
    print(f"================= Completed {split=} ================= ")