In [1]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.datasets import ImageFolder

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

IMAGE_DIR = os.path.expanduser("~/Code/geo/data")
SAVE_DIR = os.path.expanduser("./saves")
MODEL_FILE = os.path.join(SAVE_DIR, "weights.pth")

In [2]:
transform = Compose([
	Resize((224, 224)),
	ToTensor(),
	Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset = ImageFolder(IMAGE_DIR, transform=transform)
n, k = len(dataset), len(dataset.classes)
train_dataset, val_dataset = random_split(dataset, [int(0.8 * n), n - int(0.8 * n)])
train_loader, val_loader = DataLoader(train_dataset, batch_size=32, shuffle=True), DataLoader(val_dataset, batch_size=32)

model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, k)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(epochs=2):
	if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)
	for epoch in range(epochs):
		model.train()
		for images, labels in train_loader:
			images, labels = images.to(device), labels.to(device)
			optimizer.zero_grad()
			loss = criterion(model(images), labels)
			loss.backward()
			optimizer.step()
		print(f"Epoch {epoch+1}: Loss = {loss.item()}")
	torch.save(model.state_dict(), MODEL_FILE)

def evaluate():
	model.load_state_dict(torch.load(MODEL_FILE, map_location=device))
	model.eval()
	correct = 0
	total = 0
	with torch.no_grad():
		for images, labels in val_loader:
			images, labels = images.to(device), labels.to(device)
			_, predicted = torch.max(model(images), 1)
			correct += (predicted == labels).sum().item()
			total += labels.size(0)
	print(f"Validation Accuracy: {100 * correct / total:.2f}%")

In [3]:
train()
evaluate()

Epoch 1: Loss = 0.6334580183029175
Epoch 2: Loss = 0.5084250569343567
Validation Accuracy: 90.00%
