In [1]:
import torch
from torch.utils.data import DataLoader
from dataset.hfdataset import HFImageDataset
from dataset.filedataset import FileDataset
from dataset.transforms.geoclip import geoclip_train_transform
from model.embedding_heads.clip_image import CLIPImageEncoder
from model.embedding_heads.clip_text import CLIPTextEncoder
from model.embedding_heads.dino_image import DINOImageEncoder
from model.embedding_heads.rff_location import RFFLocationEncoder
from model.GeoClip import GeoCLIP

In [2]:
dataset = FileDataset("/Users/kshitij/Documents/University/Year4/MLP/RealProject/data/10K/metadata.csv", "/Users/kshitij/Documents/University/Year4/MLP/RealProject/data/10K", transform=geoclip_train_transform())

In [3]:
dataset.set_columns('file_name', 'lat', 'lon')
dataset.load_dataset("/Users/kshitij/Documents/University/Year4/MLP/RealProject/data/10K/metadata.csv")

Loading image paths and coordinates: 9984it [00:00, 48877.60it/s]


In [4]:
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

In [5]:
location_encoder = RFFLocationEncoder()

In [6]:
image_encoder = DINOImageEncoder()

In [14]:
location_encoder.load_weights('weights/pretrained_geoclip/location_encoder_weights.pth')
image_encoder.load_weights('weights/pretrained_geoclip/image_encoder_weights.pth')

In [15]:
model = GeoCLIP(image_encoder, location_encoder, '/Users/kshitij/Documents/University/Year4/MLP/RealProject/weights/pretrained_geoclip/coordinates_100K.csv')

In [18]:
model.load_weights('weights/pretrained_geoclip/logit_scale_weights.pth')

In [19]:
# Image Upload & Display
from PIL import Image
from io import BytesIO
# from google.colab import files
import matplotlib.pyplot as plt

# Heatmap
import folium
from folium.plugins import HeatMap

model.to('mps')
# Make predictions
top_pred_gps, top_pred_prob = model.predict("download.png", top_k=5)
print(top_pred_gps.shape)
# Display the top 5 GPS predictions
print("Top 5 GPS Predictions 📍")
print("========================")
for i in range(5):
    lat, lon = top_pred_gps[i]
    print(f"Prediction {i+1}: ({lat:.6f}, {lon:.6f}) - Probability: {top_pred_prob[i]:.6f}")

torch.Size([5, 2])
Top 5 GPS Predictions 📍
Prediction 1: (32.958195, -97.419975) - Probability: 0.002400
Prediction 2: (29.971237, -98.833809) - Probability: 0.001870
Prediction 3: (44.588921, 22.731888) - Probability: 0.001856
Prediction 4: (48.528713, 8.797903) - Probability: 0.001781
Prediction 5: (50.072697, 31.035105) - Probability: 0.001745
