In [1]:
import pandas as pd

from geopy.geocoders import Nominatim
from joblib import Parallel, delayed
from tqdm import tqdm
from argparse import Namespace

In [2]:
config = {
    # "fin": "../data/csv_performance_all_models/xlmt_inference_test_set.csv",
    "fin": "../data3/mmendieta/Violence_data/csv_files_global_scale/xlmt_inference_test_set.csv",
    # "fout": "../data/csv_performance_all_models/xlmt_inference_test_set_with_country.csv"
    "fout": "../data3/mmendieta/Violence_data/csv_files_global_scale/xlmt_inference_test_set.csv"
}

args = Namespace(**config)

In [3]:
# read the file
df = pd.read_csv(args.fin, engine='python', on_bad_lines='skip', encoding='utf-8')

In [None]:
df[1:3]

In [4]:
# Cache for storing results
cache = {}

# Function to get country from coordinates with caching
def get_country_cached(geo_x, geo_y):
    # Instantiate geolocator inside the function to avoid pickling issues
    geolocator = Nominatim(user_agent="geo_locator")
    key = (geo_y, geo_x)  # Use latitude and longitude as key
    
    if key in cache:
        return cache[key]  # Return cached result if available
    else:
        try:
            # Reverse geocode to get the country
            location = geolocator.reverse(key, language='en')
            if location and 'country' in location.raw['address']:
                country = location.raw['address']['country']
            else:
                country = "Unknown"
        except Exception as e:
            country = "Unknown"
        # Store result in cache
        cache[key] = country
        return country

In [None]:
# This code takes approximately 40-50 mins to complete (for 10,000 samples)
# Enable progress bar for parallel processing
tqdm.pandas()

# Apply the function in parallel to infer country
df['country'] = Parallel(n_jobs=-1)(
    delayed(get_country_cached)(row['geo_x'], row['geo_y']) for _, row in tqdm(df.iterrows(), total=len(df))
)

In [None]:
# Count the occurrences of each country
country_counts = df['country'].value_counts()

# Display the counts
print("Country Counts:")
print(country_counts)

In [None]:
# Filter rows where the country is "Unknown"
unknown_locations = df[df['country'] == "Unknown"]

# Display the filtered DataFrame
print(f"Number of observations with 'Unknown' country: {len(unknown_locations)}")
print(unknown_locations.head())

# Optionally, save the filtered observations to a CSV file
unknown_locations.to_csv("unknown_locations.csv", index=False)
print("Observations with 'Unknown' country saved to 'unknown_locations.csv'")

In [None]:
# Calculate min and max for geo_x and geo_y
geo_x_min, geo_x_max = df['geo_x'].min(), df['geo_x'].max()
geo_y_min, geo_y_max = df['geo_y'].min(), df['geo_y'].max()

# Print results
print(f"geo_x: min={geo_x_min}, max={geo_x_max}")
print(f"geo_y: min={geo_y_min}, max={geo_y_max}")

In [None]:
# Save the updated DataFrame to a CSV 
df.to_csv(args.fout, index=False)