In [1]:
import glob
import itertools
import json
import numpy as np
import os
import pandas as pd
from PIL import Image
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.datasets import ImageFolder
from sklearn.model_selection import KFold

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

BASE_PATH = os.path.expanduser("~/Code/ModelGeo")
DATA_PATH = os.path.join(BASE_PATH, "data")
if not os.path.exists(DATA_PATH): os.makedirs(DATA_PATH)
QUERY_PATH = os.path.join(BASE_PATH, "queries")
SAVES_PATH = os.path.join(BASE_PATH, "saves")
if not os.path.exists(SAVES_PATH): os.makedirs(SAVES_PATH)
METRICS_FILE = os.path.join(SAVES_PATH, "all.metrics")
if not os.path.isfile(METRICS_FILE): raise Exception()

BATCH_SIZE = 64
EPOCHS = 10
FOLDS = 1
SEED = 37
VAL_RATIO = 0.2

transform = Compose([Resize((224, 224)), ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

dataset = ImageFolder(DATA_PATH, transform=transform)
n = len(dataset)
k = len(dataset.classes)

train_dataset, val_dataset, test_dataset = random_split(dataset, [int(0.7 * n), int(0.15 * n), n - (int(0.7 * n) + int(0.15 * n))])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

class Model(nn.Module):
	def __init__(self, hidden_dim, num_heads, num_layers, num_classes):
		super().__init__()
		self.id = f"({hidden_dim}-{num_heads}-{num_layers}-{num_classes})"
		self.cnn = resnet50(weights=ResNet50_Weights.DEFAULT)
		self.cnn.fc = nn.Identity()
		self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=2048, nhead=num_heads, dim_feedforward=hidden_dim), num_layers=num_layers)
		self.fc = nn.Linear(2048, num_classes)
	def forward(self, x):
		x = self.cnn(x)
		x = x.unsqueeze(1)
		x = self.transformer(x)
		x = x.squeeze(1)
		return self.fc(x)

def save(model, epoch, train_loss, val_loss):
	model_path = os.path.join(SAVES_PATH, model.id)
	os.makedirs(model_path, exist_ok=True)
	torch.save(model.state_dict(), os.path.join(model_path, f"{model.id}_epoch{epoch}.pth"))
	with open(os.path.join(model_path, f"{model.id}_epoch{epoch}.losses"), "w") as f: json.dump({"train": train_loss, "val": val_loss}, f)

def load(model, epoch=None):
	model_path = os.path.join(SAVES_PATH, model.id)
	model_files = sorted(glob.glob(os.path.join(model_path, "*.pth")), key=os.path.getmtime)
	model_checkpoint = model_files[-1] if epoch is None else os.path.join(model_path, f"{model.id}_epoch{epoch}.pth")
	model.load_state_dict(torch.load(model_checkpoint, map_location=DEVICE))

def train(model, train_loader, val_loader, criterion, optimizer, epochs=EPOCHS):
	for epoch in range(1, epochs + 1):
		model.train()
		train_loss = 0
		for images, labels in train_loader:
			images = images.to(DEVICE)
			labels = labels.to(DEVICE)
			optimizer.zero_grad()
			loss = criterion(model(images), labels)
			loss.backward()
			optimizer.step()
			train_loss += loss.item()
		train_loss /= len(train_loader)

		model.eval()
		val_loss = sum(criterion(model(images.to(DEVICE)), labels.to(DEVICE)).item() for images, labels in val_loader) / len(val_loader)

		print(f"Epoch {epoch}: Train Loss = {train_loss:.5f}, Val Loss = {val_loss:.5f}")

		save(model, epoch, train_loss, val_loss)

def evaluate(model, test_loader):
	load(model)

	model.eval()
	correct = 0
	total = 0
	with torch.no_grad():
		for images, labels in test_loader:
			images = images.to(DEVICE)
			labels = labels.to(DEVICE)
			_, predicted = torch.max(model(images), 1)
			correct += (predicted == labels).sum().item()
			total += labels.size(0)
	print(f"Test Accuracy: {100 * correct / total:.2f}%")

	return round(100 * correct / total, 2)

def crossvalidate(model, criterion, optimizer, dataset, hyperparameters, epochs=EPOCHS):
	if FOLDS == 1:
		split = int(0.8 * len(dataset))
		train_dataset = torch.utils.data.Subset(dataset, range(split))
		test_dataset = torch.utils.data.Subset(dataset, range(split, len(dataset)))

		val_size = int(VAL_RATIO * len(train_dataset))
		train_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(val_size, len(train_dataset))), batch_size=BATCH_SIZE, shuffle=True)
		val_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(val_size)), batch_size=BATCH_SIZE)

		train(model, train_loader, val_loader, criterion, optimizer, epochs=epochs)
		M_list = [evaluate(model, DataLoader(test_dataset, batch_size=BATCH_SIZE))]
	else:
		kf = KFold(n_splits=FOLDS, shuffle=True, random_state=SEED)
		M_list = []
		indices = np.arange(len(dataset))

		for train_i, test_i in kf.split(indices):
			train_dataset = torch.utils.data.Subset(dataset, train_i)
			test_dataset = torch.utils.data.Subset(dataset, test_i)

			val_size = int(VAL_RATIO * len(train_dataset))
			train_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(val_size, len(train_dataset))), batch_size=BATCH_SIZE, shuffle=True)
			val_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(val_size)), batch_size=BATCH_SIZE)

			train(model, train_loader, val_loader, criterion, optimizer, epochs=epochs)
			M_list.append(evaluate(model, DataLoader(test_dataset, batch_size=BATCH_SIZE)))

	rounded_mean_M = round(np.mean(M_list), 8)
	json.dump(
		(json.load(open(METRICS_FILE)) if os.path.getsize(METRICS_FILE) > 0 else []) + [{"hyperparameters": hyperparameters, "M": rounded_mean_M}],
		open(METRICS_FILE, "w"),
		indent=4
	)

def search(grid):
	for column in itertools.product(grid["hidden_dim"], grid["num_heads"], grid["num_layers"]):
		model = Model(*column, k).to(DEVICE)
		criterion = nn.CrossEntropyLoss()
		optimizer = optim.Adam(model.parameters(), lr=grid["lr"][0])

		hyperparameters = f"{model.id}-{grid['lr'][0]}"
		crossvalidate(model, criterion, optimizer, dataset, hyperparameters)

def predict(model):
	load(model)

	model.eval()
	for image_path in glob.glob(os.path.join(QUERY_PATH, "*.jpg")):
		image = Image.open(image_path).convert("RGB")
		tensor_image = transform(image).to(DEVICE).unsqueeze(0)
		with torch.no_grad(): probabilities = torch.softmax(model(tensor_image), dim=1)
		predicted_class = probabilities.argmax().item()
		print(f"{image_path}: Predicted class = {dataset.classes[predicted_class]}")

grid = {
	"hidden_dim": [4096],
	"num_heads": [8],
	"num_layers": [3],
	"lr": [0.0001],
}
search(grid)



Epoch 1: Train Loss = 0.68578, Val Loss = 13.41380
Epoch 2: Train Loss = 0.49313, Val Loss = 18.12076
Epoch 3: Train Loss = 0.51417, Val Loss = 19.14291
Epoch 4: Train Loss = 0.46469, Val Loss = 18.11123
Epoch 5: Train Loss = 0.33173, Val Loss = 15.58720
Epoch 6: Train Loss = 0.17867, Val Loss = 11.78655
Epoch 7: Train Loss = 0.01061, Val Loss = 7.50494
Epoch 8: Train Loss = 0.00177, Val Loss = 3.67135
Epoch 9: Train Loss = 0.09410, Val Loss = 5.57468
Epoch 10: Train Loss = 0.00707, Val Loss = 7.66018
Test Accuracy: 100.00%


In [2]:
model = Model(4096, 8, 3, k).to(DEVICE)
predict(model)

/Users/gregory/Code/ModelGeo/queries/aland-1.jpg: Predicted class = Albania
