In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    pipeline
)
import pandas as pd
import json
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset
import numpy as np
from sklearn.metrics import top_k_accuracy_score

In [None]:
imgsz=1280

In [None]:
def jsonConvert(json_list, key_field):
    result_dict = {}
    for obj in json_list:
        key = obj.get(key_field)
        if key is not None:
            obj_copy = {k: v for k, v in obj.items() if k != key_field}
            result_dict[key] = obj_copy
        else:
            raise KeyError(f"Key '{key_field}' not found in JSON object: {obj}")
    return result_dict

In [None]:
base_model = './llama2_dota'

compute_dtype = getattr(torch, "float16")

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=quant_config,
    device_map={"": 0}
)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
test_data = pd.read_json('../datasets/DOTAv1.5/descriptions/val.json')
pred = pd.read_json(f'./pred{imgsz}.json').drop('description', axis=1).rename(columns={'properties': 'yolo_pred'})
test_data = pd.merge(test_data, pred, on='filename', how='inner')

descriptions = test_data['description'].values

In [None]:
y_true = []
for i, row in test_data.iterrows():
    json_pred = json.loads(str(row['yolo_pred']).replace("'", '"').replace("None", "0"))
    y_true.append(jsonConvert(json_pred, 'class'))

In [None]:
base_prompt = "Genereate the object bounding box properties for a remote sensing image with the following description as JSON only: "
characters_to_remove = '` \n'
translation_table = str.maketrans('', '', characters_to_remove)

errors = 0
y_pred = []
prompts = []
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=546, truncation=True)
for desc in descriptions:
    prompt = base_prompt + str(desc)
    prompts.append({'text':f"<s>[INST] {prompt} [/INST]"})

for result in tqdm(pipe(KeyDataset(prompts, 'text'))):
    try:
        json_only_result = str(result[0]['generated_text'].split('[/INST]')[1]).translate(translation_table).replace("'", '"').replace("None", "0")
        index = json_only_result.find(']')
        json_only_result = json.loads(json_only_result[:index+1])
        y_pred.append(jsonConvert(json_only_result, 'class'))
    except:
        y_pred.append({})
        errors += 1

In [None]:
similarity_mat = []
for y_p in tqdm(y_pred):
    avg_sim = []
    y_p_keys = set(y_p.keys())

    for y_t in y_true:
        y_t_keys = set(y_t.keys())

        total_sim = 0
        for key in y_t_keys:
            try:
                feats = DictVectorizer().fit_transform([y_t[key], y_p[key]])
                similarity = cosine_similarity(feats[0], feats[1])[0][0]
                total_sim += similarity
            except KeyError:
                pass
        avg_sim.append(total_sim/len(y_t_keys))
    similarity_mat.append(avg_sim)

In [None]:
similarity = np.stack(similarity_mat)
y_desc_true = np.arange(len(y_true))

top_k_stats = []
# Description-wise
top_k_stats.append({'k': 1, 'score': top_k_accuracy_score(y_desc_true, similarity, k=1)})
top_k_stats.append({'k': 3, 'score': top_k_accuracy_score(y_desc_true, similarity, k=3)})
top_k_stats.append({'k': 5, 'score': top_k_accuracy_score(y_desc_true, similarity, k=5)})
top_k_stats.append({'k': 10, 'score': top_k_accuracy_score(y_desc_true, similarity, k=10)})
top_k_stats.append({'k': 20, 'score': top_k_accuracy_score(y_desc_true, similarity, k=20)})
top_k_stats.append({'k': 30, 'score': top_k_accuracy_score(y_desc_true, similarity, k=30)})

In [None]:
top_k_stats

In [None]:
errors