# Find Grid Images with score > 0
Checking to see if any grid images got through our score filters

In [1]:
import numpy as np
import random
import psycopg
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.nn as nn
import ipywidgets as widgets
from collections import defaultdict
import itertools
from torch import optim
from transformers import get_scheduler
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.linear_model import LogisticRegression
import pickle
import io
import sqlite3
from model import NsfwClassifier
import json

In [2]:
model = NsfwClassifier(768, 0.0, 2)
model.load_state_dict(torch.load('classifier.pt'))
model.eval()

NsfwClassifier(
  (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.0, inplace=False)
  (linear1): Linear(in_features=768, out_features=1536, bias=True)
  (act_fn): GELU(approximate='none')
  (dropout2): Dropout(p=0.0, inplace=False)
  (linear2): Linear(in_features=1536, out_features=2, bias=True)
)

In [None]:
conn = sqlite3.connect('../data/clip-embeddings.sqlite3')
cursor = conn.cursor()

cursor.execute("SELECT path, score FROM images WHERE embedding IS NOT NULL AND score > 0")
all_paths = [row for row in cursor.fetchall()]
random.shuffle(all_paths)

need_to_find = 20

for i in range(0, len(all_paths), 256):
	batch = all_paths[i:i+256]

	embeddings = []

	for path,_ in batch:
		cursor.execute("SELECT embedding FROM images WHERE path = ?", (path,))
		embedding = np.frombuffer(cursor.fetchone()[0], dtype=np.float16)
		embedding = torch.tensor(embedding, dtype=torch.float32)
		embeddings.append(embedding)

	embeddings = torch.stack(embeddings)

	with torch.no_grad():
		outputs = model(embeddings)
		outputs = F.softmax(outputs, dim=1)
		scores = outputs[:, 1].tolist()
	
	found = False
	
	for path, db_score, score in zip([row[0] for row in batch], [row[1] for row in batch], scores):
		if score < 0.5:
			continue

		print(f"Score: {score:.2f} (DB: {db_score}), path: {path}")
		image = Image.open(path)
		scale = 1024 / max(image.size)
		display(image.resize((int(image.width * scale), int(image.height * scale))))
		found = True
		need_to_find -= 1
	
	if not found:
		print("Clean")
	elif need_to_find <= 0:
		break

In [4]:
# Inject problematic images into the quality-arena database
# We inject them as ties. Don't really care about one over the other winning. The goal is to get it into the system so they pop up during manual scorings.
import sqlite3
import itertools

paths = [
]

pairs = list(itertools.combinations(paths, 2))
print(len(pairs))

with sqlite3.connect("../quality-arena/ratings.sqlite3") as conn:
	cursor = conn.cursor()

	for path1, path2 in pairs:
		cursor.execute("INSERT INTO ratings (win_path, lose_path) VALUES (?, ?)", (path1, path2))
		cursor.execute("INSERT INTO ratings (win_path, lose_path) VALUES (?, ?)", (path2, path1))
	
	conn.commit()

3
