In [None]:
!pip install qwen-vl-utils
!pip install jsonformer

In [None]:
import pandas as pd
import time
import torch
from PIL import Image
import requests
import matplotlib.pyplot as plt
from transformers import Qwen2VLForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
import os
import gc
from tqdm.notebook import tqdm_notebook as tqdm   
import json
import re
from jsonformer import Jsonformer

In [None]:
model_id ="/kaggle/input/qwen7bfinetuned/qwen2-7b-instruct-artifact"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
 
model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, device_map="auto",quantization_config=bnb_config)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

In [None]:
artifacts_text = """
- Inconsistent object boundaries
- Discontinuous surfaces
- Non-manifold geometries in rigid structures
- Floating or disconnected components
- Asymmetric features in naturally symmetric objects 
- Misaligned bilateral elements in animal faces 
- Irregular proportions in mechanical components 
- Texture bleeding between adjacent regions
- Texture repetition patterns
- Over-smoothing of natural textures 
- Artificial noise patterns in uniform surfaces
- Unrealistic specular highlights
- Inconsistent material properties
- Metallic surface artifacts 
- Dental anomalies in mammals 
- Anatomically incorrect paw structures
- Improper fur direction flows
- Unrealistic eye reflections
- Misshapen ears or appendages
- Impossible mechanical connections
- Inconsistent scale of mechanical parts
- Physically impossible structural elements
- Inconsistent shadow directions
- Multiple light source conflicts
- Missing ambient occlusion
- Incorrect reflection mapping
- Incorrect perspective rendering
- Scale inconsistencies within single objects
- Spatial relationship errors
- Depth perception anomalies
- Over-sharpening artifacts
- Aliasing along high-contrast edges
- Blurred boundaries in fine details
- Jagged edges in curved structures
- Random noise patterns in detailed areas
- Loss of fine detail in complex structures
- Artificial enhancement artifacts
- Incorrect wheel geometry
- Implausible aerodynamic structures
- Misaligned body panels
- Impossible mechanical joints
- Distorted window reflections
- Anatomically impossible joint configurations
- Unnatural pose artifacts
- Biological asymmetry errors
- Regular grid-like artifacts in textures
- Repeated element patterns
- Systematic color distribution anomalies
- Frequency domain signatures
- Color coherence breaks
- Unnatural color transitions
- Resolution inconsistencies within regions
- Unnatural Lighting Gradients
- Incorrect Skin Tones
- Fake depth of field
- Abruptly cut off objects
- Glow or light bleed around object boundaries
- Ghosting effects: Semi-transparent duplicates of elements
- Cinematization Effects
- Excessive sharpness in certain image regions
- Artificial smoothness
- Movie-poster like composition of ordinary scenes
- Dramatic lighting that defies natural physics
- Artificial depth of field in object presentation
- Unnaturally glossy surfaces
- Synthetic material appearance
- Multiple inconsistent shadow sources
- Exaggerated characteristic features
- Impossible foreshortening in animal bodies
- Scale inconsistencies within the same object class
"""

artifacts_list = [line.strip("- ").strip() for line in artifacts_text.strip().split("\n") if line.strip()]

In [None]:
def process_image_and_prompt(model, processor, image_path, prompt, max_tokens=1024):
    # Load and display the image
    if image_path.startswith('http'):
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    
    # Prepare messages
    messages = [
        {
            "role": "system",
            "content" : [
                {"type": "text", "text": "You are an expert in identifying and analyzing artifacts that indicate why the image may appear unnatural or fake"}
            ]
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]
    
    # Process input
    input_text = processor.apply_chat_template(messages, tokenize = False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text = input_text,
        images = image_inputs,
        videos = video_inputs,
        padding = True,
        add_special_tokens=False,
        return_tensors="pt"
    ).to(model.device)
    
    # Generate output
    with torch.inference_mode():
        output = model.generate(**inputs, temperature = 0.1, max_new_tokens=max_tokens)
    
    # Decode output
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, output)
        ]
        response = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
    
    # Clean up memory
    del inputs, output
    
    return response

In [None]:
def get_explanations(imagelink):
    # Start timing
    start = time.time()
    
    # Display the image
    img = Image.open(imagelink)
    plt.imshow(img)
    plt.axis('off')
    plt.title('Input Image')
    plt.show()

    artifact_explanations = {}

    for artifact in tqdm(artifacts_list, desc="Processing: "):
        a1 = time.time()
        
        prompt = f"""TASK: For the provided artifact, Generate 1-2 explanation for the artifact. 
            IF the artifact is not applicable to the image, print None and NOTHING ELSE. Explain strictly in context of the given image. 
            Limit your explanation to 64 tokens.
            artifact : {artifact}"""
        artifact_response = process_image_and_prompt(model, processor, imagelink, prompt, 64)[0]
        # print(f"------------------------ {artifact} ------------------------")
        # print(artifact_response)
        # print()
        artifact_explanations[artifact] = artifact_response
        
        a2 = time.time()    
        # print(f"\nExplanation processing time: {a2 - a1} seconds")
        
    # Print total processing time
    end = time.time()
    print(f"\nTotal processing time: {end - start} seconds")

    return artifact_explanations

In [None]:
artifact_explanations = get_explanations("/kaggle/input/testdataadobe/perturbed_images_32/67.png")

In [None]:
image_dir = "/kaggle/input/testdataadobe/perturbed_images_32"

In [None]:
image_paths = []

start = 1
end = 301
for i in range(start, end):
    image_paths.append(image_dir + "/" + str(i) + ".png")

In [None]:
all_explanations = {}
for path in tqdm(image_paths):
    artifact_explanations = get_explanations(path)
    index = key.split('/')[-1].split('.')[0]
    all_explanations[path] = group_responses
    