In [2]:
import ollama
import pandas as pd
from tqdm import tqdm
from grounded_sam.inference import grounded_segmentation
from grounded_sam.plot import plot_detections_plotly, plot_detections
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
collection_df = pd.read_csv('./data/heritage_weaver_data.csv')
collection_df.head()

Unnamed: 0.1,Unnamed: 0,record_id,name,description,category,taxonomy,img_loc,img_name,img_path,downloaded,img_url,collection
0,0,co8413501,audio cassette,BBC Radio Manchester Audiopak.\n\nBroadcast ca...,Radio Communication,audio tape; sound recording; <information arte...,204/255/medium_cd0620_049_100527_2005_86_35_Pr...,204|255|medium_cd0620_049_100527_2005_86_35_Pr...,smg_imgs/204|255|medium_cd0620_049_100527_2005...,True,https://coimages.sciencemuseumgroup.org.uk/204...,smg
1,1,co8413501,audio cassette,BBC Radio Manchester Audiopak.\n\nBroadcast ca...,Radio Communication,audio tape; sound recording; <information arte...,477/975/medium_SMG00247371.jpg,477|975|medium_SMG00247371.jpg,smg_imgs/477|975|medium_SMG00247371.jpg,True,https://coimages.sciencemuseumgroup.org.uk/477...,smg
2,2,co5606,nuclear fuel,Fuel element (without fuel) from the Dounreay ...,Nuclear Energy,,58/255/medium_1982_1712__0001_.jpg,58|255|medium_1982_1712__0001_.jpg,smg_imgs/58|255|medium_1982_1712__0001_.jpg,True,https://coimages.sciencemuseumgroup.org.uk/58/...,smg
3,3,co5606,nuclear fuel,Fuel element (without fuel) from the Dounreay ...,Nuclear Energy,,58/256/medium_1982_1712__0002_.jpg,58|256|medium_1982_1712__0002_.jpg,smg_imgs/58|256|medium_1982_1712__0002_.jpg,True,https://coimages.sciencemuseumgroup.org.uk/58/...,smg
4,4,co5606,nuclear fuel,Fuel element (without fuel) from the Dounreay ...,Nuclear Energy,,58/257/medium_1982_1712__0003_.jpg,58|257|medium_1982_1712__0003_.jpg,smg_imgs/58|257|medium_1982_1712__0003_.jpg,True,https://coimages.sciencemuseumgroup.org.uk/58/...,smg


In [4]:
df_reduced = collection_df.drop_duplicates(subset=['description','img_path'])
df_reduced.shape, collection_df.shape

((27690, 12), (28476, 12))

In [5]:
def get_completion_llava(system, prompt, img_path):
  stream = ollama.generate(
    model='llava',
    system=system,
    prompt=prompt,
    images=[img_path],
    stream=True,
    
  )
  completion = ''
  for chunk in stream:
    completion+=chunk['response']
  return completion

In [6]:
tqdm.pandas()
system_message = """
    You are a helpful AI that will assist me extracting keywords from image descriptions. 
    Keywords are returned as a comma separated list.

    """
user_message = """
    What do you see in this image? 
    Please extract two or three keywords based on what you see in the image. 
    Keywords should refer to objects and things that are visible in the image.
    Use only keywords from the description provided below between triple hashtags.
    Return keywords in a comma separated list, such as in the following example:
    keys, metal object, copper string.
    """

responses = []

def run_llava(x, system_message, user_message):
  try:
    prompt = user_message + f"\n\n###{x['description']}###"

    return get_completion_llava(system_message, prompt, x['img_path'])
  except Exception as e:
    return None




In [7]:
def segment_image(row,threshold = .5):
    image = Image.open(row['img_path'])
    image_x, image_y = image.size[0],image.size[1]
    image.resize((int(image_x/3), int(image_y/3)))
    labels = [x.strip().strip('#').strip('"').lower() for x in row['keywords'].split(',')]
    image_array, detections = grounded_segmentation(image, labels, threshold=threshold, polygon_refinement=True)
    return image_array, detections

In [8]:
df_sample = df_reduced.sample(n=10)
print('Extracting keywords and segmenting images...')
df_sample['keywords'] = df_sample.progress_apply(lambda x: run_llava(x, system_message, user_message), axis=1)
df_sample['segmentation'] = df_sample.progress_apply(lambda x: segment_image(x), axis=1)

Extracting keywords and segmenting images...


100%|██████████| 10/10 [01:23<00:00,  8.36s/it]
100%|██████████| 10/10 [05:11<00:00, 31.16s/it]


In [14]:
idx = 9
image_array, detections = df_sample.iloc[idx]['segmentation']
plot_detections_plotly(image_array, detections)


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed