In [1]:
import os, re
import base64
import requests
from mistralai import Mistral

In [2]:
api_key = open("../keys/MISTRAL_API_KEY.txt", "r").read().strip("\n")
os.environ['MISTRAL_API_KEY'] = api_key

In [3]:
# Retrieve the API key from environment variables
api_key = os.environ["MISTRAL_API_KEY"]

# Specify model
model = "pixtral-12b-2409"

# Initialize the Mistral client
client = Mistral(api_key=api_key)

In [4]:
def encode_image(image_path):
    """Encode the image to base64."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: The file {image_path} was not found.")
        return None
    except Exception as e:  # Added general exception handling
        print(f"Error: {e}")
        return None

In [5]:
def extract_number_from_text(text):
    # Find all numbers in the text
    numbers = re.findall(r'\d+', text)
    if not numbers:
        return None
    # Return the first number found
    return int(numbers[0])

In [6]:
def get_vlm_result(base64_image):
    # Define the messages for the chat
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "You are a satellite imagery analyst. Look at the satellite image and count the number of distinct buildings in the picture. Do not count landscape features or vehicles, only manmade buildings such as houses or other roofed structures. Please provide only the number in numeric format, not an explanation, sentence, or any other text."
                },
                {
                    "type": "image_url",
                    "image_url": f"data:image/jpeg;base64,{base64_image}"
                }
            ]
        }
    ]

    # Get the chat response
    chat_response = client.chat.complete(
        model=model,
        messages=messages,
        temperature=0.1
    )

    try:
        result = extract_number_from_text(chat_response.choices[0].message.content)
    except:
        print("error")

    return result

In [7]:
run_num = '1'
image_dir = "../data/test/cropped_jpg/"
folder_path = 'mistral_outputs/'+run_num

if not os.path.exists(folder_path):
    os.makedirs(folder_path)

for filename in os.listdir(image_dir):
    base64_image = encode_image(image_dir+'/'+filename)
    base_name = filename.split('_pre_disaster.jpg')[0]
    response = get_vlm_result(base64_image)
    if response:
        with open(folder_path+'/'+base_name+'.txt', 'w') as file:
            file.write(str(response))