In [1]:
import os
import re
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
import textwrap
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
MAX_PER_CLASS = 5000
MAX_CHAR = 3500
OUTPUT_DIR = "triplet_output"

IMAGE_SIZE = (1024, 1024)
FONT_SIZE = 14
MARGIN = 15
LINE_SPACING = 9

In [13]:
DOC_FONT = "/System/Library/Fonts/Times New Roman.ttf" 
LATEX_FONT = "/System/Library/Fonts/Supplemental/Times.ttc"  
BROWSER_FONT = "/System/Library/Fonts/Supplemental/Arial.ttf" 

In [14]:
def render_text_to_image(text, font_path, output_path):
    try:
        font = ImageFont.truetype(font_path, FONT_SIZE)
    except:
        font = ImageFont.load_default()

    img = Image.new("RGB", IMAGE_SIZE, "white")
    draw = ImageDraw.Draw(img)

    max_width = IMAGE_SIZE[0] - 2 * MARGIN
    avg_char_width = font.getlength("A")
    max_chars_per_line = int(max_width / avg_char_width)

    wrapped_text = textwrap.fill(text, width=max_chars_per_line)

    draw.multiline_text(
        (MARGIN, MARGIN),
        wrapped_text,
        fill="black",
        font=font,
        spacing=LINE_SPACING
    )

    img.save(output_path)

In [15]:
def main():
    ds = load_dataset("artem9k/ai-text-detection-pile", split="train")
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    ai_count = 0
    human_count = 0

    for row in ds:
        text = row["text"]
        text_id = row["id"]
        source = row["source"].lower()

        if len(text) > MAX_CHAR:
            continue

        if source == "ai" and ai_count < MAX_PER_CLASS:
            label_ok = True
        elif source == "human" and human_count < MAX_PER_CLASS:
            label_ok = True
        else:
            label_ok = False

        if not label_ok:
            continue

        render_text_to_image(text, DOC_FONT, f"{OUTPUT_DIR}/{text_id}_doc.png")
        render_text_to_image(text, LATEX_FONT, f"{OUTPUT_DIR}/{text_id}_latex.png")
        render_text_to_image(text, BROWSER_FONT, f"{OUTPUT_DIR}/{text_id}_browser.png")

        if source == "ai":
            ai_count += 1
        else:
            human_count += 1

        if ai_count == MAX_PER_CLASS and human_count == MAX_PER_CLASS:
            break

    print("======================================")
    print(f"AI samples: {ai_count} → {ai_count * 3} images")
    print(f"Human samples: {human_count} → {human_count * 3} images")
    print(f"Total images: {(ai_count + human_count) * 3}")
    print("======================================")

if __name__ == "__main__":
    main()

AI samples: 5000 → 15000 images
Human samples: 5000 → 15000 images
Total images: 30000


In [2]:
df = pd.read_csv('metadata.csv')
print(df['label'].value_counts())

label
0    15000
1    15000
Name: count, dtype: int64


In [3]:
df_doc = df[df['variant'] == 'doc'].copy()
df_doc.to_csv('metadata_exp1.csv', index=False)

In [4]:
print("total images of doc:", len(df_doc))
print(df_doc['label'].value_counts())

total images of doc: 10000
label
0    5000
1    5000
Name: count, dtype: int64


In [7]:
!pip install scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (11 kB)
Collecting scipy>=1.10.0 (from scikit-learn)
  Downloading scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting joblib>=1.3.0 (from scikit-learn)
  Downloading joblib-1.5.3-py3-none-any.whl.metadata (5.5 kB)
Collecting threadpoolctl>=3.2.0 (from scikit-learn)
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl (8.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.1/8.1 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading joblib-1.5.3-py3-none-any.whl (309 kB)
Downloading scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl (20.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.9/20.9 MB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hUsing cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Insta

In [9]:
from sklearn.model_selection import train_test_split

df = pd.read_csv('metadata_exp1.csv')
unique_ids = df['text_id'].unique()

train_ids, temp_ids = train_test_split(unique_ids, test_size=0.3, random_state=42, shuffle=True)

val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)

df['split'] = "train"
df.loc[df['text_id'].isin(val_ids), "split"] = "val"
df.loc[df['text_id'].isin(test_ids), "split"] = "test"

In [10]:
print(df['split'].value_counts())

split
train    7000
test     1500
val      1500
Name: count, dtype: int64


In [11]:
df.to_csv("metadata_exp1_split.csv", index=False)

In [12]:
!pip install torch

Collecting torch
  Using cached torch-2.9.1-cp312-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting setuptools (from torch)
  Using cached setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Using cached networkx-3.6.1-py3-none-any.whl.metadata (6.8 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (2.7 kB)
Using cached torch-2.9.1-cp312-none-macosx_11_0_arm64.whl (74.5 MB)
Using cached networkx-3.6.1-py3-none-any.whl (2.1 MB)
Using cached sympy-1.14.0-py3-none-any.whl (6.3 MB)
Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)
Using cached jinja2-3.1.6-

In [13]:
import torch
from torch.utils.data import Dataset

In [15]:
from text_image_dataset import TextImageDataset

In [2]:
from data import get_dataloaders

train_loader, val_loader, test_loader = get_dataloaders(
    csv_file="metadata_exp1_split.csv",
    image_dir="triplet_output",
    batch_size=32
)