## What this notebook does
When downloading the dataset with `img2dataset` tool, the image width and height are not saved in the metadata. However, this information is saved in the parquet files. This notebook shows how to join the global metadata with the parquet files to get the width and height of the images.

In [1]:

import os
os.environ['AWS_CONFIG_FILE'] = '../.aws/config'

import json
import random
from pathlib import Path

import boto3
import s3fs
import polars as pl
from PIL import Image
from io import BytesIO

from fc_ai_pd12m.utils import safe_write_ipc

## Configure S3

In [2]:
def s3_auth():
    session = boto3.session.Session(profile_name='default')
    credentials = session.get_credentials().get_frozen_credentials()
    
    s3_fs = s3fs.S3FileSystem(
        key=credentials.access_key,
        secret=credentials.secret_key,
        endpoint_url='https://s3.gra.io.cloud.ovh.net',
    )
    return s3_fs, credentials

In [3]:
aws_region = 'gra'
aws_endpoint_url = 'https://s3.gra.io.cloud.ovh.net'
bucket_name = 'fc-gra-alejandria'
ds_path = f'{bucket_name}/ds/public/PD12M'

# Initialize boto3 S3 client
s3_fs, credentials = s3_auth()
s3_storage_options = {
    "aws_access_key_id": credentials.access_key,
    "aws_secret_access_key": credentials.secret_key,
    "endpoint_url": aws_endpoint_url,
    "aws_region": aws_region,
}

## Test different join methods

In [4]:
# df = pl.DataFrame(
#     {
#         "foo": [1, 2, 3],
#         "bar": [6.0, 7.0, 8.0],
#         "ham": ["a", "b", "c"],
#     }
# )
# other_df = pl.DataFrame(
#     {
#         "apple": ["x", "y", "z"],
#         "ham": ["a", "b", "d"],
#     }
# )

# display(df)
# display(other_df)

# for strategy in ["inner", "left", "right", "full"]:
#     print(f"Strategy: {strategy}")
#     joined_df = df.join(other_df, on="ham", how=strategy)
#     display(joined_df)


## Read data

In [None]:
orig_parquet_folder = Path("/ssd/datasets/pd12m/metadata")
if not orig_parquet_folder.exists():
    raise FileNotFoundError(f"Folder {orig_parquet_folder} does not exist")

orig_parquet_files = list(orig_parquet_folder.glob("*.parquet"))
orig_parquet_files = [str(file) for file in orig_parquet_files]
if len(orig_parquet_files) == 0:
    raise FileNotFoundError(f"No parquet files found in {orig_parquet_folder}")
print(f"Number of original parquet files: {len(orig_parquet_files)}")

downloaded_parquet_folder = "s3://fc-gra-alejandria/ds/public/PD12M"
downloaded_parquet_files = list(s3_fs.glob(downloaded_parquet_folder + "/*.parquet"))
downloaded_parquet_files = [f"s3://{file}" for file in downloaded_parquet_files]
if len(downloaded_parquet_files) == 0:
    raise FileNotFoundError(f"No parquet files found in {downloaded_parquet_folder}")
print(f"Number of downloaded parquet files: {len(downloaded_parquet_files)}")

In [6]:
def join_parquet_files_in_dataset(parquet_files: list[str], s3_storage_options: dict):
    # Read parquet files in parallel using polars
    dfs = []
    for parquet_file in parquet_files:
        if 's3://' in parquet_file:
            df = pl.read_parquet(parquet_file, storage_options=s3_storage_options)
        else:
            df = pl.read_parquet(parquet_file)
        dfs.append(df)
    
    # Concatenate all dataframes
    final_df = pl.concat(dfs, how="vertical")
    
    return final_df

In [None]:
orig_df = join_parquet_files_in_dataset(
    orig_parquet_files,
    s3_storage_options=s3_storage_options
)
print(f"Number of rows in original dataframe: {orig_df.height}")

downloaded_df = join_parquet_files_in_dataset(
    downloaded_parquet_files,
    s3_storage_options=s3_storage_options
)
print(f"Number of rows in downloaded dataframe: {downloaded_df.height}")

# Join the dataframes on the `url` column
joined_df = downloaded_df.join(orig_df, on="url", how="left")
print(f"Number of rows in joined dataframe: {joined_df.height}")

### Format the joined dataframe

In [None]:
feather_path = Path("output/global_pd12m_data.feather")
if not feather_path.exists():
    raise FileNotFoundError(f"File {feather_path} does not exist")

df = pl.read_ipc(feather_path, storage_options=s3_storage_options)
print(f"Number of rows in joined dataframe: {joined_df.height}")   

In [None]:

# Drop columns that are not needed
columns_to_drop = [
    col for col in ["error_message", "status", "width", "height", 
                    "original_width", "original_height", "sha256", 
                    "hash", "caption_right", "id"] 
    if col in joined_df.columns
]
joined_df = joined_df.drop(columns_to_drop)

# Rename columns
columns_to_rename = {
    "width_right": "image_width",
    "height_right": "image_height",
    "key": "image_id",
}
existing_keys = []
for k in columns_to_rename.keys():
    if k in joined_df.columns:
        existing_keys.append(k)
joined_df = joined_df.rename({k: v for k, v in columns_to_rename.items() if k in existing_keys})

# Get image path
joined_df = joined_df.with_columns(
    pl.col("image_id").str.slice(0, 5).cast(pl.Utf8).alias("parquet_id"),
)
joined_df = joined_df.with_columns(
    (pl.col("parquet_id") + "/" + pl.col("image_id") + ".jpg").alias("image_path"),
)
joined_df = joined_df.with_columns(
    (pl.col("image_width") / pl.col("image_height")).cast(pl.Float64).alias("aspect_ratio"),  
)

display(joined_df.head())


### Save the joined dataframe

In [None]:
dest_path = "output/global_pd12m_data.feather"
try:
    safe_write_ipc(joined_df, dest_path, s3_fs=s3_fs)
    print(f"Joined dataframe saved to {dest_path}")
except Exception as e:
    print(f"Error writing joined dataframe: {e}")