In [1]:
# Install Pytorch & other libraries
%pip install "torch==2.4.0" tensorboard pillow
 
# Install Hugging Face libraries
%pip install  --upgrade \
  "transformers==4.45.1" \
  "datasets==3.0.1" \
  "accelerate==0.34.2" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.44.0" \
  "trl==0.11.1" \
  "peft==0.13.0" \
  "qwen-vl-utils"

Note: you may need to restart the kernel to use updated packages.
Collecting transformers==4.45.1
  Downloading transformers-4.45.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets==3.0.1
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting accelerate==0.34.2
  Downloading accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting evaluate==0.4.3
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting bitsandbytes==0.44.0
  Downloading bitsandbytes-0.44.0-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Collecting trl==0.11.1
  Downloading trl-0.11.1-py3-none-any.whl.metadata (12 kB)
Collecting peft==0.13.0
  Downloading peft-0.13.0-py3-none-any.whl.metadata (13 kB)
Collecting qwen-vl-utils
  Downloading qwen_vl_utils-0.0.8-py3-none-any.whl.metadata (3.6 kB)
Collecting tyro>=0.5.11 (from trl==0.11.1)
  Downloa

In [3]:
import os
import pandas as pd
from tqdm.notebook import tqdm_notebook as tqdm
from datasets import Dataset, Features, ClassLabel, Value, Sequence
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig
from transformers import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info
from trl import SFTTrainer
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

In [7]:
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(token=hf_token)

In [9]:
# List of base directories to process
base_dirs = [
    "/kaggle/input/resized-artifact-dataset/Manually Annotated Dataset",
    "/kaggle/input/resized-artifact-dataset/labeled fake artufacts"  # Add the second directory here
]

# Dictionary to store image paths and their associated classes
image_class_map = {}

# Function to process a single base directory
def process_directory(base_dir):
    datasets = [f for f in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, f))]
    for dataset in tqdm(datasets, desc=f"Processing datasets in {base_dir}"):
        dataset_path = os.path.join(base_dir, dataset)
        defect_types = [f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]
        for defect_type in tqdm(defect_types, desc=f"Processing defect types in {dataset}", leave=False):
            defect_path = os.path.join(dataset_path, defect_type)
            image_files = [f for f in os.listdir(defect_path) if os.path.isfile(os.path.join(defect_path, f))]
            for image_name in tqdm(image_files, desc=f"Processing images in {dataset}/{defect_type}", leave=False):
                # Create a unique key using the full relative path
                unique_key = os.path.join(dataset, image_name)
                image_path = os.path.join(defect_path, image_name)
                if unique_key not in image_class_map:
                    image_class_map[unique_key] = {"path": image_path, "classes": []}
                image_class_map[unique_key]["classes"].append(defect_type)

# Process each base directory
for base_dir in base_dirs:
    process_directory(base_dir)

# Convert dictionary to DataFrame
data = [
    {"image_path": value["path"], "classes": value["classes"], "unique_key": key}
    for key, value in image_class_map.items()
]
df = pd.DataFrame(data)

# Save the DataFrame to a CSV file if needed
df.to_csv("image_classes.csv", index=False)

# Display the DataFrame
print(df)


Processing datasets in /kaggle/input/resized-artifact-dataset/Manually Annotated Dataset:   0%|          | 0/3…

Processing defect types in dog dataset:   0%|          | 0/6 [00:00<?, ?it/s]

Processing images in dog dataset/AI Defects:   0%|          | 0/6 [00:00<?, ?it/s]

Processing images in dog dataset/Reality Breaks:   0%|          | 0/34 [00:00<?, ?it/s]

Processing images in dog dataset/Biological defects:   0%|          | 0/153 [00:00<?, ?it/s]

Processing images in dog dataset/Overprocessing:   0%|          | 0/14 [00:00<?, ?it/s]

Processing images in dog dataset/Scene Oddities:   0%|          | 0/70 [00:00<?, ?it/s]

Processing images in dog dataset/Surface, Depth and Edge:   0%|          | 0/56 [00:00<?, ?it/s]

Processing defect types in cat dataset:   0%|          | 0/6 [00:00<?, ?it/s]

Processing images in cat dataset/AI Defects:   0%|          | 0/36 [00:00<?, ?it/s]

Processing images in cat dataset/Reality Breaks:   0%|          | 0/44 [00:00<?, ?it/s]

Processing images in cat dataset/Biological defects:   0%|          | 0/208 [00:00<?, ?it/s]

Processing images in cat dataset/Overprocessing:   0%|          | 0/39 [00:00<?, ?it/s]

Processing images in cat dataset/Scene Oddities:   0%|          | 0/75 [00:00<?, ?it/s]

Processing images in cat dataset/Surface, Depth and Edge:   0%|          | 0/137 [00:00<?, ?it/s]

Processing defect types in deer dataset:   0%|          | 0/6 [00:00<?, ?it/s]

Processing images in deer dataset/AI Defects:   0%|          | 0/5 [00:00<?, ?it/s]

Processing images in deer dataset/Reality Breaks:   0%|          | 0/8 [00:00<?, ?it/s]

Processing images in deer dataset/Biological defects:   0%|          | 0/133 [00:00<?, ?it/s]

Processing images in deer dataset/Overprocessing:   0%|          | 0/31 [00:00<?, ?it/s]

Processing images in deer dataset/Scene Oddities:   0%|          | 0/59 [00:00<?, ?it/s]

Processing images in deer dataset/Surface, Depth and Edge:   0%|          | 0/40 [00:00<?, ?it/s]

Processing datasets in /kaggle/input/resized-artifact-dataset/labeled fake artufacts:   0%|          | 0/4 [00…

Processing defect types in automobile:   0%|          | 0/6 [00:00<?, ?it/s]

Processing images in automobile/Scene Oddities (lighting, reflections, composition):   0%|          | 0/115 [0…

Processing images in automobile/Biological defects:   0%|          | 0/71 [00:00<?, ?it/s]

Processing images in automobile/Overprocessing:   0%|          | 0/15 [00:00<?, ?it/s]

Processing images in automobile/AI Glitches (digital artifacts, noise, processing errors):   0%|          | 0/…

Processing images in automobile/Reality Breaks (impossible structures, physics violations):   0%|          | 0…

Processing images in automobile/Surface, Depth and Edge:   0%|          | 0/178 [00:00<?, ?it/s]

Processing defect types in ships:   0%|          | 0/5 [00:00<?, ?it/s]

Processing images in ships/Scene Oddities (lighting, reflections, composition):   0%|          | 0/112 [00:00<…

Processing images in ships/Overprocessing:   0%|          | 0/40 [00:00<?, ?it/s]

Processing images in ships/AI Glitches (digital artifacts, noise, processing errors):   0%|          | 0/14 [0…

Processing images in ships/Reality Breaks (impossible structures, physics violations):   0%|          | 0/93 […

Processing images in ships/Surface, Depth and Edge:   0%|          | 0/107 [00:00<?, ?it/s]

Processing defect types in trucks:   0%|          | 0/5 [00:00<?, ?it/s]

Processing images in trucks/Scene Oddities (lighting, reflections, composition):   0%|          | 0/96 [00:00<…

Processing images in trucks/Overprocessing:   0%|          | 0/5 [00:00<?, ?it/s]

Processing images in trucks/AI Glitches (digital artifacts, noise, processing errors):   0%|          | 0/46 […

Processing images in trucks/Reality Breaks (impossible structures, physics violations):   0%|          | 0/19 …

Processing images in trucks/Surface, Depth and Edge:   0%|          | 0/96 [00:00<?, ?it/s]

Processing defect types in aeroplane:   0%|          | 0/5 [00:00<?, ?it/s]

Processing images in aeroplane/Scene Oddities (lighting, reflections, composition):   0%|          | 0/252 [00…

Processing images in aeroplane/Overprocessing:   0%|          | 0/12 [00:00<?, ?it/s]

Processing images in aeroplane/AI Glitches (digital artifacts, noise, processing errors):   0%|          | 0/1…

Processing images in aeroplane/Reality Breaks (impossible structures, physics violations):   0%|          | 0/…

Processing images in aeroplane/Surface, Depth and Edge:   0%|          | 0/279 [00:00<?, ?it/s]

                                             image_path  \
0     /kaggle/input/resized-artifact-dataset/Manuall...   
1     /kaggle/input/resized-artifact-dataset/Manuall...   
2     /kaggle/input/resized-artifact-dataset/Manuall...   
3     /kaggle/input/resized-artifact-dataset/Manuall...   
4     /kaggle/input/resized-artifact-dataset/Manuall...   
...                                                 ...   
1710  /kaggle/input/resized-artifact-dataset/labeled...   
1711  /kaggle/input/resized-artifact-dataset/labeled...   
1712  /kaggle/input/resized-artifact-dataset/labeled...   
1713  /kaggle/input/resized-artifact-dataset/labeled...   
1714  /kaggle/input/resized-artifact-dataset/labeled...   

                                               classes  \
0                     [AI Defects, Biological defects]   
1                     [AI Defects, Biological defects]   
2                     [AI Defects, Biological defects]   
3     [AI Defects, Reality Breaks, Biological defects]   
4

In [10]:
# Define artifact group names
label_names = [
    "Overprocessing",
    "Reality Breaks",
    "AI Defects",
    "Scene Oddities",
    "Surface, Depth and Edge",
    "Biological defects",
]

def load_images_and_labels(df):
    """
    Load image paths and corresponding labels from a DataFrame.
    
    Args:
        df (pd.DataFrame): Input DataFrame containing 'image_path' and 'classes' columns.
        
    Returns:
        dict: A dictionary with 'image' paths and corresponding 'label' indices.
    """
    image_paths, labels = [], []
    for i in range(len(df)):
        row = df.iloc[i]
        image_paths.append(row["image_path"])
        
        # Convert class names to indices based on `label_names`
        class_indices = []
        for cls in row["classes"]:
            # Normalize class names for mapping
            if cls == "AI Glitches (digital artifacts, noise, processing errors)":
                cls = "AI Defects"
            elif cls == "Reality Breaks (impossible structures, physics violations)":
                cls = "Reality Breaks"
            elif cls == "Scene Oddities (lighting, reflections, composition)":
                cls = "Scene Oddities"
            
            class_indices.append(label_names.index(cls))
        
        labels.append(class_indices)
    
    return {"image": image_paths, "label": labels}

# Load datasets
train_data = load_images_and_labels(df)
test_data = load_images_and_labels(df[1500:])

# Define dataset features
features = Features({
    "image": Value("string"),  # Store image paths as strings
    "label": Sequence(feature=ClassLabel(names=label_names)),  # Support multiple labels
})

# Convert dictionaries to Hugging Face datasets
train_dataset = Dataset.from_dict(train_data).cast(features)
test_dataset = Dataset.from_dict(test_data).cast(features)

# Define system message and prompt for image classification
system_message = "You are a group classifier for images with artifacts."
prompt = (
    "You are provided with an image. Your task is to analyze the image and identify which artifact groups it belongs to. "
    f"An image can belong to one or more of the following groups: {', '.join(label_names)}. Below are descriptions of each group to guide your analysis: "
    "1. AI Defects: Floating parts, Noise on flat areas, Weird perspective, Blurred details, Ghosting/Repeats. \n"
    "2. Biological Defects: Misaligned features, Deformations, Fur errors, Unrealistic eyes, Asymmetry. \n"
    "3. Overprocessing: Grid artifacts, Cinematic look, Over-sharpening, Dramatic lighting, Scale issues. \n"
    "4. Reality Breaks: Non-manifold structures, Asymmetric shapes, Proportion errors, Impossible joints, Jagged edges. \n"
    "5. Scene Oddities: Metallic artifacts, Distorted reflections, Specular issues, Shadow inconsistencies, Glossy surfaces. \n"
    "6. Surface, Depth and Edge: Depth anomalies, Blurred edges, Aliasing, Texture bleeding, Fake depth, Synthetic look, Color breaks. \n"
)


Casting the dataset:   0%|          | 0/1715 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/215 [00:00<?, ? examples/s]

In [11]:
# Function to format data
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt,
                    },
                    {
                        "type": "image",
                        "image": sample["image"],  # Image path or image content
                        "resized_height": 32,
                        "resized_width": 32,
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": ', '.join([label_names[idx] for idx in sample["label"]]),  # Convert indices to class names
                    }
                ],
            },
        ],
    }


# Process datasets into OpenAI messages
formatted_train = [format_data(sample) for sample in train_dataset]
formatted_test = [format_data(sample) for sample in test_dataset]

# Example output
if formatted_train:
    print(formatted_train[0]["messages"])
else:
    print("Train dataset is empty.")


[{'role': 'system', 'content': [{'type': 'text', 'text': 'You are a group classifier for images with artifacts.'}]}, {'role': 'user', 'content': [{'type': 'text', 'text': 'You are provided with an image. Your task is to analyze the image and identify which artifact groups it belongs to. An image can belong to one or more of the following groups: Overprocessing, Reality Breaks, AI Defects, Scene Oddities, Surface, Depth and Edge, Biological defects. Below are descriptions of each group to guide your analysis: 1. AI Defects: Floating parts, Noise on flat areas, Weird perspective, Blurred details, Ghosting/Repeats. \n2. Biological Defects: Misaligned features, Deformations, Fur errors, Unrealistic eyes, Asymmetry. \n3. Overprocessing: Grid artifacts, Cinematic look, Over-sharpening, Dramatic lighting, Scale issues. \n4. Reality Breaks: Non-manifold structures, Asymmetric shapes, Proportion errors, Impossible joints, Jagged edges. \n5. Scene Oddities: Metallic artifacts, Distorted reflecti

In [12]:
len(formatted_train)

1715

In [13]:
 # Hugging Face model id
model_id = "Qwen/Qwen2-VL-7B-Instruct" 
 
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
 
# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map="auto",
    # attn_implementation="flash_attention_2", # not supported for training
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)

config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


model.safetensors.index.json:   0%|          | 0.00/56.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/5 [00:00<?, ?it/s]

model-00001-of-00005.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

model-00002-of-00005.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00005.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00005.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00005-of-00005.safetensors:   0%|          | 0.00/1.09G [00:00<?, ?B/s]

`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/244 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/347 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

In [14]:
# Preparation for inference
text = processor.apply_chat_template(
    formatted_train[2]["messages"], tokenize=False, add_generation_prompt=False
)
print(text)

<|im_start|>system
You are a group classifier for images with artifacts.<|im_end|>
<|im_start|>user
You are provided with an image. Your task is to analyze the image and identify which artifact groups it belongs to. An image can belong to one or more of the following groups: Overprocessing, Reality Breaks, AI Defects, Scene Oddities, Surface, Depth and Edge, Biological defects. Below are descriptions of each group to guide your analysis: 1. AI Defects: Floating parts, Noise on flat areas, Weird perspective, Blurred details, Ghosting/Repeats. 
2. Biological Defects: Misaligned features, Deformations, Fur errors, Unrealistic eyes, Asymmetry. 
3. Overprocessing: Grid artifacts, Cinematic look, Over-sharpening, Dramatic lighting, Scale issues. 
4. Reality Breaks: Non-manifold structures, Asymmetric shapes, Proportion errors, Impossible joints, Jagged edges. 
5. Scene Oddities: Metallic artifacts, Distorted reflections, Specular issues, Shadow inconsistencies, Glossy surfaces. 
6. Surface, 

In [15]:
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=8,
        bias="none",
        target_modules=["q_proj", "v_proj"],
        task_type="CAUSAL_LM", 
)

In [16]:
args = SFTConfig(
    output_dir="qwen2-7b-instruct-artifact", # directory to save and repository id
    num_train_epochs=2,                     # number of training epochs
    per_device_train_batch_size=4,          # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=5,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision
    tf32=False,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing
    dataset_text_field="", # need a dummy field for collator
    dataset_kwargs = {"skip_prepare_dataset": True} # important for collator
)
args.remove_unused_columns=False
 
# Create a data collator to encode text and image pairs
def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    image_inputs = [process_vision_info(example["messages"])[0] for example in examples]
     
    # Tokenize the texts and process the images
    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
 
    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652,151653,151655]
    else: 
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100
    batch["labels"] = labels
 
    return batch

In [17]:
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=formatted_train,
    data_collator=collate_fn,
    dataset_text_field="", # needs dummy value
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


In [12]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
 
# save model 
trainer.save_model(args.output_dir)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


Step,Training Loss
5,2.7129
10,2.4229
15,2.055
20,1.6208
25,1.0651
30,0.5104
35,0.158
40,0.0737
45,0.0653
50,0.0614


Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


events.out.tfevents.1733351630.4240fec3afc5.23.0:   0%|          | 0.00/11.0k [00:00<?, ?B/s]