# Preprocess the laion-art dataset

In [None]:
from datasets import load_dataset

dataset_name = "laion/laion-art"
dataset = load_dataset(
            dataset_name,
            None,
            None,
        )

In [None]:
en_dataset = dataset['train'].filter(lambda example: example['LANGUAGE'] in ['en', 'nolang'])

In [None]:
import numpy as np
def print_arr(arr):
    print(np.mean(arr), np.min(arr), np.max(arr))

In [None]:
import random
ids = random.sample(range(len(en_dataset)), 10000)
small_en_dataset = en_dataset.select(ids)
small_en_dataset = small_en_dataset.train_test_split(test_size=0.1)
small_en_dataset

In [None]:
print_arr(small_en_dataset['train']['aesthetic'])
print_arr(en_dataset['aesthetic'])

In [None]:
# Generate code to download the images from url with multi-processing
import requests
from PIL import Image
import io
import cv2
import os
from tqdm import tqdm
import csv
import multiprocessing
from multiprocessing import Pool
import time

def download_image(url, im_path):
    try:
        response = requests.get(url).content
        im = Image.open(io.BytesIO(response))
        im.save(im_path)
    except Exception as e:
        print(f'Failed to save {im_path} due to {e}')

def download_images(split, sub_dataset):
    sub_im_dir = os.path.join(im_dir, split)
    os.makedirs(sub_im_dir, exist_ok=True)
    cvs_filename = os.path.join(sub_im_dir, "metadata.csv")
    with open(cvs_filename, 'w') as csvfile:
        # creating a csv writer object 
        csvwriter = csv.writer(csvfile)    
        # writing the fields
        # csvwriter.writerow(['file_name', 'text', 'aesthetic'])
        for im_id, example in tqdm(enumerate(sub_dataset)):
            im_path = os.path.join(sub_im_dir, f'{im_id}.png')
            if os.path.isfile(im_path):
                continue
            try:
                url, text, aesthetic = example['URL'], example['TEXT'], example['aesthetic']
                download_image(url, im_path)
                csvwriter.writerow([f'{im_id}.png', text, aesthetic])
            except Exception as e:
                print(f'Failed to save {im_path} due to {e}')

def download_images_mp(split, sub_dataset):
    sub_im_dir = os.path.join(im_dir, split)
    os.makedirs(sub_im_dir, exist_ok=True)
    cvs_filename = os.path.join(sub_im_dir, "metadata.csv")
    with open(cvs_filename, 'w') as csvfile:
        # creating a csv writer object 
        csvwriter = csv.writer(csvfile)    
        # writing the fields
        csvwriter.writerow(['file_name', 'text', 'aesthetic'])
        with multiprocessing.Pool(8) as p:
            for im_id, example in tqdm(enumerate(sub_dataset)):
                im_path = os.path.join(sub_im_dir, f'{im_id}.png')
                if os.path.isfile(im_path):
                    continue
                try:
                    url, text, aesthetic = example['URL'], example['TEXT'], example['aesthetic']
                    p.apply_async(download_image, args=(url, im_path))
                    csvwriter.writerow([f'{im_id}.png', text, aesthetic])
                except Exception as e:
                    print(f'Failed to save {im_path} due to {e}')
            p.close()
            p.join()

im_dir = "data/laion-art"
os.makedirs(im_dir, exist_ok=True)
# download_images_mp('train', small_en_dataset['train'])
download_images_mp('test', small_en_dataset['test'])

In [None]:
import os
im_dir = "data/laion-art"
sub_im_dir = os.path.join(im_dir, "train")
cvs_filename = os.path.join(sub_im_dir, "metadata.csv")

import pandas as pd
  
# reading the csv file
df = pd.read_csv(cvs_filename)
  
# updating the column value/data
# removeing the prefix from the file_name

df['file_name'] = df['file_name'].apply(lambda s: s[21:])
  
# writing into the file
df.to_csv(cvs_filename, index=False)

In [None]:
import os
from glob import glob
import csv

im_dir = "data/laion-art/train"

csv_filename = os.path.join(im_dir, "metadata.csv")
available_im_paths = set()
with open(csv_filename, 'r') as csvfile:
    csvreader = csv.reader(csvfile)
    for im_path, text, aesthetic in csvreader:
        available_im_paths.add(os.path.join(im_dir, im_path))
print(list(available_im_paths)[:10])
print(len(available_im_paths))
im_paths = glob(os.path.join(im_dir, "*.png"))
print(im_paths[:10])
for im_path in im_paths:
    if im_path not in available_im_paths:
        print(im_path)
        os.remove(im_path)

# Upload dataset to hugging face

In [None]:
from datasets import load_dataset
from huggingface_hub import create_repo

dataset = load_dataset("data/laion-art")
repo_id = create_repo(
                repo_id="fantasyfish/laion-art", exist_ok=True, token="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"
            ).repo_id
dataset.push_to_hub("fantasyfish/laion-art")