In [1]:
import h5py
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
from datasets import load_dataset
from time import time
from tqdm import tqdm
from collections import Counter

### Square damage generator:


In [2]:
def generate_square_damage(image: Image.Image, min_square_side=14, max_square_side=36) -> Image.Image:
    image = image.copy()
    draw = ImageDraw.Draw(image)

    square_side = np.random.randint(min_square_side, max_square_side + 1)

    x = np.random.randint(0, image.size[0] - square_side)
    y = np.random.randint(0, image.size[1] - square_side)
    square_coords = [(x, y), (x + square_side, y + square_side)]

    draw.rectangle(square_coords, fill=(255, 255, 255))
    
    return image

### Irregular damage generator:

In [3]:
def generate_irregular_damage(image: Image.Image, max_radius=16, min_radius=8, points=10) -> Image.Image:
    image = image.copy()
    draw = ImageDraw.Draw(image)

    x_center = np.random.randint(max_radius, image.size[0] - max_radius)
    y_center = np.random.randint(max_radius, image.size[1] - max_radius)

    angles = np.linspace(0, 2 * np.pi, points, endpoint=False)
    radii = np.random.randint(min_radius, max_radius, size=points)
    vertices = [
        (
            int(x_center + radius * np.cos(angle)),
            int(y_center + radius * np.sin(angle))
        )
        for angle, radius in zip(angles, radii)
    ]

    vertices.append(vertices[0])

    draw.polygon(vertices, fill=(255, 255, 255))
    
    return image

### Loading WikiArt from HuggingFace

In [4]:
ds = load_dataset('Artificio/WikiArt', split='train')
print('ds loaded')

README.md:   0%|          | 0.00/663 [00:00<?, ?B/s]

dataset_infos.json:   0%|          | 0.00/1.30k [00:00<?, ?B/s]

(…)-00000-of-00004-3c65976b59bc0ab4.parquet:   0%|          | 0.00/426M [00:00<?, ?B/s]

(…)-00001-of-00004-441bd829579dead0.parquet:   0%|          | 0.00/428M [00:00<?, ?B/s]

(…)-00002-of-00004-7b0bbb36fb350222.parquet:   0%|          | 0.00/429M [00:00<?, ?B/s]

(…)-00003-of-00004-971fec8ddd44fece.parquet:   0%|          | 0.00/429M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/103250 [00:00<?, ? examples/s]

ds loaded


### Samples by style distribution

In [5]:
style_counts = Counter(ds['style'])

total_count = sum(style_counts.values())
print(f'Style count: {total_count}\n')

for key, value in style_counts.items():
    print(f'{key:<30} {value}')

Style count: 103250

New Realism                    329
Surrealism                     4167
Expressionism                  7013
Art Nouveau (Modern)           4899
Symbolism                      3476
Realism                        10523
Early Renaissance              1351
Divisionism                    338
Romanticism                    9285
Post-Impressionism             5778
Impressionism                  10643
Baroque                        4400
Naïve Art (Primitivism)        2295
Fauvism                        731
Pointillism                    501
Pop Art                        791
Abstract Expressionism         2074
Neo-Dada                       131
Art Informel                   1267
Abstract Art                   979
Luminism                       385
Neoclassicism                  2038
Cubism                         1747
Mannerism (Late Renaissance)   1342
Op Art                         528
Neo-Rococo                     97
Neo-Expressionism              420
Proto Renaissance

### Filtering by style threshold

In [6]:
threshold = total_count * 0.006

filtered_style_counts = {key: value for key, value in style_counts.items() if value > threshold}

print(f'Filtered count: {sum(filtered_style_counts.values())}\n')

style_filter = set(filtered_style_counts.keys())

for key, value in filtered_style_counts.items():
    print(f'{key:<30} {value}')

Filtered count: 87835

Surrealism                     4167
Expressionism                  7013
Art Nouveau (Modern)           4899
Symbolism                      3476
Realism                        10523
Early Renaissance              1351
Romanticism                    9285
Post-Impressionism             5778
Impressionism                  10643
Baroque                        4400
Naïve Art (Primitivism)        2295
Fauvism                        731
Pop Art                        791
Abstract Expressionism         2074
Art Informel                   1267
Abstract Art                   979
Neoclassicism                  2038
Cubism                         1747
Mannerism (Late Renaissance)   1342
None                           986
Northern Renaissance           2379
High Renaissance               1314
Academicism                    972
Ukiyo-e                        1426
Rococo                         2733
Magic Realism                  1002
Lyrical Abstraction            670
Art Deco 

### Applying filter to WikiArt dataset

In [7]:
filtered_ds = ds.filter(lambda example: example['style'] in style_filter)

print(f'Filtered dataset size: {len(filtered_ds)}')

Filter:   0%|          | 0/103250 [00:00<?, ? examples/s]

Filtered dataset size: 87835


### Splitting into train, valid and test

In [8]:
train_ratio = 0.8
valid_ratio = 0.1
test_ratio = 0.1

split_data = filtered_ds.train_test_split(test_size=(valid_ratio + test_ratio), shuffle=True, seed=42)

train_data = split_data['train']
temp_data = split_data['test']

split_data = temp_data.train_test_split(test_size=test_ratio / (valid_ratio + test_ratio), shuffle=True, seed=42)

valid_data = split_data['train']
test_data = split_data['test']

In [9]:
annotations = {
    'train.csv': train_data.to_pandas()[['artist', 'genre', 'style']],
    'valid.csv': valid_data.to_pandas()[['artist', 'genre', 'style']],
    'test.csv': test_data.to_pandas()[['artist', 'genre', 'style']],
}

datasets = {
    'train.h5': train_data,
    'valid.h5': valid_data,
    'test.h5': test_data,
}

### Creating HDF5 datasets and annotation files

In [10]:
def save_to_h5(filename, dataset, image_shape=(3, 224, 224), batch_size=6000):
    num_images = len(dataset)
    with h5py.File(filename, 'w') as h5f:
        h5_dataset = h5f.create_dataset('image', shape=(num_images, *image_shape), dtype='uint8')
        for i in tqdm(range(0, num_images, batch_size), desc=f'Processing {filename}'):
            batch_images = dataset[i:i + batch_size]['image']
            print(f'{filename} batch {(i/batch_size) + 1}') # for kaggle log
            batch_data = [
                np.array(img.resize((224, 224)).convert('RGB'), dtype=np.uint8).transpose(2, 0, 1)
                for img in batch_images
            ]
            h5_dataset[i:i + len(batch_data)] = batch_data

In [11]:
for filename, data in datasets.items():
    save_to_h5(f'/kaggle/working/{filename}', data)

print('train.h5, valid.h5, test.h5 has been saved.')

Processing /kaggle/working/train.h5:   0%|          | 0/12 [00:00<?, ?it/s]

/kaggle/working/train.h5 batch 1.0


Processing /kaggle/working/train.h5:   8%|▊         | 1/12 [00:18<03:20, 18.19s/it]

/kaggle/working/train.h5 batch 2.0


Processing /kaggle/working/train.h5:  17%|█▋        | 2/12 [00:35<02:55, 17.55s/it]

/kaggle/working/train.h5 batch 3.0


Processing /kaggle/working/train.h5:  25%|██▌       | 3/12 [00:51<02:32, 16.98s/it]

/kaggle/working/train.h5 batch 4.0


Processing /kaggle/working/train.h5:  33%|███▎      | 4/12 [01:08<02:14, 16.76s/it]

/kaggle/working/train.h5 batch 5.0


Processing /kaggle/working/train.h5:  42%|████▏     | 5/12 [01:24<01:56, 16.57s/it]

/kaggle/working/train.h5 batch 6.0


Processing /kaggle/working/train.h5:  50%|█████     | 6/12 [01:40<01:39, 16.57s/it]

/kaggle/working/train.h5 batch 7.0


Processing /kaggle/working/train.h5:  58%|█████▊    | 7/12 [01:57<01:22, 16.51s/it]

/kaggle/working/train.h5 batch 8.0


Processing /kaggle/working/train.h5:  67%|██████▋   | 8/12 [02:13<01:05, 16.45s/it]

/kaggle/working/train.h5 batch 9.0


Processing /kaggle/working/train.h5:  75%|███████▌  | 9/12 [02:30<00:49, 16.63s/it]

/kaggle/working/train.h5 batch 10.0


Processing /kaggle/working/train.h5:  83%|████████▎ | 10/12 [02:47<00:33, 16.59s/it]

/kaggle/working/train.h5 batch 11.0


Processing /kaggle/working/train.h5:  92%|█████████▏| 11/12 [03:03<00:16, 16.61s/it]

/kaggle/working/train.h5 batch 12.0


Processing /kaggle/working/train.h5: 100%|██████████| 12/12 [03:15<00:00, 16.28s/it]
Processing /kaggle/working/valid.h5:   0%|          | 0/2 [00:00<?, ?it/s]

/kaggle/working/valid.h5 batch 1.0


Processing /kaggle/working/valid.h5:  50%|█████     | 1/2 [00:15<00:15, 15.77s/it]

/kaggle/working/valid.h5 batch 2.0


Processing /kaggle/working/valid.h5: 100%|██████████| 2/2 [00:23<00:00, 11.57s/it]
Processing /kaggle/working/test.h5:   0%|          | 0/2 [00:00<?, ?it/s]

/kaggle/working/test.h5 batch 1.0


Processing /kaggle/working/test.h5:  50%|█████     | 1/2 [00:15<00:15, 15.92s/it]

/kaggle/working/test.h5 batch 2.0


Processing /kaggle/working/test.h5: 100%|██████████| 2/2 [00:23<00:00, 11.65s/it]

train.h5, valid.h5, test.h5 has been saved.





In [12]:
for filename, data in annotations.items():
    data.to_csv(f'/kaggle/working/{filename}', index=False)

print('train.csv, valid.csv, test.csv has been saved.')

train.csv, valid.csv, test.csv has been saved.
