In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
import time
from pathlib import Path
import pandas as pd
from PIL import Image
from weavingtools.annotation_tools import *
from weavingtools.linkage_tools import *
from weavingtools.embedding_tools import *
import scipy.spatial as sp
import ipyannotations.generic
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import random
sns.set()

In [3]:
def plot_record_pair(record_pair):
    fig, axes = plt.subplots(1, 2, figsize=(15, 7.5))
    
    for i in range(2):
       
        record = collection_df[collection_df.record_id==record_pair[i]]
        print(record.record_id.values[0])
        img_path = record.img_path.values[0]
        description = soft_wrap_text(record.record_id.values[0] + ' '  + record.description.values[0])
        img = Image.open(img_path)
        #img.resize((250, 2500))
        axes[i].imshow(img)
        axes[i].set_title(description, fontsize = 18)
        axes[i].axis('off')
         
    plt.show()



In [6]:
collection_db = load_db("hw",'heritage_weaver','google/siglip-base-patch16-224')
collection_df = pd.read_csv('data/heritage_weaver_data.csv')
out_path = Path('annotations')
out_path.mkdir(exist_ok=True)

## Retrieval Based on Text Prompts

In [38]:

top_n = 20
annotator = 'KB'

In [39]:
query = 'a radio transistor'
experiment = '1'

In [40]:
# retrieve records
if experiment == '1':
    filters = {'modality':'image'}
elif experiment == '2':
    filters = {'modality':'text'}
elif experiment == '3':
    filters = {}
results = collection_db.query(query_texts=[query],where=filters, n_results=top_n)
query_df = get_query_results(results, source='img_path') # collection_df,
inputs = list(query_df.to_records())


In [41]:
labels = []

widget = ipyannotations.generic.ClassLabeller(
        options=['relevant', 'not relevant'], allow_freetext=False,
        display_function=plot_by_record)


def store_annotations(entity_annotation):
    labels.append(entity_annotation)
    try:
        widget.display(inputs.pop(0))
    except IndexError:
        print("Finished.")
        
widget.on_submit(store_annotations)
widget.display(inputs.pop(0))
widget

ClassLabeller(children=(Box(children=(Output(layout=Layout(margin='auto', min_height='50px')),), layout=Layout…

Finished.


In [42]:
query_df.shape,len(labels)

((20, 7), 20)

In [43]:
query_df['labels'] = labels[:top_n]
query_df['query'] = query
query_df['experiment'] = experiment
query_df.to_csv(out_path / f'{annotator}_{time.time()}.csv')
#results = collection_db.query(query_texts=[query],where=filters, n_results=top_n)

In [None]:
# #Visual prompting optional at the moment, integrate later
# idx = 4
# record = collection_df.iloc[idx]
# results = collection_db.query(query_uris=[record.img_path],n_results=top_n, where=filters) # 
# #Image.open(record.img_path)
# query_df = plot_query_results(results, collection_df,source='img_path')

# Link Annotation

In [44]:
annotator = 'KB'
num_annotations = 10
coll1, coll2 = 'smg','nms'
percentile = 99.0  #99.95 | False
threshold = 0.8
randomize = True


In [45]:
experiment = '1'

In [46]:
if experiment == '1':
    modality1, modality2 = 'image','image'
    agg_func = 'max' 
elif experiment == '2':
    modality1, modality2 = 'text','text' 
    agg_func = 'max' 
elif experiment == '3':
    modality1, modality2 = 'text','text' 
    experiment_id = '3'
    agg_func = 'mean' 

edges, image_similarities, inputs = get_edges(collection_db,coll1,coll2, modality1, modality2, agg_func,percentile, threshold ); len(edges)

Get inputs...
7 7
Compute similarities...
--- Get similarities ---
--- Using 0.759627968792046 as threshold ---
--- Aggregate similarities by record ---
--- Threshold similarities and binarize ---
Retrieve edges...


103268

In [47]:
if experiment == '4':
    image_edges, similarities, inputs = get_edges(collection_db,coll1,coll2, 'image','image', 'max', 99.5  , threshold )
    text_edges, similarities, inputs =  get_edges(collection_db,coll1,coll2, 'text','text', 'max', 99.5  , threshold )
    edges = list(set(image_edges).intersection(set(text_edges))); len(edges)

In [48]:
if randomize:
    random.seed(42)
    random.shuffle(edges)
img_pairs = edges[:num_annotations]
to_annotate = img_pairs.copy()

labels = []

widget = ipyannotations.generic.ClassLabeller(
        options=['same object', 'same category', 'same materials','no link'], allow_freetext=True,
        display_function=plot_record_pair)


def store_annotations(entity_annotation):
    labels.append(entity_annotation)
    try:
        widget.display(img_pairs.pop(0))
    except IndexError:
        print("Finished.")
widget.on_submit(store_annotations)
widget.display(img_pairs.pop(0))
widget


ClassLabeller(children=(Box(children=(Output(layout=Layout(margin='auto', min_height='50px')),), layout=Layout…

Finished.


In [49]:
annotations_df = pd.DataFrame(to_annotate, columns=['coll1','coll2'])
annotations_df['labels'] = labels[:num_annotations]
annotations_df['experiment'] = experiment
annotations_df.to_csv(out_path / f'{annotator}_{time.time()}')

# Fin.