In [None]:
import pathlib
import textwrap
import json
import os
import google.generativeai as genai
import yaml
import argparse

import PIL.Image

# Heatmap
import folium
from folium.plugins import HeatMap

from model import get_model
from method import *

def main(args, config_path):
    with open(config_path, 'r') as cfg_file:
        cfg = yaml.safe_load(cfg_file)
    api_key = cfg['model'].get('api_key', None)
    model_name = cfg['model']['name']
    method = cfg['method']['name']
    round = cfg['method'].get('round', 1)
    model = get_model(model_name)

    with open(cfg['prompt'], 'r') as prompt_file:
        prompt = prompt_file.read()

    # if round > 1:
    #     prompt = Top5.get_prompt(ds[model_name][method], round)

    img = PIL.Image.open(args)

    try:
        res = model(prompts=[img, prompt], api_key=api_key)
    except:
        print(f'Model inference error')
        return
    
    field = 'output' if round == 1 else 'street_num'
    ds = {model_name: {method: {field: res}}}

    text = ds[model_name][method]['output']

    ds[model_name][method].update(METHOD_DICT[method](text).to_dict())

    latitude = ds[model_name][method]['latitude']
    longitude = ds[model_name][method]['longitude']
    location = ds[model_name][method]['location']

    print(f'This image was taken at {location}, {latitude}°, {longitude}° ')
    return latitude, longitude

In [2]:
config_path = "./config.yaml"
latitude, longitude = main('./images/demo.jpg', config_path)

m = folium.Map(location=[latitude, longitude], zoom_start=2.2)

weighted_coordinates = [(latitude, longitude, 1.)]

# Define the color gradient
magma = {
    0.0: '#932667',
    0.2: '#b5367a',
    0.4: '#d3466b',
    0.6: '#f1605d',
    0.8: '#fd9668',
    1.0: '#fcfdbf'
}

HeatMap(weighted_coordinates, gradient=magma).add_to(m)

gps_coordinates = [latitude, longitude]
# Mark top coordinate
top_coordinate = gps_coordinates

folium.Marker(
    location=top_coordinate,
    popup=f"Top Prediction: {top_coordinate}",
    icon=folium.Icon(color='orange', icon='star')
).add_to(m)

# Display the map
m

This image was taken at Pisa, Italy, 43.723°, 10.3966° 
