In [None]:
# Installazione dipendenze
!pip install diffusers transformers torch tqdm pillow

# Download dataset di test
!mkdir -p /kaggle/working/test_dataset/faces
!wget -P /kaggle/working/test_dataset/faces https://upload.wikimedia.org/wikipedia/commons/d/dc/Steve_Jobs_Headshot_2010-CROP_%28cropped_2%29.jpg
!wget -P /kaggle/working/test_dataset/faces https://upload.wikimedia.org/wikipedia/commons/1/18/Mark_Zuckerberg_F8_2019_Keynote_%2832830578717%29_%28cropped%29.jpg
!wget -P /kaggle/working/test_dataset/faces https://upload.wikimedia.org/wikipedia/commons/3/34/Elon_Musk_Royal_Society_%28crop2%29.jpg

import os
import torch
from PIL import Image
from diffusers import StableDiffusionInstructPix2PixPipeline
from tqdm import tqdm
import logging
import json
import warnings
warnings.filterwarnings('ignore')

class DatasetGenerator:
    def __init__(self, base_dir="/kaggle/working/dataset"):
        self.base_dir = base_dir
        self.setup_directories()
        self.setup_model()
        
        # Parametri esatti dal paper per IP2P
        self.ip2p_params = {
            "num_inference_steps": 100,  # Dal paper: 100 denoising steps
            "image_guidance_scale": 1.5,  # Dal paper: image guidance of 1.5
            "guidance_scale": 7.5        # Dal paper: text guidance of 7.5
        }
        
        # Concetti dal paper
        self.concepts = {
            "old_person": "make the person look older",
            "vangogh_style": "convert this into Vincent van Gogh painting style",
            "watercolor": "convert this into watercolor painting style",
            "blond_person": "make the person blonde",
            "tan_person": "make the person look tanned"
        }

    def setup_directories(self):
        """Crea la struttura delle directory"""
        self.dirs = {
            "raw": os.path.join(self.base_dir, "raw"),
            "processed": os.path.join(self.base_dir, "processed")
        }
        
        for dir_path in self.dirs.values():
            os.makedirs(dir_path, exist_ok=True)
            
        print("Directory create con successo")

    def setup_model(self):
        """Carica il modello InstructPix2Pix"""
        try:
            self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
                "timbrooks/instruct-pix2pix",
                torch_dtype=torch.float16,
                safety_checker=None
            ).to("cuda")
            print("Modello caricato con successo")
        except Exception as e:
            print(f"Errore nel caricamento del modello: {e}")
            raise

    def process_image(self, image_path, instruction):
        """Processa una singola immagine con i parametri del paper"""
        try:
            image = Image.open(image_path).convert('RGB')
            # Resize a 512x512 come standard per InstructPix2Pix
            image = image.resize((256, 256))

            result = self.model(
                prompt=instruction,
                image=image,
                **self.ip2p_params  # Usa i parametri esatti del paper
            ).images[0]

            return result
        except Exception as e:
            print(f"Errore nel processamento dell'immagine {image_path}: {e}")
            return None

    def generate_dataset(self, source_dir):
        """Genera il dataset"""
        print(f"Inizio processamento immagini da: {source_dir}")
        
        # Prendi tutte le immagini dalla directory
        image_files = [f for f in os.listdir(source_dir) 
                      if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
        
        print(f"Trovate {len(image_files)} immagini")

        # Dizionario per tenere traccia dei risultati
        results = {concept: {"success": 0, "failed": 0} for concept in self.concepts}

        # Processa ogni immagine
        for img_file in tqdm(image_files, desc="Processing images"):
            img_path = os.path.join(source_dir, img_file)
            
            for concept_name, instruction in self.concepts.items():
                # Crea cartella per il concetto
                concept_dir = os.path.join(self.dirs["processed"], concept_name)
                os.makedirs(concept_dir, exist_ok=True)

                # Genera l'immagine modificata
                modified_img = self.process_image(img_path, instruction)
                
                if modified_img:
                    # Salva le immagini
                    base_name = os.path.splitext(img_file)[0]
                    original_path = os.path.join(concept_dir, f"original_{base_name}.png")
                    modified_path = os.path.join(concept_dir, f"modified_{base_name}.png")
                    
                    Image.open(img_path).convert('RGB').save(original_path)
                    modified_img.save(modified_path)
                    
                    results[concept_name]["success"] += 1
                    print(f"Salvate immagini per concetto: {concept_name}")
                else:
                    results[concept_name]["failed"] += 1

        # Stampa risultati finali
        print("\nRisultati finali:")
        for concept, stats in results.items():
            print(f"{concept}:")
            print(f"  Successi: {stats['success']}")
            print(f"  Falliti: {stats['failed']}")

        print("\nGenerazione dataset completata!")

def check_gpu():
    """Verifica disponibilità GPU"""
    if torch.cuda.is_available():
        print(f"GPU disponibile: {torch.cuda.get_device_name(0)}")
        print(f"Memoria totale: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        return True
    else:
        print("GPU non disponibile!")
        return False

def main():
    # Verifica GPU
    if not check_gpu():
        return
    
    # Crea il generatore
    generator = DatasetGenerator()
    
    # Percorso del dataset di test
    test_faces_path = "/kaggle/working/test_dataset/faces"
    
    # Genera dataset
    if os.path.exists(test_faces_path):
        generator.generate_dataset(test_faces_path)
    else:
        print(f"Directory {test_faces_path} non trovata!")

if __name__ == "__main__":
    main()

print("\nPer visualizzare i risultati:")
print("1. Controlla la cartella /kaggle/working/dataset/processed/")
print("2. Ogni sottocartella contiene le immagini originali e modificate per ogni concetto")