In [4]:
from PIL import Image
from pipeline import pipeline, segmentation_filter, user_palette_classification_filter
from utils import segmentation_labels, utils
from models import dataset, config
from sklearn.model_selection import train_test_split
import random
import torchvision.transforms as T
import matplotlib.pyplot as plt
from palette_classification import color_processing, palette
import glob
from metrics_and_losses import metrics
import torch

In [2]:
# fetch training dataset
n_classes = len(segmentation_labels.labels)
dataset_path = config.DATASET_PATH
img_paths, label_paths = dataset.get_paths(dataset_path, file_name=config.DATASET_INDEX_NAME)
X_train, _, Y_train, _ = train_test_split(img_paths, label_paths, test_size=0.25, random_state=99, shuffle=True)

# load reference palettes and set up labels/indexes
palettes_dir = 'palette_classification/palettes/'
example_images_dir = 'palette_classification/example_images/'
palette_filenames = glob.glob(palettes_dir + '*.csv')
reference_palettes = [palette.PaletteRGB().load(
    palette_filename.replace('\\', '/'), header=True) for palette_filename in palette_filenames]
relevant_labels = ['skin', 'hair', 'lips', 'eyes']
relevant_indexes = [ utils.from_key_to_index(segmentation_labels.labels, label) for label in relevant_labels ]
skin_idx, hair_idx, lips_idx, eyes_idx = (0, 1, 2, 3)

# define filters
sf = segmentation_filter.SegmentationFilter('local')

# compute metrics on training set
n_samples = len(X_train)
results = { metric: [] for metric in ['intensity', 'value', 'contrast'] }
for i in range(n_samples):
    print(f'Evaluating sample {i+1}/{n_samples}...')
    
    input = Image.open(X_train[i]).convert('RGB')
    img, masks = sf.execute(input)

    # consider only images for which all relevant categories are present
    if (masks[relevant_indexes[0], :, :].sum() == 0 or masks[relevant_indexes[1], :, :].sum() == 0 or
        masks[relevant_indexes[2], :, :].sum() == 0 or masks[relevant_indexes[3], :, :].sum() == 0):
        continue
    
    # compute metrics
    relevant_masks = masks[relevant_indexes, :, :]
    img_masked = color_processing.apply_masks(img, relevant_masks)
    dominants = color_processing.compute_user_embedding(
        img_masked, n_candidates=(3, 3, 3, 3), distance_fn=metrics.rmse)
    dominants_palette = palette.PaletteRGB('dominants', dominants)
    intensity = palette.compute_intensity(dominants[skin_idx])
    value = palette.compute_value(dominants[skin_idx], dominants[hair_idx], dominants[eyes_idx])
    contrast = palette.compute_contrast(dominants[hair_idx], dominants[eyes_idx])
    
    # add metrics values to results dictionary
    results['intensity'].append(intensity)
    results['value'].append(value)
    results['contrast'].append(contrast)

Evaluating sample 1/12417...
Evaluating sample 2/12417...
Evaluating sample 3/12417...
Evaluating sample 4/12417...
Evaluating sample 5/12417...
Evaluating sample 6/12417...
Evaluating sample 7/12417...
Evaluating sample 8/12417...
Evaluating sample 9/12417...
Evaluating sample 10/12417...
Evaluating sample 11/12417...
Evaluating sample 12/12417...
Evaluating sample 13/12417...
Evaluating sample 14/12417...
Evaluating sample 15/12417...
Evaluating sample 16/12417...
Evaluating sample 17/12417...
Evaluating sample 18/12417...
Evaluating sample 19/12417...
Evaluating sample 20/12417...
Evaluating sample 21/12417...
Evaluating sample 22/12417...
Evaluating sample 23/12417...
Evaluating sample 24/12417...
Evaluating sample 25/12417...
Evaluating sample 26/12417...
Evaluating sample 27/12417...
Evaluating sample 28/12417...
Evaluating sample 29/12417...
Evaluating sample 30/12417...
Evaluating sample 31/12417...
Evaluating sample 32/12417...
Evaluating sample 33/12417...
Evaluating sample 3

In [3]:
for metric in list(results.keys()):
    values = torch.tensor(results[metric])

    print()
    print(metric)
    print('|__min=%.3f' % values.min())
    print('|__max=%.3f' % values.max())
    print('|__avg=%.3f' % values.mean())
    print('|__median=%.3f' % values.median())


intensity
|__min=0.000
|__max=0.982
|__avg=0.411
|__median=0.422

value
|__min=0.069
|__max=0.854
|__avg=0.385
|__median=0.390

contrast
|__min=0.000
|__max=0.859
|__avg=0.247
|__median=0.200
