Collect some manually editted captions to help zero-shot and fine-tune LLMs for cleaning up the captions.

In [1]:
import psycopg
from pathlib import Path
import random
import ipywidgets as widgets
import yaml
from IPython.display import display, clear_output
import re
from collections import defaultdict
import openai
import json
#from mistralai import Mistral
import os
from dotenv import load_dotenv
#import mistralai
import time
import difflib
import textwrap
import requests
import copy

In [None]:
load_dotenv()

In [4]:
#chat_tokenizer = get_chat_template(
#	tokenizer,
#	chat_template = "llama-3.1",
#)

CHAT_PROMPT = """Please edit the user's provided image descriptions following the guidelines below:
1. The edits should be minimal and not affect the details or accuracy of the description.
2. Remove any mention of the image's resolution, but don't remove information about the image's quality.
3. Edit out any self-referential language. For example: "this is a digital painting" -> "a digital painting", "In this photo a woman stands" -> "photo of a woman standing", etc.
4. Randomly swap in informal synonyms for things like "penis", "vulva", etc.
5. Do not modify anything in quotes that are describing text in the image.
6. Randomly swap the word "photograph" to "photo".
7. Remove any duplicates from the description if the description repeats itself.
8. When you make edits, make sure to maintain the original meaning of the sentence, and minimize the number of changes.
9. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.

Respond with only the edited image description.
"""

ADD_PROMPT = """Please edit the user's provided image descriptions following the guidelines below:
1. If the description does not mention that there is a watermark in the image, add a mention of a watermark.
2. If the description already mentions a watermark, no changes are needed. Leave the description as is.
3. If the description says there is no watermark, fix the description to mention a watermark.
4. Fix any conflicting information about the watermark in the description.
5. When you make edits, make sure to maintain the original meaning of the sentence, and minimize the number of changes.
6. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.

Respond with only the edited image description, or the original description if no changes were needed.
"""

REMOVE_PROMPT = """Please edit the user's provided image descriptions following the guidelines below:
1. If the description mentions that there is a watermark in the image, remove the mention of the watermark.
2. If the description does not mention a watermark, no changes are needed. Leave the description as is.
3. If the description says there is no watermark, no changes are needed. Leave the description as is.
4. Fix any conflicting information about the lack of a watermark in the description.
5. When you make edits, make sure to maintain the original meaning of the sentence, and minimize the number of changes.
6. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.

Respond with only the edited image description, or the original description if no changes were needed.
"""

SOURCE_PROMPT = """Please edit the user's provided image descriptions following these guidelines:
1. In some way, add to the description so that it mentions that the image is from "{source}".
2. If the source starts with "r/" then randomly also add a mention of "reddit" to the description.
3. Your edits should be minimal and not affect the details or accuracy of the description in any way.
4. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.
5. There are a variety of ways that you can add this information to the description. Be creative and make sure to maintain the original meaning of the description.
6. Randomly change the capitalization of the source name in the description.

Respond with only the edited image description."""


def format_messages(operator: str, caption: str, extra: str | None, extra_prompt: str | None) -> list[dict]:
	if operator == "add-watermark":
		system_message = ADD_PROMPT
	elif operator == "remove-watermark":
		system_message = REMOVE_PROMPT
	elif operator == "add-source":
		system_message = SOURCE_PROMPT
	else:
		raise ValueError(f"Unknown operator: {operator}")
	
	if operator == "add-source":
		assert extra is not None
		system_message = system_message.format(source=extra)
	
	if extra_prompt is not None:
		system_message = extra_prompt + " " + system_message
	
	return [
		{"role": "system", "content": system_message.strip()},
		{"role": "user", "content": caption.strip()},
	]


def ask_our_model(operator: str, caption: str, extra: str | None):
	client = openai.OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1")

	if operator == "add-watermark":
		model = "watermark"
	elif operator == "remove-watermark":
		model = "watermark"
	elif operator == "add-source":
		model = "add-source"
	else:
		raise ValueError(f"Unknown operator: {operator}")

	chat_response = client.chat.completions.create(
		model=model,
		messages=format_messages(operator, caption, extra, None),
		temperature=0.6,
		top_p=0.9,
		#top_k=0,
		max_tokens=512,
	)

	return chat_response.choices[0].message.content


In [5]:
OPENROUTER_MODELS = {
	"openai/gpt-4o-2024-08-06": {"temperature": 1.0, "top_p": 1.0 },
	#"openai/o1-preview": {"temperature": 1.0, "top_p": 1.0 },
	"sao10k/l3.1-euryale-70b": {"temperature": 0.6, "top_p": 0.9 },
	"meta-llama/llama-3.1-70b-instruct:free": {"temperature": 0.6, "top_p": 0.9 },
	"meta-llama/llama-3.1-405b-instruct:free": {"temperature": 0.6, "top_p": 0.9 },
	"nousresearch/hermes-3-llama-3.1-405b:free": {"temperature": 0.6, "top_p": 0.9 },
}

def ask_openrouter(operator: str, caption: str, extra: str | None) -> str:
	model_name = random.choice(list(OPENROUTER_MODELS.keys()))
	model_params = OPENROUTER_MODELS[model_name]

	print(f"Using model {model_name}")

	client = openai.OpenAI(api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1")

	try:
		chat_response = client.chat.completions.create(
			model=model_name,
			messages=format_messages(operator, caption, extra, None),
			temperature=model_params["temperature"],
			top_p=model_params["top_p"],
			#top_k=0,
			max_tokens=512,
		)

		response = chat_response.choices[0].message.content
		assert isinstance(response, str)
		return response
	except Exception as e:
		print(e)
		return "Sorry, I'm having trouble connecting to the model. Please try again later."

In [6]:
conn = psycopg.connect(dbname='postgres', user='postgres', host=str(Path.cwd().parent / "pg-socket"))

In [7]:
#OPERATOR = 'remove_bullshit'
#OPERATOR = 'add-watermark'
#OPERATOR = 'remove-watermark'
OPERATOR = 'add-source'

In [8]:
SELECTIVE_EDITING = set(["features","depicts"])
SYNONYMS = yaml.safe_load(Path("synonyms.yaml").read_text())
SOURCES = {'fansly', 'flickr', 'onlyfans', 'unsplash'}

with conn.cursor() as cur:
	cur.execute("SELECT DISTINCT subreddit FROM images WHERE subreddit IS NOT NULL")
	SUBREDDITS = set(row[0] for row in cur.fetchall())

In [9]:
PUNC = re.compile(r'[,.!?:;]')

def is_repeating(caption: str) -> bool:
	# Split on any punctuation
	pieces = PUNC.split(caption)
	counts = defaultdict(int)
	for piece in pieces:
		counts[piece.strip().lower()] += 1
	
	if any(count > 2 for count in counts.values()):
		return True
	return False


def check_words_in_quotes(input_string: str, words_list: list[str]) -> bool:
	words_pattern = '|'.join(map(re.escape, words_list))

	# Regular expression to check if any word from the list is between quotes
	pattern = rf'["\'][^s]([^"\']*({words_pattern})[^"\']*)["\']'

	match = re.search(pattern, input_string.lower())
	return bool(match)


keywords = ['image']
syns = list(SYNONYMS.keys())

with conn.cursor('caption-editing') as cur:
	#cur.execute("SELECT caption FROM images WHERE caption IS NOT NULL AND caption LIKE '%this%'")
	#cur.execute("SELECT caption FROM images WHERE caption IS NOT NULL")
	cur.execute("SELECT caption_3 FROM images WHERE caption_3 IS NOT NULL")
	#cur.execute("SELECT caption_2 FROM images WHERE caption_2 IS NOT NULL")
	#cur.execute("SELECT caption FROM images WHERE caption IS NOT NULL AND (caption LIKE '%feature%' OR caption LIKE '%depict%' OR caption LIKE '%likely%' OR caption LIKE '%possibly%' OR caption LIKE '%probably%' OR caption LIKE '%resolution%' OR caption LIKE '%appear%' OR caption LIKE '%seem%')")
	g_captions = [row[0] for row in cur]
	#captions = [caption for caption in captions if any(keyword in caption.lower() for keyword in keywords)]
	#captions = [caption for caption in captions if is_repeating(caption)]
	#captions = [caption for caption in captions if check_words_in_quotes(caption, syns)]
	#g_captions = [caption for caption in g_captions if 'watermark' in caption.lower()]
	#g_captions = [caption for caption in g_captions if caption.lower().count('watermark') > 1]

	cur.execute("SELECT caption, recaptioned FROM recaption_dataset WHERE operator = %s", (OPERATOR,))
	recaptions = {row[0]: row[1] for row in cur}

In [10]:
def modify_caption(caption: str) -> str:
	words = caption.split(" ")
	updated_words = []

	for word in words:
		if random.random() < 0.5:
			updated_words.append(word)
			continue

		for k, v in SYNONYMS.items():
			if k not in word.lower():
				continue

			new_word = random.choice(v)

			# Match case
			if word.strip().isupper():
				new_word = new_word.upper()
			elif word.strip()[0].isupper():
				new_word = new_word.capitalize()

			print(f"Replacing {k} with {new_word} in {word}")
			
			updated_words.append(word.lower().replace(k, new_word))
			break
		else:
			updated_words.append(word)
	
	modified_caption = " ".join(updated_words).strip()

	return modified_caption

In [None]:
class CaptionEditor:
	def __init__(self, auto_modify: bool, show_repeats: bool, keywords_to_highlight: set[str], invert_button: bool, sources: set[str]):
		self.auto_modify = auto_modify
		self.show_repeats = show_repeats
		self.keywords_to_highlight = keywords_to_highlight
		self.sources = sources

		self.counter_label = widgets.Label(value="0 / 0")
		self.current_operator = widgets.Label(value=f"Operator: {OPERATOR}")
		self.sym_label = widgets.Label(value="")
		self.original_caption = widgets.Textarea(layout=widgets.Layout(width='100%', height='100px'), disabled=True)
		self.text_area = widgets.Textarea(layout=widgets.Layout(width='100%', height='100px'))
		self.highlighted_diff = widgets.HTML()
		self.repeat_area = widgets.Textarea(layout=widgets.Layout(width='100%', height='100px'), disabled=True)
		self.submit_button = widgets.Button(description='Submit')
		self.submit_button.on_click(self.on_submit)
		self.skip_button = widgets.Button(description='Skip')
		self.skip_button.on_click(lambda b: self.next_caption())
		self.chatgpt_button = widgets.Button(description='Ask ChatGPT')
		self.chatgpt_button.on_click(lambda b: self.ask_chatgpt())

		if invert_button:
			self.invert_button = widgets.Button(description='Submit & Invert')
			self.invert_button.on_click(self.on_submit_and_invert)

		self.submit_and_ask_button = widgets.Button(description='Submit & ChatGPT')
		self.submit_and_ask_button.on_click(self.on_submit_and_ask)

		ui = [
			self.counter_label,
			self.current_operator,
			self.sym_label,
			self.original_caption,
			self.text_area,
			self.highlighted_diff,
		]

		if self.show_repeats:
			ui.append(self.repeat_area)

		ui.extend([self.submit_button, self.submit_and_ask_button, self.skip_button, self.chatgpt_button])

		if invert_button:
			ui.append(self.invert_button)

		display(widgets.VBox(ui))

		self.text_area.observe(self.update_highlight, names='value')

		self.next_caption()
	
	def ask_chatgpt(self):
		self.text_area.value = "Asking ChatGPT..."
		#edited_caption = ask_openrouter(OPERATOR, self.current_caption, self.sym_label.value.split(": ")[1].strip())
		edited_caption = ask_our_model(OPERATOR, self.current_caption, self.sym_label.value.split(": ")[1].strip())
		self.text_area.value = edited_caption
	
	def show_caption(self, caption: str):
		# Automatic synonym replacement
		if self.auto_modify:
			modified_caption = modify_caption(caption)
		else:
			modified_caption = caption

		# Selective editting
		selective_editing = []

		if len(self.sources) > 0:
			# Downweight subreddits
			sources = list(self.sources)
			weights = [0.2 if "r/" in source else 1.0 for source in sources]
			source = random.choices(sources, weights=weights)[0]
			selective_editing.append(source)
		else:
			for key in SELECTIVE_EDITING:
				if random.random() < 0.5:
					selective_editing.append(key)

		self.original_caption.value = caption
		self.text_area.value = modified_caption
		self.sym_label.value = f"Selective editing: {', '.join(selective_editing)}"
		self.counter_label.value = f"Completed / Total: {len(recaptions)} / {len(g_captions)}"
	
	def next_caption(self):
		remaining = [caption for caption in g_captions if caption not in recaptions]
		if not remaining:
			print("No more captions to edit!")
			return
		
		self.current_caption = random.choice(remaining)
		
		self.show_caption(self.current_caption)
	
	def on_submit(self, b):
		recaptions[self.current_caption] = self.text_area.value
		extra = None

		if OPERATOR == 'add-source':
			extra = self.sym_label.value.split(": ")[1].strip()

		with conn.cursor() as cur:
			cur.execute("INSERT INTO recaption_dataset (caption, recaptioned, operator, extra_1) VALUES (%s, %s, %s, %s)", (self.current_caption, self.text_area.value, OPERATOR, extra))
			conn.commit()

		self.next_caption()
	
	def on_submit_and_ask(self, b):
		self.on_submit(b)
		self.ask_chatgpt()
	
	def on_submit_and_invert(self, b):
		recaptions[self.current_caption] = self.text_area.value

		if OPERATOR == 'remove-watermark':
			inverted_operator = 'add-watermark'
		else:
			raise ValueError(f"Don't know how to invert operator {OPERATOR}")
		
		with conn.cursor() as cur:
			cur.execute("INSERT INTO recaption_dataset (caption, recaptioned, operator) VALUES (%s, %s, %s)", (self.current_caption, self.text_area.value, OPERATOR))
			cur.execute("INSERT INTO recaption_dataset (caption, recaptioned, operator) VALUES (%s, %s, %s)", (self.text_area.value, self.current_caption, inverted_operator))
			conn.commit()
		
		self.next_caption()
	
	def update_highlight(self, change):
		original = self.current_caption
		modified = self.text_area.value
		highlighted_html = self.highlight_changes_html(original, modified)
		self.highlighted_diff.value = highlighted_html

		repeats = get_repeats(modified)
		self.repeat_area.value = "\n".join(repeats)
	
	def highlight_changes_html(self, original: str, modified: str) -> str:
		original_tokens = self.tokenize(original)
		modified_tokens = self.tokenize(modified)

		diff = difflib.ndiff(original_tokens, modified_tokens)
		result = []

		for word in diff:
			token = word[2:]
			if word.startswith('  '):  # no change
				clean_token = token.strip()
				if clean_token.lower() in self.keywords_to_highlight:
					result.append(f'<span style="background-color: yellow;">{token}</span>')
				else:
					result.append(token)
			elif word.startswith('- '):  # deletion
				result.append(f'<span style="color: red; text-decoration: line-through;">{token}</span>')
			elif word.startswith('+ '):  # addition
				result.append(f'<span style="color: green; font-weight: bold;">{token}</span>')
		
		return ''.join(result)
	
	@staticmethod
	def tokenize(text: str) -> list[str]:
		#return re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
		return re.findall(r'\s+|\w+|[^\w\s]', text, re.UNICODE)


def get_repeats(caption: str) -> list[str]:
	pieces = [x.strip().lower() for x in PUNC.split(caption)]
	repeats = []

	for piece in pieces:
		if pieces.count(piece) > 1:
			repeats.append(piece)
	
	return repeats


auto_modify = False
show_repeats = False
keywords_to_highlight = set()
invert_button = False
sources = set()

if OPERATOR == 'remove_bullshit':
	auto_modify = True
	show_repeats = True
	keywords_to_highlight = {"resolution", "image", "likely", "possibly", "probably", "appear", "seem", "this", "features", "depicts", "the photo", "the photograph"}
elif OPERATOR == 'add-watermark':
	keywords_to_highlight = {"watermark", "watermarks", "watermarked", "logo", "logos"}
elif OPERATOR == 'remove-watermark':
	keywords_to_highlight = {"watermark", "watermarks", "watermarked", "logo", "logos"}
	invert_button = True
elif OPERATOR == 'add-source':
	sources = SOURCES.union({f"r/{sub.lower()}" for sub in SUBREDDITS})
	sources.add("reddit")

CaptionEditor(auto_modify=auto_modify, show_repeats=show_repeats, keywords_to_highlight=keywords_to_highlight, invert_button=invert_button, sources=sources)

In [None]:
# Stats
syn_counts = {k: [0, 0] for k in SYNONYMS.keys()}
syn_counts['features'] = [0, 0]

vulgar_count = defaultdict(int)

for original, edited in recaptions.items():
	original = PUNC.sub("", original).lower()
	edited = PUNC.sub("", edited).lower()
	original_words = original.split()
	edited_words = edited.split()

	for keyword in syn_counts.keys():
		original_count = sum(1 for word in original_words if word == keyword)
		edited_count = sum(1 for word in edited_words if word == keyword)
		syn_counts[keyword][0] += original_count
		syn_counts[keyword][1] += edited_count
	
	for k, values in SYNONYMS.items():
		if not any(k in word for word in original_words):
			continue
		
		for v in values:
			count = sum(1 for word in edited_words if word == v)
			vulgar_count[f"{k}->{v}"] += count

for key, (original_count, edited_count) in syn_counts.items():
	print(f"{key}: {edited_count} / {original_count} ({edited_count / max(0.0001, original_count) * 100:.2f}%)")

print()
print("Vulgar counts:")
for key, count in vulgar_count.items():
	print(f"{key}: {count}")

In [41]:
with conn.cursor() as cur:
	cur.execute("SELECT caption, recaptioned FROM recaption_dataset WHERE operator = 'add-watermark'")
	add_watermark_recaptions = {row[0]: row[1] for row in cur}
	cur.execute("SELECT caption, recaptioned FROM recaption_dataset WHERE operator = 'remove-watermark'")
	remove_watermark_recaptions = {row[0]: row[1] for row in cur}

add_watermark_recaptions_changes = sum(1 for k, v in add_watermark_recaptions.items() if k != v)
remove_watermark_recaptions_changes = sum(1 for k, v in remove_watermark_recaptions.items() if k != v)

print(f"Add watermark: {add_watermark_recaptions_changes} / {len(add_watermark_recaptions)} ({add_watermark_recaptions_changes / max(0.0001, len(add_watermark_recaptions)) * 100:.2f}%)")
print(f"Remove watermark: {remove_watermark_recaptions_changes} / {len(remove_watermark_recaptions)} ({remove_watermark_recaptions_changes / max(0.0001, len(remove_watermark_recaptions)) * 100:.2f}%)")

Add watermark: 195 / 394 (49.49%)
Remove watermark: 179 / 358 (50.00%)


In [None]:
for caption in random.sample([caption for caption in captions if "watermark" in caption.lower()], 32):
	print(caption)
	print("###")

## Write training data for add-source

In [21]:
SOURCE_PROMPT = """Please edit the user's provided image descriptions following these guidelines:
1. In some way, add to the description so that it mentions that the image is from "{source}".
2. If the source starts with "r/" then randomly also add a mention of "reddit" to the description.
3. Your edits should be minimal and not affect the details or accuracy of the description in any way.
4. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.
5. There are a variety of ways that you can add this information to the description. Be creative and make sure to maintain the original meaning of the description.
6. Randomly change the capitalization of the source name in the description.

Respond with only the edited image description."""

def write_training(filename: str, examples: list[tuple[str, str, str]]):
	with open(filename, 'w') as f:
		for k, v, source in examples:
			example = {
				"messages": [
					{
						"role": "system",
						"content": SOURCE_PROMPT.format(source=source).strip(),
					},
					{
						"role": "user",
						"content": k.strip(),
					},
					{
						"role": "assistant",
						"content": v.strip(),
					},
				]
			}
			f.write(json.dumps(example) + "\n")


with conn.cursor() as cur:
	cur.execute("SELECT caption, recaptioned, extra_1 FROM recaption_dataset WHERE operator = 'add-source'")
	examples = [row for row in cur]

random.shuffle(examples)

n_test = 32
test_examples = examples[:n_test]
train_examples = examples[n_test:]

write_training("source-train.jsonl", train_examples)
write_training("source-test.jsonl", test_examples)

## Write training data for add/remove watermark

In [5]:
ADD_PROMPT = """Please edit the user's provided image descriptions following the guidelines below:
1. If the description does not mention that there is a watermark in the image, add a mention of a watermark.
2. If the description already mentions a watermark, no changes are needed. Leave the description as is.
3. If the description says there is no watermark, fix the description to mention a watermark.
4. Fix any conflicting information about the watermark in the description.
5. When you make edits, make sure to maintain the original meaning of the sentence, and minimize the number of changes.
6. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.

Respond with only the edited image description, or the original description if no changes were needed.
"""

REMOVE_PROMPT = """Please edit the user's provided image descriptions following the guidelines below:
1. If the description mentions that there is a watermark in the image, remove the mention of the watermark.
2. If the description does not mention a watermark, no changes are needed. Leave the description as is.
3. If the description says there is no watermark, no changes are needed. Leave the description as is.
4. Fix any conflicting information about the lack of a watermark in the description.
5. When you make edits, make sure to maintain the original meaning of the sentence, and minimize the number of changes.
6. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.

Respond with only the edited image description, or the original description if no changes were needed.
"""

with conn.cursor() as cur:
	cur.execute("SELECT caption, recaptioned FROM recaption_dataset WHERE operator = 'add-watermark'")
	add_watermark_examples = [(ADD_PROMPT, row[0], row[1]) for row in cur]
	cur.execute("SELECT caption, recaptioned FROM recaption_dataset WHERE operator = 'remove-watermark'")
	remove_watermark_examples = [(REMOVE_PROMPT, row[0], row[1]) for row in cur]


def write_training(filename: str, examples: list[tuple[str, str, str]]):
	with open(filename, 'w') as f:
		for prompt, k, v in examples:
			example = {
				"messages": [
					{
						"role": "system",
						"content": prompt.strip(),
					},
					{
						"role": "user",
						"content": k.strip(),
					},
					{
						"role": "assistant",
						"content": v.strip(),
					},
				]
			}
			f.write(json.dumps(example) + "\n")

n_test = 16
random.shuffle(add_watermark_examples)
random.shuffle(remove_watermark_examples)
test_examples = add_watermark_examples[:n_test] + remove_watermark_examples[:n_test]
train_examples = add_watermark_examples[n_test:] + remove_watermark_examples[n_test:]
random.shuffle(test_examples)
random.shuffle(train_examples)

write_training("watermark-train.jsonl", train_examples)
write_training("watermark-test.jsonl", test_examples)

## Write training data for remove_bullshit

In [10]:
PROMPT = """Please edit the user's provided image descriptions following the guidelines below:
1. The edits should be minimal and not affect the details or accuracy of the description.
2. Remove any mention of the image's resolution, but don't remove information about the image's quality.
3. Edit out any self-referential language. For example: "this is a digital painting" -> "a digital painting", "In this photo a woman stands" -> "photo of a woman standing", etc.
4. Randomly swap in informal synonyms for things like "penis", "vulva", etc.
5. Do not modify anything in quotes that are describing text in the image.
6. Randomly swap the word "photograph" to "photo".
7. Remove any duplicates from the description if the description repeats itself.
8. When you make edits, make sure to maintain the original meaning of the sentence, and minimize the number of changes.
9. Only update the grammer if necessary. Do NOT fix any grammar mistakes or oddness that were in the original description. Some of them may be MidJourney prompts or lists of tags.

Respond with only the edited image description.
"""

def write_training(filename: str, examples: list[tuple[str, str]]):
	with open(filename, 'w') as f:
		for k, v in examples:
			example = {
				"messages": [
					{
						"role": "system",
						"content": PROMPT,
					},
					{
						"role": "user",
						"content": k.strip(),
					},
					{
						"role": "assistant",
						"content": v.strip(),
					},
				]
			}
			f.write(json.dumps(example) + "\n")


examples = [(k, v) for k, v in recaptions.items()]
random.shuffle(examples)

n_test = 32
test_examples = examples[:n_test]
train_examples = examples[n_test:]

write_training("train.jsonl", train_examples)
write_training("test.jsonl", test_examples)

In [None]:
for k, v in recaptions.items():
	print(f"Original: {k}")
	print(f"Rewrite: {v}")
	print()