In [None]:
!git clone https://github.com/lokissdo/thesis_2025.git

In [None]:
%cd /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing

In [None]:
cp /kaggle/input/celabamask/masking_model.pth /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/models/parsenet

In [None]:
mv /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/models/parsenet/masking_model.pth /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/models/parsenet/model.pth

In [None]:
pwd

In [None]:
mkdir /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/Data_preprocessing/test_img 

In [None]:
pip install tensorboardX 

In [None]:
pip install -U google-generativeai

In [None]:
import os
os.environ['GOOGLE_API_KEY'] = 'AIzaSyDckLD55r3VtpUSxQnGdzuGth260zy6s-A'

In [None]:
%%writefile /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/prompt_face_part_extractor.py
import os
import base64
from PIL import Image
from io import BytesIO
import subprocess
import argparse
import google.generativeai as genai
import shutil
import csv

def extract_prompt(model_name="gemini-1.5-pro", prompt=None, image_path=None, overwrited_prompt=None):
    chosen_labels = ['mouth', 'upper_lip', 'lower_lip']
    return chosen_labels

def get_api_key():
    """Get API key from environment variable or command line argument"""
    api_key = os.getenv('GOOGLE_API_KEY')
    if not api_key:
        api_key = input("Please enter your Google API key: ").strip()
        if not api_key:
            raise ValueError("API key is required. Set GOOGLE_API_KEY environment variable or provide it when prompted.")
    return api_key

def get_image_path():
    """Get image path from command line argument"""
    parser = argparse.ArgumentParser(description='Extract prompt for face parsing')
    parser.add_argument('--image_path', type=str, required=True, 
                       help='Path to the input image')
    parser.add_argument('--prompt', type=str, required=True,
                       help='Prompt describing the desired change')
    args = parser.parse_args()
    
    if not os.path.exists(args.image_path):
        raise FileNotFoundError(f"Image file not found: {args.image_path}")
    
    return args.image_path, args.prompt

def get_prompt_args():
    """Get both image_path and prompt from command line arguments"""
    parser = argparse.ArgumentParser(description='Extract prompt for face parsing')
    parser.add_argument('--image_path', type=str, required=True, 
                       help='Path to the input image')
    parser.add_argument('--prompt', type=str, required=True,
                       help='Prompt describing the desired change')
    parser.add_argument('--model_name', type=str, default='gemini-1.5-pro',
                       help='Model name to use (default: gemini-1.5-pro)')
    parser.add_argument('--overwrited_prompt', type=str, default=None,
                       help='Custom prompt to overwrite the default prompt template')
    args = parser.parse_args()
    
    if not os.path.exists(args.image_path):
        raise FileNotFoundError(f"Image file not found: {args.image_path}")
    
    return args.image_path, args.prompt, args.model_name, args.overwrited_prompt

if __name__ == "__main__":
    # Get API key and configure genai
    # api_key = get_api_key()
    # genai.configure(api_key=api_key)

    # Get image path and prompt from command line arguments
    image_path, prompt, model_name, overwrited_prompt = get_prompt_args()
    img_index = image_path.split('/')[-1].split('.')[0]
    
    chosen_labels = extract_prompt(
        model_name=model_name, 
        prompt=prompt, 
        image_path=image_path, 
        overwrited_prompt=overwrited_prompt
    )
    print(chosen_labels)
    chosen_labels_str = ' '.join(chosen_labels)


    test_img_dir = './Data_preprocessing/test_img'
    # Remove all images in ./Data_preprocessing/test_img
    for file in os.listdir(test_img_dir):
        os.remove(os.path.join(test_img_dir, file))
        
    # Copy the image to ./Data_preprocessing/test_img/0.jpg
    shutil.copy(image_path, os.path.join(test_img_dir, f'0.jpg'))

    # Prepare the command to run main.py with the chosen_labels
    command = f"python -u main.py --batch_size 1 --imsize 512 --version parsenet --train False --test_size 1 --chosen_labels {chosen_labels_str} --test_image_path {test_img_dir} --output_mask_name {img_index}_combined_mask.jpg"

    # Execute the command
    # Run the command and stream the output
    process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

    # Continuously read and print the output in real-time
    for line in process.stdout:
        print(line, end="")  # Print the line from stdout

    # Wait for the process to complete
    process.wait()
        
    # Write the chosen_labels to a file in     /test_results
    csv_path = './test_results/chosen_labels.csv'
    file_exists = os.path.isfile(csv_path)
    is_empty = not file_exists or os.stat(csv_path).st_size == 0

    with open(csv_path, mode='a', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        
        # Ghi header nếu file trống
        if is_empty:
            writer.writerow(['img_index', 'chosen_labels'])

        # Ghi dữ liệu
        writer.writerow([img_index, chosen_labels_str])
    
    # Check for any errors from stderr
    stderr_output = process.stderr.read()
    if stderr_output:
        print(f"Errors: {stderr_output}")

In [None]:
import time

for i in range(0, 300):
    image_path = f"/kaggle/input/celebamaskhq/CelebAMask-HQ/CelebA-HQ-img/{i}.jpg"
    !python prompt_face_part_extractor.py --image_path "{image_path}" --prompt "SMILLING AND NOT BALD"
    time.sleep(3)


In [None]:
!zip -r /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/test_results.zip /kaggle/working/thesis_2025/face_parts_retrieval/CelebAMask-HQ/face_parsing/test_results