In [1]:
import torch
from torch.utils.data import DataLoader
from dataset.hfdataset import HFImageDataset
from dataset.filedataset import FileDataset
from dataset.transforms.geoclip import geoclip_train_transform
from model.embedding_heads.clip_image import CLIPImageEncoder
from model.embedding_heads.clip_text import CLIPTextEncoder
from model.embedding_heads.dino_image import DINOImageEncoder
from model.embedding_heads.rff_location import RFFLocationEncoder
from model.GeoClip import GeoCLIP

In [5]:
dataset = FileDataset("/Users/kshitij/Documents/University/Year4/MLP/RealProject/data/10K/metadata.csv", "/Users/kshitij/Documents/University/Year4/MLP/RealProject/data/10K", transform=geoclip_train_transform())

In [6]:
dataset.set_columns('file_name', 'lat', 'lon')
dataset.load_dataset("/Users/kshitij/Documents/University/Year4/MLP/RealProject/data/10K/metadata.csv")

Loading image paths and coordinates: 10000it [00:00, 37112.22it/s]


In [8]:
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

In [9]:
location_encoder = RFFLocationEncoder()

In [11]:
image_encoder = DINOImageEncoder()

In [12]:
location_encoder.load_weights('/Users/kshitij/Documents/University/Year4/MLP/RealProject/weights/pretrained_geoclip/location_encoder_weights.pth')
image_encoder.load_weights('/Users/kshitij/Documents/University/Year4/MLP/RealProject/weights/pretrained_geoclip/image_encoder_mlp_weights.pth')

In [15]:
model = GeoCLIP(image_encoder, location_encoder, '/Users/kshitij/Documents/University/Year4/MLP/RealProject/weights/pretrained_geoclip/coordinates_100K.csv')

In [16]:
model.load_weights('/Users/kshitij/Documents/University/Year4/MLP/RealProject/weights/pretrained_geoclip/logit_scale_weights.pth')

In [17]:
# Image Upload & Display
from PIL import Image
from io import BytesIO
# from google.colab import files
import matplotlib.pyplot as plt

# Heatmap
import folium
from folium.plugins import HeatMap

model.to('mps')
# Make predictions
top_pred_gps, top_pred_prob = model.predict("download.png", top_k=5)
print(top_pred_gps.shape)
# Display the top 5 GPS predictions
print("Top 5 GPS Predictions 📍")
print("========================")
for i in range(5):
    lat, lon = top_pred_gps[i]
    print(f"Prediction {i+1}: ({lat:.6f}, {lon:.6f}) - Probability: {top_pred_prob[i]:.6f}")

torch.Size([5, 2])
Top 5 GPS Predictions 📍
Prediction 1: (39.052750, -77.033386) - Probability: 0.005673
Prediction 2: (39.059383, -77.031860) - Probability: 0.005458
Prediction 3: (36.906528, -76.199455) - Probability: 0.003734
Prediction 4: (43.531212, -79.645042) - Probability: 0.003402
Prediction 5: (39.871918, -75.675499) - Probability: 0.003370


In [None]:
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="honours-project-aya",
    name="Training GeoCLIP with DINO and RFF Location Encoder",
    # track hyperparameters and run metadata
    config={
    "learning_rate": 3e-5,
    "architecture": "GeoCLIP",
    "dataset": "Training on GeoGuessing/GeoTaggedImages",
    "epochs": 10,
    }
)

In [19]:
image_encoder = DINOImageEncoder()
location_encoder = RFFLocationEncoder()
model = GeoCLIP(image_encoder, location_encoder, '/Users/kshitij/Documents/University/Year4/MLP/RealProject/weights/pretrained_geoclip/coordinates_100K.csv')
optim = torch.optim.Adam(model.parameters(), lr=3e-5, weight_decay=1e-6)

NameError: name 'torch' is not defined

In [None]:
for i in range(10):
    loss = train(dataloader, model, batch_size=256, device='mps', optimizer=optim, epoch=i)
    wandb.log({'loss': loss})
    if i % 2 == 0:
        os.makedirs(f'/Users/kshitij/Documents/University/Year4/MLP/Project/checkpoints/run_23_01_finetuning_2/{i}')
        save_weights(model, f'/Users/kshitij/Documents/University/Year4/MLP/Project/checkpoints/run_23_01_finetuning_2/{i}')
        wandb.save(f'/Users/kshitij/Documents/University/Year4/MLP/Project/checkpoints/run_23_01_finetuning_2/{i}/image_encoder_mlp_weights.pth')
        wandb.save(f'/Users/kshitij/Documents/University/Year4/MLP/Project/checkpoints/run_23_01_finetuning_2/{i}/location_encoder_weights.pth')
        wandb.save(f'/Users/kshitij/Documents/University/Year4/MLP/Project/checkpoints/run_23_01_finetuning_2/{i}/logit_scale_weights.pth')