In [1]:
from plip import PLIP
plip = PLIP('vinid/plip')

import torch
import numpy as np
import json

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
from PIL import Image

# Function to get image keys
# input: json data
# output: list of image keys
def getImageVals(data, images_dir):
    image_keys = []
    images=[]
    all_images = os.listdir(images_dir)
    print(all_images)

    for key in data.keys():
        if key+".jpg" in os.listdir(images_dir):
            images.append(images_dir + "/" + key + ".jpg")
            image_keys.append(key)

    print(images)
    return image_keys, images

images_dir = "./pubmed_set/images"

# Function to get image
def getImage(images_dir, image_name):
    image_path = os.path.join(images_dir, image_name)
    image = Image.open(image_path).convert('RGB')

    return image

# Function to process text captions
# Input: list of captions
# Output: dictionary of image key to caption embedding
def getTextEmbeddings(text):
    text_embedding = plip.encode_text(text, batch_size=32)
    text_embedding = text_embedding/np.linalg.norm(text_embedding, ord=2, axis=-1, keepdims=True)

    return text_embedding

# Function to process image
# Input: list of images
# Output: dictionary of image key to image embedding
def getImageEmbeddings(images):
    image_embeddings = plip.encode_images(images, batch_size=32)
    image_embeddings = image_embeddings/np.linalg.norm(image_embeddings, ord=2, axis=-1, keepdims=True)

    return image_embeddings

# Dictionary of the image key and the given embedding (text or image)
# input: keys --> image_key, embeddings --> list of text or image embeddings
# output: dictionary of the key and the given embedding
def pairEmbeddings(keys, embeddings):
    return dict(zip(keys, embeddings))

# Function to insert embeddings into json
# input: data --> json data file, text_embeddings --> pair of key and text_embedding, image_embeddings --> pair of key and image_embedding
def addEmbeddings(keys, embedding):
    newData = {}
    for key in keys:
        newData[key] = embedding.get(key)
        # data[key]["caption_embedding"] = text_embeddings.get(key)
        # data[key]["image_embedding"] = image_embeddings.get(key)
        # print(data[key])
    return newData

In [3]:
# Get data from json file
with open('../pubmed_captions.json', 'r') as f:
    data = json.load(f)
print(data)

{'3f93c716-8fc9-42e9-bc29-bec52a51ab4b': 'ER expression in tumor tissue. IHC staining, original', '9fcdf1e1-139c-4b63-bf1a-79d83c71f41a': 'Nuclear expression of TS (brown) in a colon carcinoma', '00f1ad7a-f4b0-4938-b874-089d40a123ce': 'Nuclear expression of E2F1 (brown) in a colon carcinoma. This is higher magnification of the upper portion of a core shown in an inset (lower left corner)', '9d3aef30-7c8b-4b78-9acf-ec523f952650': 'Cytoplasmic immunoexpression of PD-L1 in oral squamous cell carcinomas with poorer prognosis (OSCCPP). Immunohistochemistry. Total magnification x100', 'b317d529-3626-49fc-9282-e4f28cf3d1cb': 'Nuclear and perinuclear immunoexpression of Foxp3 in oral squamous cell carcinomas with poorer prognosis (OSCCPP). Immunohistochemistry. Total magnification x100', 'e9e99cb6-f795-4c5d-9d66-a8e81475b934': 'Cytoplasmic immunoexpression of PD-L1 in oral squamous cell carcinomas with better prognosis (OSCCBP). Immunohistochemistry. Total magnification x100', 'c707e670-51d0-4

In [4]:

# Add all embeddings to json file
# Note: keys, captions, and image_paths are all in the order of the json file
keys, image_paths = getImageVals(data, "./pubmed_set/images")
# print(image_paths)

['0011f470-940b-4c4a-91eb-a8dbe75ae109.jpg', '0064fd71-1a87-47f1-bb5a-ea0b6dd575f8.jpg', '006eff37-6549-4645-ba06-955836f8ab6c.jpg', '00853952-d6e7-40ff-91b1-b42646b91dc9.jpg', '0094dd3c-2284-4c37-8ace-a950bb1bcad0.jpg', '00ce12d4-1a47-4d59-ab60-ea85cae284d0.jpg', '00ea125a-3f78-4457-b19c-7539d9045996.jpg', '00f1ad7a-f4b0-4938-b874-089d40a123ce.jpg', '00fc972e-974d-4bc9-a4b5-ab9c919b536a.jpg', '0102aba6-6c81-4f3f-b5ef-b59090530da0.jpg', '010b8ee6-3085-42be-9829-496e6b9c4044.jpg', '015a1ed1-3b12-47b5-9ac3-452fab376f81.jpg', '0169872f-24e4-4780-b27e-2a69292eecf7.jpg', '016b6192-9354-4581-a2a8-c85fe9a995c7.jpg', '01908db0-804c-4916-9578-e6efc24e5e62.jpg', '0197e88b-ed1c-47f1-b205-7a52f85d7528.jpg', '019cb835-968b-4364-a1dd-8c3a2a653100.jpg', '019e3eb8-8f2a-49fd-85f6-29d585bb43fe.jpg', '01b4b1a6-6457-4714-ac88-ca0cd52d06e8.jpg', '01bf4aa0-f917-48a2-808f-7d3806f69652.jpg', '01c52d66-d256-4630-8b9f-2f2ea0471182.jpg', '01d19bc9-4495-4e2f-9ec1-87baefca3405.jpg', '01dfcfed-2fe5-46c1-9e65-13f4a0

In [5]:
# get caption embeddings
captions = [data[key] for key in data.keys()]
text_embeddings = [embedding.tolist() for embedding in getTextEmbeddings(captions)]

Map: 100%|██████████| 3309/3309 [00:00<00:00, 4020.42 examples/s]
104it [01:52,  1.08s/it]                         


In [6]:
# get image embeddings
image_embeddings = getImageEmbeddings(image_paths)
image_embeddings = [embedding.tolist() for embedding in image_embeddings]
# print(image_embeddings[0])


103it [03:26,  2.01s/it]                         


In [None]:
print(len(data.keys()))
print(len(keys))
# print(image_embeddings)
# print(image_embeddings[0])

3309
3272
[0.0699537992477417, -0.045385975390672684, -0.017583969980478287, 0.024916479364037514, -0.057095229625701904, -0.01816147193312645, 0.029679380357265472, -0.00893434602767229, 0.030276937410235405, 0.02301786281168461, 0.0012672533048316836, 0.04457440227270126, 0.054404329508543015, -0.01820065639913082, 0.01296603586524725, 0.016232233494520187, 0.08582457154989243, 0.04170924052596092, -0.001555098220705986, 0.041003335267305374, -0.025735968723893166, -0.01338905468583107, 0.019370773807168007, -0.018465779721736908, 0.04437369480729103, -0.016607187688350677, -0.017885493114590645, 0.0010988175636157393, 0.0009069604566320777, -0.025912277400493622, 0.020982669666409492, -0.006900251843035221, -0.010657611303031445, -0.024689119309186935, 0.025223882868885994, -0.018039343878626823, -0.007163587026298046, -0.04945336654782295, -0.009569249115884304, -0.06295471638441086, -0.03010929375886917, 0.006286997348070145, 0.030195264145731926, -0.02833070419728756, -0.02237260

In [None]:
# len(image_embeddings)

3272

In [9]:
text_data = dict(zip(data.keys(), text_embeddings))
image_data = dict(zip(keys, image_embeddings))


In [10]:
# create new file with embeddings
with open('plip_text_embeddings.json', 'w') as f:
    json.dump(text_data, f, indent=4)

In [11]:
with open('plip_image_embeddings.json', 'w') as f:
    json.dump(image_data, f, indent=4)