# Get ResNet-50 embeds to photos coresponding to OSM

In [None]:
import os
import glob
import rasterio
import geopandas as gpd
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import models
from rasterio.mask import mask as rasterio_mask
from shapely.geometry import box
from PIL import Image
import sys
import pandas as pd


def get_vector_embeddings(input_folder, vector_data_path, output_folder, crs_epsg):
    tif_files = glob.glob(os.path.join(input_folder, "*.tif"))
    
    gdf = gpd.read_file(vector_data_path).to_crs(epsg=crs_epsg)
    

    resnet50 = models.resnet50(pretrained=True)
    resnet50 = torch.nn.Sequential(*list(resnet50.children())[:-1])
    resnet50.eval()
    
    def preprocess_image(image):
        transform = transforms.Compose([
            transforms.Resize((224, 224)), 
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return transform(image).unsqueeze(0)
    
    def extract_features(image):
        image_tensor = preprocess_image(image)
        with torch.no_grad():
            features = resnet50(image_tensor)
        return features.squeeze().numpy()


    def is_relevant_tiff(tif_path, gdf):
        with rasterio.open(tif_path) as src:
            bounds = box(*src.bounds)
            return gdf.intersects(bounds).any()
    

    relevant_tifs = [tif for tif in tif_files if is_relevant_tiff(tif, gdf)]
    
    gdfs = []
    
    for i, tif_file in enumerate(relevant_tifs):
        print(f'Processing: {tif_file}')
        
        with rasterio.open(tif_file) as src:
            crs = src.crs 
            
            gdf_crs = gdf.crs
            if gdf.crs != crs:
                gdf = gdf.to_crs(crs)
    
            def embed_building(poly):
                try:
                    cropped_image, _ = rasterio_mask(src, [poly], crop=True)
                    cropped_image = np.transpose(cropped_image, (1, 2, 0))
                    cropped_pil = Image.fromarray(cropped_image.astype(np.uint8))
                    return extract_features(cropped_pil)
                except Exception as e:
                    return None
    
            gdf['resnet50_embed'] = gdf['geometry'].apply(lambda poly: embed_building(poly))
            gdf_to_emb = gdf[gdf['resnet50_embed'].notna()]
            gdf_no_emb = gdf[gdf['resnet50_embed'].isna()]
            emb_df = gdf_to_emb['resnet50_embed'].apply(pd.Series)
            emb_df = emb_df.astype('str')
            emb_df['resnet50_emb'] = emb_df.select_dtypes(include=['object']).apply(lambda row: ','.join(row.dropna().astype(str)), axis=1)
            emb_df = emb_df[['resnet50_emb']]
    
            gdf_to_emb = gdf_to_emb.drop(columns=['resnet50_embed'])
            
            gdf_expanded = gdf.join(emb_df, how='inner')
    
            gdf_expanded.set_crs(gdf.crs, inplace=True)
            
            gdf_expanded.to_crs(epsg=4326).to_file(f"{output_folder}/{tif_file.split('/')[-1].split('.')[0]}.geojson", driver='GeoJSON')

In [None]:
datasets = [
    {'input_folder': '../data/EE/orto/2020', 'vector_data_path': '../data/EE/buildings/osm_buildings_ee.geojson',
     'output_folder': '../data/ee_2020_out', 'crs_epsg': 3301},

    {'input_folder': '../data/EE/orto/2024', 'vector_data_path': '../data/EE/buildings/osm_buildings_ee.geojson',
     'output_folder': '../data/ee_2024_out', 'crs_epsg': 3301},

    {'input_folder': '../data/LT/orto/2020', 'vector_data_path': '../data/LT/buildings/osm_buildings_lt.geojson',
     'output_folder': '../data/lt_2020_out', 'crs_epsg': 3346},

    {'input_folder': '../data/LT/orto/2024', 'vector_data_path': '../data/LT/buildings/osm_buildings_lt.geojson',
     'output_folder': '../data/lt_2024_out', 'crs_epsg': 3346},
]

for dataset in datasets:
    get_vector_embeddings(**dataset)