In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from DataHandler import *
from model import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
images_by_id, images_by_coordinates, path_to_coordinates = load_data()

In [None]:
# Flatten the data for the dataset
image_paths = []
coordinates = []
for image_path, coord in path_to_coordinates.items():
    image_paths.append(image_path)
    coordinates.append(coord)

# Initialize the dataset
dataset = ImageGPSDataset(image_paths=image_paths, coordinates=coordinates)

In [None]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


In [None]:
from torch.utils.data import DataLoader

batch_size = 32  # You can adjust this according to your system's capability

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


In [None]:
model = ImageGPSModelV3().to(device)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
from torch.cuda.amp import GradScaler, autocast

# Initialize the gradient scaler
scaler = GradScaler()

epochs = 10  # Number of epochs

for epoch in range(epochs):
    model.train()
    total_train_loss = 0

    for images, coords in train_loader:
        images = images.to(device)
        coords = coords.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with automatic mixed precision
        with autocast():
            outputs = model(images)
            loss = loss_function(outputs, coords)

        # Backward pass and optimize
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()

    # Validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, coords in val_loader:
            images = images.to(device)
            coords = coords.to(device)

            # Forward pass for validation
            with autocast():
                outputs = model(images)
                loss = loss_function(outputs, coords)

            total_val_loss += loss.item()

    print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {total_train_loss/len(train_loader)}, Validation Loss: {total_val_loss/len(val_loader)}')


In [None]:
model_path = "image_gps_model.pth"

# Save additional model information
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss_function,
    'scaler_state_dict': scaler.state_dict(),  # If you're using GradScaler for mixed precision
}, model_path)