In [2]:
import os
import yaml
import numpy as np
import h3
import geopandas as gpd
import torch
from tqdm import tqdm
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from shapely.geometry import Polygon

import vectorgeo.constants as c
import vectorgeo.transfer as transfer
from vectorgeo.h3_utils import H3GlobalIterator
from vectorgeo.landcover import LandCoverPatches

# Parameters
wipe_qdrant = False
inference_batch_size = 32
h3_resolution = 7
image_size = 32
model_filename = "resnet-triplet-lc.pt"
embed_dim = 16
seed_latlng = (47.475099, -122.170557)  # Seattle, WA
max_iters = None
qdrant_collection = c.QDRANT_COLLECTION_NAME
device = 'cuda'


# Load secrets
secrets = yaml.load(open(os.path.join(c.BASE_DIR, '.secrets.yml')), Loader=yaml.FullLoader)

# Download world geometry
world_path = os.path.join(c.TMP_DIR, 'world.gpkg')
transfer.download_file('misc/world.gpkg', world_path)
world_gdf = gpd.read_file(world_path)
world_geom = world_gdf.iloc[0].geometry.simplify(0.1)

qdrant_client = QdrantClient(
        url=secrets['qdrant_url'],
        api_key=secrets['qdrant_api_key']
    )

# Wipe Qdrant collection if needed
if wipe_qdrant:
    print(f"Wiping Qdrant collection {qdrant_collection}")
    
    qdrant_client.recreate_collection(
        collection_name=qdrant_collection,
        vectors_config=VectorParams(size=embed_dim, distance=Distance.DOT),
    )

# Load the PyTorch model
key = f"models/{model_filename}"
local_model_path = os.path.join(c.TMP_DIR, model_filename)
transfer.download_file(key, local_model_path)
model = torch.load(local_model_path).to(device)
model.eval()
print(f"Loaded model from {key}")

# Download land cover data
lc_key = 'raw/' + c.COPERNICUS_LC_KEY
transfer.download_file(lc_key, c.LC_LOCAL_PATH)
lcp = LandCoverPatches(c.LC_LOCAL_PATH, world_gdf, image_size, full_load=False)

# Initialize H3 iterator
state_filepath = os.path.join(c.TMP_DIR, c.H3_STATE_FILENAME)
try:
    transfer.download_file('misc/h3-state.json', state_filepath)
except Exception as e:
    print(f"Encountered exception {e} while downloading state file")
    print("No state file found; starting from scratch")

iterator = H3GlobalIterator(seed_latlng[0], seed_latlng[1], h3_resolution, state_file=state_filepath)
int_map       = {x: i for i, x in enumerate(c.LC_LEGEND.keys())}
int_map_fn    = np.vectorize(int_map.get)

# Main inference loop
h3_batch = []
xs_batch = []
h3s_processed = set()

for i, cell in enumerate(tqdm(iterator)):
    if i % 5000 == 0:
        print(f"Processing cell {i}: {cell}")
        iterator.save_state()
        transfer.upload(c.H3_STATE_KEY, state_filepath)        

    if max_iters and i >= int(max_iters):
        print(f"Reached max_iters {max_iters}; stopping")
        break

    poly = Polygon((x, y) for y, x in h3.h3_to_geo_boundary(cell))

    if not world_geom.intersects(poly):
        h3s_processed.add(cell)
        continue

    xs = int_map_fn(lcp.h3_to_patch(cell))

    xs_one_hot = np.zeros((c.LC_N_CLASSES, image_size, image_size))

    for i in range(c.LC_N_CLASSES):
        xs_one_hot[i] = (xs == i).squeeze().astype(int)

    h3_batch.append(cell)
    xs_batch.append(xs_one_hot)

    if len(h3_batch) >= inference_batch_size:
        xs_one_hot_tensor = torch.tensor(np.stack(xs_batch, axis=0), dtype=torch.float32).to(device)
        with torch.no_grad():
            zs_batch = model(xs_one_hot_tensor).cpu().numpy().squeeze().tolist()

        coords = [h3.h3_to_geo(h3_index) for h3_index in h3_batch]
        lats, lngs = zip(*coords)

        _ = qdrant_client.upsert(
            collection_name=qdrant_collection,
            wait=True,
            points=[PointStruct(
                id=int("0x" + id, 0),
                vector=vector,
                payload={"location": {"lon": lng, "lat": lat}}
            ) for id, vector, lng, lat in zip(h3_batch, zs_batch, lngs, lats)]
        )
        h3s_processed = h3s_processed.union(set(h3_batch))
        h3_batch = []
        xs_batch = []


File /home/ubuntu/vectorgeo/tmp/world.gpkg already exists; skipping download
File /home/ubuntu/vectorgeo/tmp/resnet-triplet-lc.pt already exists; skipping download
Loaded model from models/resnet-triplet-lc.pt
File /home/ubuntu/vectorgeo/tmp/PROBAV_LC100_global_v3.0.1_2019-nrt_Discrete-Classification-map_EPSG-4326.tif already exists; skipping download
File /home/ubuntu/vectorgeo/tmp/h3-state.json already exists; skipping download


2it [00:00, 18.35it/s]

Processing cell 0: 872d60228ffffff


5002it [04:33, 18.68it/s]

Processing cell 5000: 872d604a6ffffff


7471it [06:48, 18.29it/s]


KeyboardInterrupt: 

### Convert borders shapefile to geopackage

In [2]:
import geopandas as gpd

gdf = gpd.read_file('lql-data/misc/WB_countries_Admin0_10m/WB_countries_Admin0_10m.shp')
new_geoms = [gdf.unary_union]

world_gdf = gpd.GeoDataFrame(geometry=new_geoms, crs=gdf.crs)
world_gdf.to_file('lql-data/misc/world.gpkg', driver='GPKG')

CPLE_AppDefinedError: b'sqlite3_exec(CREATE TRIGGER "trigger_delete_feature_count_world" AFTER DELETE ON "world" BEGIN UPDATE gpkg_ogr_contents SET feature_count = feature_count - 1 WHERE lower(table_name) = lower(\'world\'); END;) failed: disk I/O error'

Exception ignored in: 'fiona.ogrext.gdal_flush_cache'
Traceback (most recent call last):
  File "fiona/_err.pyx", line 198, in fiona._err.GDALErrCtxManager.__exit__
fiona._err.CPLE_AppDefinedError: b'sqlite3_exec(CREATE TRIGGER "trigger_delete_feature_count_world" AFTER DELETE ON "world" BEGIN UPDATE gpkg_ogr_contents SET feature_count = feature_count - 1 WHERE lower(table_name) = lower(\'world\'); END;) failed: disk I/O error'
