# Grid Manual Tagging

Manually tagging images as to whether are a grid of thumbnails or not.

In [1]:
import random
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn.functional as F
import ipywidgets as widgets
from collections import defaultdict, OrderedDict
import io
import sqlite3
from model import NsfwClassifier
import json
import base64
from ipywidgets import HTML
from IPython.display import display, Javascript
from pathlib import Path

In [2]:
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-base-patch32")

In [3]:
conn = sqlite3.connect('../data/clip-embeddings.sqlite3')
cursor = conn.cursor()

cursor.execute("SELECT path FROM images WHERE embedding IS NOT NULL")
all_paths = [row[0] for row in cursor.fetchall()]

random_paths = random.sample(all_paths, 1000)

In [4]:
# Find likely bad content using zero-shot
@torch.no_grad()
def zero_shot_search(text: str) -> list[float]:
	# Get image embeddings
	embeddings = []
	for path in random_paths:
		cursor.execute("SELECT embedding FROM images WHERE path = ?", (path,))
		embedding = bytes(cursor.fetchone()[0])
		embedding = torch.frombuffer(bytearray(embedding), dtype=torch.float16)
		embeddings.append(embedding.to(torch.float32))

	image_embeds = torch.stack(embeddings)
	image_embeds = clip_model.visual_projection(image_embeds)

	# Get text embedding
	inputs = clip_processor(text=[text], return_tensors="pt", padding=True)
	text_embeds = clip_model.get_text_features(**inputs)

	# Normalize embeddings
	image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
	text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

	# Cosine similarity as logits
	logit_scale = clip_model.logit_scale.exp()
	logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * logit_scale.to(text_embeds.device)

	return logits_per_text[0].tolist()

zero_shot_scores = zero_shot_search("Preview thumbnails")

In [4]:
# Use trained classifier to score images
model = NsfwClassifier(768, 0.0, 2)
model.load_state_dict(torch.load('classifier.pt'))
model.eval()

@torch.no_grad()
def trained_filter(path, model) -> float:
	cursor.execute("SELECT embedding FROM images WHERE path = ?", (path,))
	embedding = bytes(cursor.fetchone()[0])
	embedding = torch.frombuffer(bytearray(embedding), dtype=torch.float16)
	embedding = embedding.to(torch.float32)
	embedding = embedding.unsqueeze(0)

	logits = model(embedding)
	probabilities = F.softmax(logits, dim=1)

	return probabilities[:, 1].item()


def get_trained_scores(paths: list[str | Path], model) -> OrderedDict[Path | str, dict]:
	scores = [trained_filter(path, model) for path in paths]
	items = list(zip(paths, scores))
	items = sorted(items, key=lambda x: x[1], reverse=True)

	return OrderedDict((x[0], {"score": x[1]}) for x in items)

In [5]:
def read_existing_scores() -> dict[str, int]:
	if not Path("manual_scores.json").exists():
		return {}

	with open("manual_scores.json", "r") as f:
		return json.load(f)

In [None]:
class Rater:
	paths: OrderedDict[str | Path, dict]

	def __init__(self, paths: OrderedDict[str | Path, dict], img_width: int, img_height: int):
		self.idx = -1
		self.paths = paths
		self.img_width = img_width
		self.img_height = img_height

		self.image_box = widgets.Image(width=img_width, height=img_height, format='jpg')
		self.grid_button = widgets.Button(description='Grid')
		self.none_button = widgets.Button(description='No grid')
		self.skip_button = widgets.Button(description='Skip')

		self.score_label = widgets.Label(value="Score: ?")

		self.grid_button.on_click(lambda _: self.on_feedback(1))
		self.none_button.on_click(lambda _: self.on_feedback(0))
		self.skip_button.on_click(lambda _: self.next_image())

		self.vbox = widgets.VBox([widgets.HBox([self.grid_button, self.none_button, self.skip_button]), self.score_label, self.image_box])
	
	def on_feedback(self, feedback: int):
		scores = read_existing_scores()
		path = list(self.paths.keys())[self.idx]
		scores[str(path)] = feedback

		# Save feedback to database
		with open("manual_scores.tmp", "w") as f:
			json.dump(scores, f, indent=4)
		
		Path("manual_scores.tmp").rename("manual_scores.json")
		
		self.next_image()
	
	def set_image(self, path: str | Path):
		image = Image.open(path)
		scale1 = self.img_width / image.width
		scale2 = self.img_height / image.height
		width = int(image.width * min(scale1, scale2))
		height = int(image.height * min(scale1, scale2))
		image = image.resize((width, height))
		with io.BytesIO() as output:
			image.save(output, format="JPEG")
			self.image_box.value = output.getvalue()
		self.image_box.width = width
		self.image_box.height = height
		#self.image_box.value = open(path, 'rb').read()

		score = self.paths[path]["score"]
		self.score_label.value = f"Score: {score:.4f}"
	
	def prev_image(self):
		self.idx = max(0, self.idx - 1)
		path = list(self.paths.keys())[self.idx]
		self.set_image(path)
	
	def next_image(self):
		ratings = read_existing_scores()
		self.idx += 1

		while self.idx < len(self.paths):
			path = list(self.paths.keys())[self.idx]

			if path in ratings:
				self.idx += 1
				continue

			self.set_image(path)
			return
		
		self.image_box.value = None
		print("Done!")


random_paths = random.sample(all_paths, 1000)
random_paths = get_trained_scores(random_paths, model)
rater = Rater(random_paths, 800, 800)
rater.next_image()
rater.vbox

## Big Rater

In [None]:
class BigRater:
	paths: OrderedDict[Path | str, dict]

	def __init__(self, paths: OrderedDict[Path | str, dict]):
		self.clicked_paths = []

		# Filter out paths that have already been rated
		scores = read_existing_scores()
		scored_paths = set(scores.keys())
		paths = OrderedDict([(path, x) for path, x in paths.items() if path not in scored_paths])

		self.paths = paths

		self.display_rating_grid(list(self.paths.keys()), self.on_done_clicked, self.on_image_click)

	def on_image_click(self, b):
		image_path = b.tooltip
		self.clicked_paths.append(image_path)
		widget = self.paths[image_path]['widget']
		widget.children[-1].style.button_color = 'red'
	
	def on_done_clicked(self, b):
		self.vbox.close()

		self.display_rating_grid(self.clicked_paths, self.on_verify_clicked, None)
	
	def on_verify_clicked(self, b):
		self.vbox.close()

		# Save feedback to database
		scores = read_existing_scores()

		for path in self.clicked_paths:
			scores[path] = 1

		# Save feedback to database
		with open("manual_scores.tmp", "w") as f:
			json.dump(scores, f, indent=4)

		Path("manual_scores.tmp").rename("manual_scores.json")
	
	def display_rating_grid(self, paths: list[Path | str], on_done, on_image_click):
		images = []
		for path in tqdm(paths):
			image = Image.open(path)
			image = image.resize((256, 256)).convert("RGB")
			with io.BytesIO() as output:
				image.save(output, format="JPEG")
				image_bytes = output.getvalue()
			
			widget = widgets.Button(
				description='Grid',
				layout=widgets.Layout(width='256px', height='25px', padding='0px', margin='0px'),
				tooltip=path,
			)
			widget.style.button_color = 'lightgray'
			if on_image_click is not None:
				widget.on_click(on_image_click)

			image = widgets.Image(value=image_bytes, format='jpeg')

			label = widgets.Label(value=f"Score: {self.paths[path]['score']:.4f}", layout=widgets.Layout(width='256px', height='20px', padding='0px', margin='0px'))

			vbox = widgets.VBox([label, image, widget])
			images.append(vbox)
			self.paths[path]['widget'] = vbox

		grid = widgets.GridBox(
			images,
			layout=widgets.Layout(
				grid_template_columns='repeat(auto-fit, 256px)',
				grid_gap='10px'
			)
		)

		done_button = widgets.Button(description='Done')
		done_button.on_click(on_done)
		self.vbox = widgets.VBox([grid, done_button])
		self.image_widgets = images

		display(self.vbox)


random_paths = random.sample(all_paths, 256)
random_paths = get_trained_scores(random_paths, model)
big_rater = BigRater(random_paths)

## Double check data by visualizing all of it

In [None]:
groups = defaultdict(list)
for path, score in read_existing_scores().items():
	groups[score].append(path)

html = "<html><head><style>img { width: 200px; height: 200px; }</style>"
html += "<script>function copyToClipboard(text) { navigator.clipboard.writeText(text); }</script>"
html += "</head><body>"

for score, paths in groups.items():
	html += f"<h1>{score}</h1>"
	
	for path in tqdm(paths):
		image = Image.open(path)
		scale = 256 / max(image.width, image.height)
		image = image.resize((int(image.width * scale), int(image.height * scale)))
		with io.BytesIO() as output:
			image.save(output, format="WEBP")
			contents = output.getvalue()
			b64 = base64.b64encode(contents).decode()
		
		html += f'<img src="data:image/webp;base64,{b64}" onclick="copyToClipboard(\'{path}\')" />'
	
html += "</body></html>"

with open("output.html", "w") as f:
	f.write(html)