In [None]:
!pip install hf_transfer transformers -q
!export HF_HUB_ENABLE_HF_TRANSFER=True
!huggingface-cli download shadowlilac/aesthetic-shadow-v2 model.safetensors --local-dir /workspace/aesthetic_scorer

In [None]:
import os
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

import safetensors.torch
from transformers import pipeline, AutoConfig, AutoProcessor, ViTForImageClassification
import torch
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

base_directory = "/workspace/train_data/animagine-xl-3.1"
model_path = '/workspace/aesthetic_scorer/model.safetensors'
batch_size = 64  
extensions = ('png', 'jpg', 'jpeg', 'webp')
num_workers = 8 
aesthetic_tags = ['very aesthetic', 'aesthetic', 'displeasing', 'very displeasing']  

class ShadowScore:
    def __init__(self, pathname, device, high_quality_label='hq'):
        self.pipe = None
        self.device = device
        self.pathname = pathname
        self.high_quality_label = high_quality_label
        self.initialize_model()

    def initialize_model(self):
        print("Loading model...")
        statedict = safetensors.torch.load_file(self.pathname)
        config = AutoConfig.from_pretrained("shadowlilac/aesthetic-shadow-v2")
        model = ViTForImageClassification.from_pretrained(pretrained_model_name_or_path=None, state_dict=statedict, config=config)
        processor = AutoProcessor.from_pretrained("shadowlilac/aesthetic-shadow-v2")
        device_str = str(self.device).replace('cuda:0', 'cuda')
        self.pipe = pipeline("image-classification", model=model, image_processor=processor, device=device_str)
        print("Model loaded.")

    def score(self, images):
        scores = self.pipe(images=images)
        results = []
        for score in scores:
            value = [p for p in score if p['label'] == self.high_quality_label][0]['score']
            if value > 0.71:
                tag = 'very aesthetic'
            elif value > 0.45 and value < 0.71:  # Corrected syntax
                tag = 'aesthetic'
            elif value > 0.27 and value < 0.45:  # Corrected syntax
                tag = 'displeasing'
            else:
                tag = 'very displeasing'
            results.append((value, tag))
        return results

def find_images(base_path, extensions):
    valid_images = []
    for root, _, files in os.walk(base_path):
        for file in files:
            if file.lower().endswith(extensions):
                valid_images.append(os.path.join(root, file))
    return valid_images

def load_image(path):
    with Image.open(path) as img:
        return img.copy()

def load_images_parallel(image_paths):
    images = [None] * len(image_paths)
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        future_to_index = {executor.submit(load_image, path): i for i, path in enumerate(image_paths)}
        for future in as_completed(future_to_index):
            index = future_to_index[future]
            images[index] = future.result()
    return images

def has_aesthetic_tags(image_path, aesthetic_tags):
    txt_path = os.path.splitext(image_path)[0] + '.txt'
    if os.path.exists(txt_path):
        with open(txt_path, 'r', encoding='utf-8') as file:
            content = file.read()
            for tag in aesthetic_tags:
                if tag in content:
                    return True
    return False

def resize_and_display_image(image_path, base_width=256):
    with Image.open(image_path) as img:
        w_percent = (base_width / float(img.size[0]))
        h_size = int((float(img.size[1]) * float(w_percent)))
        img = img.resize((base_width, h_size), Image.LANCZOS)
        img.show()
        
def append_tag_to_text_file(image_path, tag):
    txt_path = os.path.splitext(image_path)[0] + '.txt'
    if os.path.exists(txt_path):
        with open(txt_path, 'r+', encoding='utf-8') as file:
            content = file.read()
            if tag not in content:
                if content.strip():  # If file is not empty, prepend a comma
                    content += ', '
                content += tag
                print(content)
                file.seek(0)
                file.write(content)
                file.truncate()
            
    else:
        with open(txt_path, 'w', encoding='utf-8') as file:
            file.write(tag)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
shadow_score = ShadowScore(pathname=model_path, device=device)

all_image_paths = find_images(base_directory, extensions)
total_images = len(all_image_paths)

for i in tqdm(range(0, total_images, batch_size), desc="Processing batches"):
    batch_paths = all_image_paths[i:i+batch_size]
    batch_images = []
    batch_paths_to_process = []

    for path in batch_paths:
        if not has_aesthetic_tags(path, aesthetic_tags):
            image = load_image(path)
            if image is not None:  # Only append if the image was successfully loaded
                batch_paths_to_process.append(path)
                batch_images.append(image)

    if not batch_images:  
        continue

    try:
        batch_scores = shadow_score.score(batch_images)
    except Exception as e:
        print(f"Error processing batch starting with image path: {batch_paths_to_process[0]}")
        print(f"Exception: {e}")
        continue  

    for path, (score, tag) in zip(batch_paths_to_process, batch_scores):
        # print(f"Image: {path}, Score: {score:.2f}, Tag: {tag}")
        append_tag_to_text_file(path, tag)
        # resize_and_display_image(path, base_width=512)

    if device.type == 'cuda':
        torch.cuda.empty_cache()
