In [30]:
%load_ext autoreload
%autoreload 2

TORCH_ACCELERATOR = "cpu"
import os
if TORCH_ACCELERATOR == "cpu":
    os.environ["CUDA_VISIBLE_DEVICES"] = ""

import math
import multiprocessing

import torchvision.models as models
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import numpy as np
import pandas
import matplotlib.pyplot as plt

import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger

import torchmetrics
import webdataset as wds
import s2sphere
import tqdm

import label_mapping

torch.cuda.is_available()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


False

In [2]:
%run datasets.ipynb

DATASET_ROOT = Path.home() / "datasets" / "im2gps" / "outputs"

# Load the s2cell-annotated dataset
annotated_df = pandas.read_pickle(DATASET_ROOT / "s2cell_2007" / "annotated.pkl")
mapping = label_mapping.LabelMapping.read_csv(DATASET_ROOT / "s2cell_2007" / "cells.csv")

mapping.label_to_name

{0: '0097',
 1: '0099',
 2: '009b',
 3: '009d',
 4: '00a7',
 5: '0717',
 6: '07ab',
 7: '0b43',
 8: '0b47',
 9: '0b5d',
 10: '0c41',
 11: '0c47',
 12: '0c61',
 13: '0c6b',
 14: '0d05',
 15: '0d0b',
 16: '0d0d',
 17: '0d11',
 18: '0d13',
 19: '0d15',
 20: '0d17',
 21: '0d18c',
 22: '0d1931',
 23: '0d1933',
 24: '0d1935',
 25: '0d195',
 26: '0d19c',
 27: '0d1b',
 28: '0d1f',
 29: '0d23',
 30: '0d25',
 31: '0d2f',
 32: '0d31',
 33: '0d37',
 34: '0d39',
 35: '0d3b',
 36: '0d3d',
 37: '0d3f',
 38: '0d41',
 39: '0d4224',
 40: '0d42284',
 41: '0d4228c',
 42: '0d42294',
 43: '0d422f',
 44: '0d4234',
 45: '0d45',
 46: '0d47',
 47: '0d49',
 48: '0d4f',
 49: '0d51',
 50: '0d55',
 51: '0d57',
 52: '0d59',
 53: '0d5b',
 54: '0d5d',
 55: '0d5f',
 56: '0d61',
 57: '0d63',
 58: '0d6b',
 59: '0d6d',
 60: '0d6f',
 61: '0d71',
 62: '0d73',
 63: '0d7b',
 64: '0d97',
 65: '0d9f',
 66: '0da1',
 67: '0da7',
 68: '0dad',
 69: '0daf',
 70: '0db1',
 71: '0db3',
 72: '0dbb',
 73: '0dbd',
 74: '0e39',
 75: '0e3b'

In [3]:

image_id_to_s2cell = {row.id: row.s2cell for row in annotated_df.itertuples()}
print(f"Loaded {len(image_id_to_s2cell)} image ids with s2cell")

# train set has ~470k, val set has ~120k
BATCH_SIZE = auto_batch_size()
print("Batch size:", BATCH_SIZE)
NUM_WORKERS = (multiprocessing.cpu_count() * 2) // 3
print("Num workers:", NUM_WORKERS)

NORMALIZE_T = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

TRAIN_T = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    NORMALIZE_T,
])

VAL_T = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    NORMALIZE_T,
])

# Transform s2cells to labels, skipping examples without s2cell
def to_img_label(sample):
    img, meta = sample
    s2cell = image_id_to_s2cell.get(meta["id"])
    if s2cell is None:
        return None
    label = mapping.get_label(s2cell)
    # TODO: use VAL_T for the validation set
    return TRAIN_T(img), label

def urls_to_dataset(urls, shuffle=True):
    ds = wds.WebDataset(urls, shardshuffle=shuffle)
    if shuffle:
        ds = ds.shuffle(100)
    return ds.decode("pil").to_tuple("jpg", "json")\
        .map(to_img_label)\
        .batched(BATCH_SIZE)

train_dataset = urls_to_dataset(str(DATASET_ROOT / "wds" / "im2gps_2007_train_{000..028}.tar"))
val_dataset = urls_to_dataset(str(DATASET_ROOT / "wds" / "im2gps_2007_val_{000..007}.tar"))

train_dataloader = wds.WebLoader(train_dataset, batch_size=None, num_workers=NUM_WORKERS)
val_dataloader = wds.WebLoader(val_dataset, batch_size=None, num_workers=NUM_WORKERS)

# Visualize a few loaded samples
for inputs, targets in train_dataloader:
    print(inputs.shape, targets.shape, targets)
    break

for inputs, targets in val_dataloader:
    print(inputs.shape, targets.shape, targets)
    break

Loaded 591441 image ids with s2cell
Batch size: 1
Num workers: 5
torch.Size([1, 3, 224, 224]) torch.Size([1]) tensor([1098])
torch.Size([1, 3, 224, 224]) torch.Size([1]) tensor([790])


In [4]:
# Define a LightningModule for the classifier
class S2CellClassifierMnet3(L.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        mnet3 = models.mobilenet_v3_large(weights="IMAGENET1K_V2")

        self.features = mnet3.features
        self.avgpool = mnet3.avgpool
        hidden_size = 2048
        self.classifier = nn.Sequential(
            nn.Linear(mnet3.classifier[0].in_features, hidden_size),
            nn.Hardswish(inplace=True),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(hidden_size, num_classes), # out is 1776
        )

        torch.nn.init.xavier_uniform_(self.classifier[0].weight)
        torch.nn.init.xavier_uniform_(self.classifier[3].weight)

        self.accuracy = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        with torch.no_grad():
            x = self.features(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        z = self.forward(x)
        loss = nn.CrossEntropyLoss()(z, y)
        self.log("train_loss", loss, prog_bar=True)

        preds = torch.argmax(z, dim=1)
        self.accuracy(preds, y)
        self.log('train_acc_step', self.accuracy, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        z = self.forward(x)
        val_loss = nn.CrossEntropyLoss()(z, y)
        self.log("val_loss", val_loss, prog_bar=True)

        preds = torch.argmax(z, dim=1)
        self.accuracy(preds, y)
        self.log('val_acc', self.accuracy, on_step=False, on_epoch=True)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        z = self.forward(x)
        test_loss = nn.CrossEntropyLoss()(z, y)
        self.log("test_loss", test_loss, prog_bar=True)

        preds = torch.argmax(z, dim=1)
        self.accuracy(preds, y)
        self.log('test_acc', self.accuracy, on_step=False, on_epoch=True)
        return test_loss

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        return self.forward(batch)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

In [55]:
mnet3_model = S2CellClassifierMnet3(num_classes=len(mapping))

# Quick test run
L.Trainer(
    accelerator=TORCH_ACCELERATOR,
    fast_dev_run=True,
).fit(model=mnet3_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------------------------
0 | features   | Sequential         | 3.0 M 
1 | avgpool    | AdaptiveAvgPool2d  | 0     
2 | classifier | Sequential         | 5.6 M 
3 | accuracy   | MulticlassAccuracy | 0     
--------------------------------------------------
8.6 M     Trainable params
0         Non-trainable params
8.6 M     Total params
34.316    Total estimated model params size (MB)


Epoch 0: 100%|███████████████████████████████████████| 1/1 [00:40<00:00, 40.05s/it, train_loss=7.520, train_acc_step=0.000, val_loss=7.570]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|███████████████████████████████████████| 1/1 [00:40<00:00, 40.05s/it, train_loss=7.520, train_acc_step=0.000, val_loss=7.570]


In [5]:
# Full training (resuming from checkpoint)
mnet3_model = S2CellClassifierMnet3.load_from_checkpoint("checkpoints/s2cell_predict/version1.ckpt", num_classes=len(mapping))
trainer = L.Trainer(
    accelerator=TORCH_ACCELERATOR,
    callbacks=[
        L.pytorch.callbacks.ModelCheckpoint(
            monitor="val_acc",
            mode="max",
            save_last=True,
            save_top_k=5,
        ),
        L.pytorch.callbacks.EarlyStopping(
            patience=5,
            monitor="val_acc",
            mode="max",
            verbose=True,
        ),
    ],
)
trainer.fit(model=mnet3_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
print("Training done")

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 21.1M/21.1M [00:00<00:00, 358MB/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA A10') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /home/ubuntu/img2loc/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------

Epoch 0: : 426it [10:37,  1.50s/it, v_num=0, train_loss=5.360, train_acc_step=0.108, val_loss=7.270]                                       

Metric val_acc improved. New best score: 0.083


Epoch 1: : 426it [10:43,  1.51s/it, v_num=0, train_loss=4.860, train_acc_step=0.0591, val_loss=7.680] 

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.088


Epoch 2: : 426it [10:35,  1.49s/it, v_num=0, train_loss=5.190, train_acc_step=0.134, val_loss=6.950] 

Metric val_acc improved by 0.014 >= min_delta = 0.0. New best score: 0.102


Epoch 3: : 426it [07:19,  1.03s/it, v_num=0, train_loss=4.520, train_acc_step=0.183, val_loss=6.950]  

Metric val_acc improved by 0.007 >= min_delta = 0.0. New best score: 0.108


Epoch 5: : 426it [10:38,  1.50s/it, v_num=0, train_loss=4.640, train_acc_step=0.113, val_loss=7.380] 

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.113


Epoch 6: : 426it [10:37,  1.50s/it, v_num=0, train_loss=4.320, train_acc_step=0.0806, val_loss=7.470]

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.118


Epoch 8: : 426it [10:37,  1.50s/it, v_num=0, train_loss=4.380, train_acc_step=0.161, val_loss=7.070] 

Metric val_acc improved by 0.004 >= min_delta = 0.0. New best score: 0.122


Epoch 9: : 426it [10:34,  1.49s/it, v_num=0, train_loss=4.330, train_acc_step=0.215, val_loss=7.480] 

Metric val_acc improved by 0.006 >= min_delta = 0.0. New best score: 0.128


Epoch 10: : 426it [10:38,  1.50s/it, v_num=0, train_loss=4.500, train_acc_step=0.140, val_loss=7.620] 

Metric val_acc improved by 0.002 >= min_delta = 0.0. New best score: 0.130


Epoch 11: : 426it [10:32,  1.48s/it, v_num=0, train_loss=4.450, train_acc_step=0.172, val_loss=7.040] 

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.135


Epoch 12: : 426it [10:43,  1.51s/it, v_num=0, train_loss=3.750, train_acc_step=0.231, val_loss=7.140] 

Metric val_acc improved by 0.003 >= min_delta = 0.0. New best score: 0.138


Epoch 13: : 426it [10:41,  1.51s/it, v_num=0, train_loss=3.920, train_acc_step=0.263, val_loss=7.300] 

Metric val_acc improved by 0.004 >= min_delta = 0.0. New best score: 0.142


Epoch 15: : 426it [10:33,  1.49s/it, v_num=0, train_loss=3.330, train_acc_step=0.339, val_loss=7.230] 

Metric val_acc improved by 0.002 >= min_delta = 0.0. New best score: 0.144


Epoch 16: : 426it [10:38,  1.50s/it, v_num=0, train_loss=4.650, train_acc_step=0.0753, val_loss=7.130]

Metric val_acc improved by 0.000 >= min_delta = 0.0. New best score: 0.144


Epoch 17: : 426it [10:43,  1.51s/it, v_num=0, train_loss=3.440, train_acc_step=0.344, val_loss=7.150] 

Metric val_acc improved by 0.005 >= min_delta = 0.0. New best score: 0.149


Epoch 18: : 426it [10:41,  1.51s/it, v_num=0, train_loss=3.380, train_acc_step=0.312, val_loss=7.470] 

Metric val_acc improved by 0.007 >= min_delta = 0.0. New best score: 0.157


Epoch 19: : 426it [10:30,  1.48s/it, v_num=0, train_loss=3.340, train_acc_step=0.376, val_loss=7.630] 

Metric val_acc improved by 0.001 >= min_delta = 0.0. New best score: 0.158


Epoch 22: : 426it [07:19,  1.03s/it, v_num=0, train_loss=4.140, train_acc_step=0.194, val_loss=7.510] 



In [27]:
# Run inference on test set. Measure accuracy using PlaNet criteria
def to_img_label_test(sample):
    img, meta = sample
    label = (meta["latitude"], meta["longitude"])
    return VAL_T(img), label

def predict_all(model):
    test_dataset = wds.WebDataset(str(Path.home() / "datasets" / "im2gps3ktest" / "wds" / "im2gps3ktest_000.tar"))
    test_dataset = test_dataset.decode("pil").to_tuple("jpg", "json")\
        .map(to_img_label_test)\
        .batched(1)
    test_dataloader = wds.WebLoader(test_dataset, batch_size=None, num_workers=NUM_WORKERS)
    #raw_predictions = L.Trainer(accelerator=TORCH_ACCELERATOR).predict(model, test_dataset)
    #return raw_predictions

    results = {
        "pred_token": [],
        "true_lat": [],
        "true_lng": [],
    }

    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(test_dataloader):
            logits = model.predict_step(inputs, 0)
            pred = torch.argmax(logits)
            pred_s2cell = mapping.get_name(pred.item())

            results["pred_token"].append(pred_s2cell)
            results["true_lat"].append(targets[0][0])
            results["true_lng"].append(targets[0][1])
    return pandas.DataFrame(results)

map_location = {}
if TORCH_ACCELERATOR == "cpu":
    map_location = {"map_location": torch.device("cpu")}
mnet3_model = S2CellClassifierMnet3.load_from_checkpoint("checkpoints/s2cell_predict/version2.ckpt", num_classes=len(mapping), **map_location)
results = predict_all(mnet3_model)
results

2997it [01:08, 43.59it/s]


Unnamed: 0,pred_token,true_lat,true_lng
0,32d5,-33.940916,18.374647
1,12cd,50.484599,5.890045
2,4784c,34.022502,77.603302
3,48bd,37.807359,-122.469470
4,3e5f,51.499473,-0.119862
...,...,...,...
2992,151d,-13.311708,48.115938
2993,12cd,38.785535,121.141777
2994,32d5,51.121411,2.622985
2995,31db,-37.814666,144.954986


In [38]:
# Compute accuracy according to PlaNet criteria
# Street = 1km
# City = 25km
# Region = 200km
# Country = 750km
# Continent = 2500km

def compute_accuracy_planet(results):
    accuracy_planet = {
        "street": 0,
        "city": 0,
        "region": 0,
        "country": 0,
        "continent": 0,
    }

    for r in results.itertuples():
        pred_center = s2sphere.Cell(s2sphere.CellId.from_token(r.pred_token)).get_center()
        pred_latlng = s2sphere.LatLng.from_point(pred_center)
        true_latlng = s2sphere.LatLng.from_degrees(r.true_lat, r.true_lng)
        angle = true_latlng.get_distance(pred_latlng)
        distance_km = 6378.0 * 2 * math.pi * angle.radians
        
        if distance_km <= 1:
            accuracy_planet["street"] += 1
        if distance_km <= 25:
            accuracy_planet["city"] += 1
        if distance_km <= 200:
            accuracy_planet["region"] += 1
        if distance_km <= 750:
            accuracy_planet["country"] += 1
        if distance_km <= 2500:
            accuracy_planet["continent"] += 1

    for k in accuracy_planet.keys():
        accuracy_planet[k] /= 2997.0
        accuracy_planet[k] *= 100.0 # percent

    return accuracy_planet

compute_accuracy_planet(results)

{'street': 0.033366700033366704,
 'city': 1.6016016016016015,
 'region': 3.003003003003003,
 'country': 4.337671004337671,
 'continent': 7.374040707374041}

In [39]:
# Compute max possible accuracy with our current set of s2 cells
all_s2_cell_centers = {
    token: s2sphere.LatLng.from_point(s2sphere.Cell(s2sphere.CellId.from_token(token)).get_center())
    for token in mapping.name_to_label.keys()
}

def best_cell_for(lat, lng):
    true_latlng = s2sphere.LatLng.from_degrees(lat, lng)
    all_distances = [
        (token, true_latlng.get_distance(center).radians)
        for token, center in all_s2_cell_centers.items()
    ]
    return min(all_distances, key=lambda x: x[1])[0]

def perfect_test_results():
    test_dataset = wds.WebDataset(str(Path.home() / "datasets" / "im2gps3ktest" / "wds" / "im2gps3ktest_000.tar"))
    test_dataset = test_dataset.decode("pil").to_tuple("jpg", "json")\
        .map(to_img_label_test)\
        .batched(1)
    test_dataloader = wds.WebLoader(test_dataset, batch_size=None, num_workers=NUM_WORKERS)

    results = {
        "pred_token": [],
        "true_lat": [],
        "true_lng": [],
    }

    for inputs, targets in tqdm.tqdm(test_dataloader):
        pred_s2cell = best_cell_for(targets[0][0], targets[0][1])
        results["pred_token"].append(pred_s2cell)
        results["true_lat"].append(targets[0][0])
        results["true_lng"].append(targets[0][1])
    return pandas.DataFrame(results)

compute_accuracy_planet(perfect_test_results())

2997it [00:52, 56.72it/s]


{'street': 0.9009009009009009,
 'city': 24.491157824491157,
 'region': 49.249249249249246,
 'country': 93.02635969302636,
 'continent': 97.7977977977978}