In [None]:
import sys
sys.path.append('..')

import os
import time
import copy
import pickle
import psutil
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import persistence
from torch_utils import misc
import wandb
import ambient_utils
import json
from collections import defaultdict
import zipfile

In [None]:
from calculate_metrics_quality import load_stats

In [None]:
import torchvision

def view_mosaic(images, nrow=5):
  mosaic = torchvision.utils.make_grid([x for x in images], nrow=min(nrow, len(images)))
  mosaic = torchvision.transforms.ToPILImage()(mosaic)
  display(mosaic)

In [None]:
dataset_kwargs = dnnlib.EasyDict(
  path='../datasets/img512.zip', 
  use_labels=True, 
  corruption_probability=0.0,
  normalize=False,
  use_other_keys=False
)

annotations_qualities_path = "../annotations/clip_iqa_patch_average.pkl"
bad_data_percentage = 0.8
bad_data_sigma_min = 0.2
use_ambient_crops = False

## Verify qualities

In [None]:
annotations_qualities = load_stats(annotations_qualities_path)

In [None]:
annotations_qualities['CLIP-IQA'][:, 0]

In [None]:
annotations_qualities['CLIP-IQA-512']

In [None]:
annotations_qualities['CLIP-IQA-32']

In [16]:
for k in annotations_qualities.keys():
  print(k, annotations_qualities[k].shape)

hash torch.Size([1281167])
CLIP-IQA torch.Size([1281167, 1])
CLIP-IQA-512 torch.Size([1281167])
CLIP-IQA-256 torch.Size([1281167])
CLIP-IQA-128 torch.Size([1281167])
CLIP-IQA-64 torch.Size([1281167])
CLIP-IQA-32 torch.Size([1281167])


## Verify annotations

In [None]:
# Setup dataset, encoder, and network.
print('Loading dataset...')
dataset_obj = ambient_utils.dataset.SyntheticallyCorruptedImageFolderDataset(**dataset_kwargs)

In [None]:
## Annotations
annotations = {}
if annotations_qualities_path is not None:
    annotations_qualities = load_stats(annotations_qualities_path)

    ### Sigma min
    global_qualities = annotations_qualities['CLIP-IQA'][:, 0]
    assert len(global_qualities) == len(dataset_obj), f'Qualities ({len(global_qualities)}) and dataset_obj ({len(dataset_obj)}) must have equal lengths'

    sorted_indices = torch.argsort(global_qualities, descending=True)
    rank = torch.arange(len(global_qualities))[sorted_indices]
    rank_threshold = int(len(global_qualities) * (1 - bad_data_percentage))
    annotations_sigma_min = torch.where(rank < rank_threshold, 0.0, bad_data_sigma_min)

    ### Sigma max
    latents_receptive_field_to_sigma_max = {8:0.15, 16:0.25, 32:0.95}
    annotations_sigma_max = torch.zeros(len(global_qualities))
    if use_ambient_crops:
        for latents_receptive_field in [8, 16, 32]:
            pixel_receptive_field = 8 * latents_receptive_field
            patch_qualities = annotations_qualities[f'CLIP-IQA-{pixel_receptive_field}']

            sorted_indices = torch.argsort(patch_qualities, descending=True)
            rank = torch.arange(len(patch_qualities))[sorted_indices]

            rank_threshold = int(len(global_qualities) * (1 - bad_data_percentage))
            good_data_sigma_max = latents_receptive_field_to_sigma_max[latents_receptive_field]

            annotations_sigma_max = torch.where(rank < rank_threshold, good_data_sigma_max, annotations_sigma_max)

    ### Annotations tuple
    annotations = {dataset_obj._image_fnames[i]: (annotations_sigma_min[i], annotations_sigma_max[i]) for i in range(len(global_qualities))}
else:
    annotations = {dataset_obj._image_fnames[i]: (0.0, 0.0) for i in range(len(dataset_obj))}

## Set dataset annotations
dataset_obj.annotations = annotations

In [None]:
## Count number of samples with annotations zero
count_zeros = sum(1 for value in annotations.values() if value[0] == 0.0)
total_samples = len(annotations)
percentage = (count_zeros / total_samples) * 100 if total_samples > 0 else 0
print(f"Number of samples with annotation 0.0: {count_zeros} out of {total_samples} ({percentage:.2f}%)")

## Visualize top and bottom quality images

In [None]:
num_display = 16
global_qualities = annotations_qualities['CLIP-IQA'][:, 0]
sorted_indices = torch.argsort(global_qualities, descending=True)

In [None]:
# Display quality scores for top images
print("Quality scores for top images:")
top_quality_image_indices = sorted_indices[:num_display]
top_quality_scores = global_qualities[top_quality_image_indices]

print(" ".join([f"Image {i + 1}: {score:.4f}" for i, score in enumerate(top_quality_scores)]))

images = torch.tensor([dataset_obj[i]['image'] for i in top_quality_image_indices])

view_mosaic(images, nrow=8)

In [None]:
print("Quality scores for bottom images:")
bottom_quality_image_indices = sorted_indices[-num_display:]
bottom_quality_scores = global_qualities[bottom_quality_image_indices]

print(" ".join([f"Image {len(dataset_obj) - num_display + i + 1}: {score:.4f}" for i, score in enumerate(bottom_quality_scores)]))

images = torch.tensor([dataset_obj[i]['image'] for i in bottom_quality_image_indices])

view_mosaic(images, nrow=8)

## Quality distribution

In [None]:
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

# Create figure with appropriate size
plt.figure(figsize=(10, 6))

# Prepare data
data = {
    'num_examples': np.arange(len(global_qualities)) / len(global_qualities),
    'quality': global_qualities[sorted_indices].numpy(),
}

# Create a beautiful plot with custom styling
sns.set_style("whitegrid")
sns.lineplot(x='num_examples', y='quality', data=data, linewidth=2.5, color='#1f77b4')

# Add title and improve labels
plt.title('Image Quality Distribution', fontsize=16, fontweight='bold')
plt.xlabel('Fraction of Dataset (sorted by quality)', fontsize=12)
plt.ylabel('CLIP-IQA Quality Score', fontsize=12)

# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7)

# Improve tick labels
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

# Tight layout for better spacing
plt.tight_layout()

# Show the plot
plt.show()