In [None]:
import msgpack
import glob
import os

from PIL import Image
from io import BytesIO

import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

#import requests
from geopy.geocoders import Nominatim


dataset_dir = "/kaggle/input/large-dataset-of-geotagged-images/shards/"
shard_fnames = [dataset_dir + f"shard_{i}.msg" for i in range(0, 143)]

def get_image(record):
    return Image.open(BytesIO(record["image"]))

def get_country(lat, lon, geocoder=None):
    if geocoder:
        try:
            location = geocoder.reverse(f"{lat}, {lon}", language="en")
            if location:
                return location.raw['address'].get('country')
        except Exception as e:
            print(f"Geocoding error for ({lat}, {lon}): {e}")
    # Fallback (optional): Assign country based on distance (replace with your logic)
    # distances = {...}  # Pre-calculate distances to known country centroids
    # closest_country = min(distances, key=distances.get)
    # return closest_country  # Replace with actual country assignment
    return None  # Indicate failure if no geocoder or fallback

def save_image_and_ginfo(record, df, geocoder, img_id):
    lat, lon = record["latitude"], record["longitude"]
    country = get_country(lat, lon, geocoder)
    if not country:
        return df
    image = get_image(record)
    image_path = country
    if not os.path.exists(image_path):
        os.makedirs(image_path)
    image.save(f"{image_path}/{img_id}.jpg")
    new_row = pd.DataFrame({'image_id': [img_id], 'image_path': [image_path], 'latitude': [lat], 'longitude': [lon], 'country': [country]})
    df = pd.concat([df, new_row], ignore_index=True)
    return df


df = pd.DataFrame(columns=['image_id', 'image_path', 'latitude', 'longitude', 'country'])
geocoder = Nominatim(user_agent="my-application1")

for i, shard_fname in enumerate(shard_fnames[:1]):
    print(f"Processing {shard_fname.split('/')[-1]}")
    with open(shard_fname, "rb") as infile:
        for j, record in enumerate(tqdm(msgpack.Unpacker(infile, raw=False))):
            img_id = str(1000000*(i+1) + j)
            df = save_image_and_ginfo(record, df, geocoder, img_id)
            image = get_image(record)

print(df.head)
df.to_csv("processed_images_info.csv", index=False)