# Visualize the Trained Model

In [4]:
import torch
from models import QualityClassifier
import sqlite3
import random
from PIL import Image
import io
import base64
from IPython.display import display, HTML

In [3]:
with sqlite3.connect('../data/clip-embeddings.sqlite3') as conn:
	cur = conn.cursor()
	cur.execute('SELECT path, score FROM images WHERE score IS NOT NULL')
	paths = cur.fetchall()

random_paths = random.sample(paths, 1024)

bins = [[] for _ in range(10)]

for path, score in random_paths:
	bins[score].append(path)

In [None]:
print("Bin sizes:")
print(", ".join(f"{len(bin)}" for bin in bins))

def img_html(path):
	image = Image.open(path)
	scale = 512 / max(image.size)
	image = image.resize((int(image.width * scale), int(image.height * scale)))
	image_base64 = io.BytesIO()
	image.save(image_base64, format='WebP', quality=80)
	image_base64 = base64.b64encode(image_base64.getvalue()).decode('utf-8')
	return f'<img src="data:image/webp;base64,{image_base64}" width="512" style="margin: 5px;">'

html = "<table>"

for bin_number, image_paths in enumerate(bins):
	sampled_images = random.sample(image_paths, 5)

	row_html = f"<tr><td>{bin_number}</td>"
	for path in sampled_images:
		row_html += f"<td>{img_html(path)}</td>"
	row_html += "</tr>"

	html += row_html

html += "</table>"

display(HTML(html))

In [None]:
model = QualityClassifier(768, 0.0)
model.load_state_dict(torch.load('classifier.pt'))
model.eval()

In [None]:
with sqlite3.connect('../data/clip-embeddings.sqlite3') as conn:
	cursor = conn.cursor()
	cursor.execute('SELECT path FROM images WHERE embedding IS NOT NULL')
	all_paths = [row[0] for row in cursor.fetchall()]
	random_paths = random.sample(all_paths, 128)

	random_embeddings = []
	for path in random_paths:
		cursor.execute('SELECT embedding FROM images WHERE path = ?', (path,))
		embedding = bytes(cursor.fetchone()[0])
		embedding = torch.frombuffer(embedding, dtype=torch.float16).to(torch.float32)
		random_embeddings.append(embedding)
