In [5]:
import os
import sys
sys.path.append('..')
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"

import torch
torch.cuda.device_count()

2

In [6]:
import time
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from einops import repeat
from accelerate import Accelerator
from src.utils import FlamingoProcessor
from huggingface_hub import hf_hub_download
from demo_utils import image_paths, clean_generation
from open_flamingo import create_model_and_transforms

  from .autonotebook import tqdm as notebook_tqdm


In [147]:
accelerator = Accelerator() #when using cpu: cpu=True

device = accelerator.device

device(type='cuda')

In [None]:
print('Loading model...')

# >>> add your local path to Llama-7B (v1) model here:
llama_path = '../models/llama-7b-hf'
if not os.path.exists(llama_path):
    raise ValueError('Llama model not yet set up, please check README for instructions!')

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path=llama_path,
    tokenizer_path=llama_path,
    cross_attn_every_n_layers=4
)
# load med-flamingo checkpoint:
checkpoint_path = hf_hub_download("med-flamingo/med-flamingo", "model.pt", cache_dir="../models/")
print(f'Downloaded Med-Flamingo checkpoint to {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
processor = FlamingoProcessor(tokenizer, image_processor)

# go into eval model and prepare:
model = accelerator.prepare(model)
is_main_process = accelerator.is_main_process
model.eval()

# 전처리

In [18]:
root = Path("/home/cmti/jogi/diag_plz/med-flamingo")

In [77]:
data = {}
dataset = "kstr"
raw_csv = {}
raw_csv[dataset] = root / "data" / dataset / f"{dataset}_data.csv"
data[dataset] = pd.read_csv(raw_csv[dataset], index_col=None)
select_idx = ~data[dataset]['findings'].isna() 
new_data = data[dataset][select_idx].reset_index(drop=True)
new_data['diagnosis'] = new_data['diagnosis'].str.replace(r'\d$', '', regex=True)
new_data.sort_values(by=['id'], inplace=True)
# new_data['img_count'] = new_data.apply(lambda row: len(eval(row['images'])[0]), axis=1)

In [50]:
t = data["kstr"]
for i,row in t.iterrows():
    if row.img_links != '[]':
        a = row.link.split("=")[-1].strip()
        b = eval(row.img_links)[0].split(".")[-2].split("case-")[-1].split("-")[0].strip()
        print(row.id,a,b,a==b)
        if a!=b:
            print("error")
            break
    

1 324 324 True
2 155 155 True
3 325 325 True
4 156 156 True
5 426 426 True
6 61 61 True
7 62 62 True
8 427 427 True
9 157 157 True
10 126 126 True
11 323 323 True
12 402 402 True
13 158 158 True
14 159 159 True
15 511 511 True
16 326 326 True
17 63 63 True
18 160 160 True
19 327 327 True
20 322 322 True
21 512 512 True
22 64 64 True
23 65 65 True
24 161 161 True
25 555 555 True
26 66 66 True
27 328 328 True
28 321 321 True
29 403 403 True
30 67 67 True
31 404 404 True
32 162 162 True
33 320 320 True
34 68 68 True
35 329 329 True
36 330 330 True
37 69 69 True
38 428 428 True
39 405 405 True
40 163 163 True
41 70 70 True
42 556 556 True
43 513 513 True
44 406 406 True
45 71 71 True
46 429 429 True
47 72 72 True
48 164 164 True
49 514 514 True
50 165 165 True
51 331 331 True
52 407 407 True
53 166 166 True
54 73 73 True
55 332 332 True
56 333 333 True
57 430 430 True
58 515 515 True
59 167 167 True
60 74 74 True
61 334 334 True
62 168 168 True
63 169 169 True
64 335 335 True
65 516 516 Tr

In [52]:
k=t[['id','img_links']]

In [66]:
from urllib.request import urlopen, urlretrieve
img_dir = Path("/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images")
pbar = tqdm(k.iterrows())
fails = []
for i,j in pbar:
    idx = j.id
    pbar.set_description(f"Processing {idx:04}")
    case_dir = img_dir / f"case_{idx:04}"
    # case_dir.mkdir(exist_ok=True)
    img_links = eval(j.img_links)
    extensions = ["jpg", "JPG", 'gif', "GIF", "png", "PNG", "jpeg", "JPEG"]
    if len(img_links) > 0 and j.img_links != []:
        for i,link in enumerate(img_links):
            success = False
            for ext in extensions:
                try:
                    urlretrieve(link.replace('jpg',ext), case_dir / f"image_{idx:04}_{i:02}.{ext}")
                    success = True
                    break
                except:
                    pass

            if not success:
                fails.append(link)



Processing 0001: : 0it [00:00, ?it/s]

Processing 0078: : 77it [00:29,  1.63it/s]

Failed to download case: 0077


Processing 0125: : 124it [03:11, 39.40s/it]

Failed to download case: 0124


Processing 0647: : 646it [08:51,  1.39s/it]

Failed to download case: 0646


Processing 0719: : 718it [09:36,  1.18it/s]

Failed to download case: 0718


Processing 0722: : 721it [09:38,  1.35it/s]

Failed to download case: 0721


Processing 0918: : 917it [12:43,  1.44it/s]

Failed to download case: 0917


Processing 1026: : 1025it [14:02,  1.28it/s]

Failed to download case: 1025


Processing 1159: : 1158it [16:00,  1.35it/s]

Failed to download case: 1158


Processing 1208: : 1207it [16:55,  1.28it/s]

Failed to download case: 1207


Processing 1278: : 1277it [17:54,  1.21s/it]

Failed to download case: 1277


Processing 1303: : 1302it [18:15,  1.08it/s]

Failed to download case: 1302


Processing 1343: : 1343it [18:57,  1.18it/s]


In [72]:
failed = ["case_0077","case_0124","case_0646","case_0718","case_0721","case_0917","case_1025","case_1158","case_1207","case_1277","case_1302"]
for case in failed:
    print(list((img_dir / case).glob("*")))

[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_0077/image_0077.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_0124/image_0124.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_0646/image_0646.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_0718/image_0718.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_0721/image_0721.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_0917/image_0917.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_1025/image_1025.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_1158/image_1158.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_1207/image_1207.jpg')]
[PosixPath('/home/cmti/jogi/diag_plz/med-flamingo/data/kstr/new_images/case_1277/image_1277.jpg')]
[PosixPath

In [61]:
new_data['img_count'] = new_data['img_links'].apply(lambda x: len(eval(x)))


In [62]:
new_data['history'] = new_data['age'].astype(str) + "-year-old " + new_data['sex'].apply(lambda x: "man" if x == "M" else "woman" if x == "F" else "person") + " with a complaint of " + new_data['complaint'] + "."

In [63]:
new_data

Unnamed: 0,id,link,date,age,sex,complaint,diagnosis,findings,brief_review,img_links,applicants,answer_rates,img_count,history
1297,1,https://kstr.radiology.or.kr/weekly/archive/vi...,1997-11-03,27,M,"fever, HIV(+)",Mediastinal tb. lymphadenitis with esophagome...,"A 26-year-old man presented with fever, cough,...",Esophageal involvement by tuberculosis usually...,['https://kstr.radiology.or.kr/weekly/files/ca...,0,[],5,"27-year-old man with a complaint of fever, HIV..."
1296,16,https://kstr.radiology.or.kr/weekly/archive/vi...,1998-02-16,27,M,incidental CPA abnormality,Pericardial cyst (connected to superior perica...,The vast majority of mesothelial cysts (perica...,,['https://kstr.radiology.or.kr/weekly/files/ca...,0,"['correct:1/0', 'semi:4/0']",4,27-year-old man with a complaint of incidental...
1295,18,https://kstr.radiology.or.kr/weekly/archive/vi...,1998-03-02,29,F,mild dyspnea for several years,Sarcoidosis (with progression from stage II to...,Sarcoidosis is a systemic disease of unknown e...,,['https://kstr.radiology.or.kr/weekly/files/ca...,0,['correct:5/0'],2,29-year-old woman with a complaint of mild dys...
1294,20,https://kstr.radiology.or.kr/weekly/archive/vi...,1998-03-16,56,M,fever with myalgia for 3 days,Leptospirosis,"He is a farmer and lives in Paju-gun, Kyoung-g...",,['https://kstr.radiology.or.kr/weekly/files/ca...,0,"['correct:11/0', 'semi:2/0']",4,56-year-old man with a complaint of fever with...
1293,22,https://kstr.radiology.or.kr/weekly/archive/vi...,1998-03-30,68,M,chest PA abnormality,Atypical Carcinoid,Atypical carcinoids have clinicopathologic fea...,,['https://kstr.radiology.or.kr/weekly/files/ca...,0,['correct:1/0'],4,68-year-old man with a complaint of chest PA a...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4,1339,https://kstr.radiology.or.kr/weekly/archive/vi...,2023-06-19,43,F,"Dyspnea for 1 year, SpO2 84%",Hereditary Hemorrhagic Telangiectasia (Osler-R...,Fig 1. Chest PA shows ill-defined peripheral o...,Pulmonary arteriovenous malformations (PAVMs) ...,['https://kstr.radiology.or.kr/weekly/files/ca...,96,"['correct:63/96', 'diff:13/96', 'semi:13/96']",5,43-year-old woman with a complaint of Dyspnea ...
3,1340,https://kstr.radiology.or.kr/weekly/archive/vi...,2023-06-26,51,M,Dyspnea\r\n\r\nPHx: Lung cancer s/p LUL upper ...,Amiodarone pulmonary toxicity,(Fig. 1-2) Chest computed tomography (CT) scan...,The patient had a history of atrial fibrillati...,['https://kstr.radiology.or.kr/weekly/files/ca...,82,"['correct:57/82', 'semi:8/82']",5,51-year-old man with a complaint of Dyspnea\r\...
2,1341,https://kstr.radiology.or.kr/weekly/archive/vi...,2023-07-03,61,M,"Proptosis, visual acuity decrease",Erdheim-Chester disease,Fig 1. Chest PA shows interlobular septal thic...,Erdheim-Chester disease is a rare non-inherite...,['https://kstr.radiology.or.kr/weekly/files/ca...,106,"['correct:56/106', 'diff:11/106', 'semi:1/106']",6,"61-year-old man with a complaint of Proptosis,..."
1,1342,https://kstr.radiology.or.kr/weekly/archive/vi...,2023-07-10,66,F,Incidental finding,Azygos vein aneurysm (AVA),Fig 1. Chest PA shows no significant abnormali...,"- Very rare, only a few cases have been descri...",['https://kstr.radiology.or.kr/weekly/files/ca...,94,"['correct:65/94', 'diff:7/94', 'semi:5/94']",4,66-year-old woman with a complaint of Incident...


In [71]:
mod_data = new_data[['id','history','findings','diagnosis','img_count']]
filtered_csv = {}
filtered_csv[dataset] = root / "data" / dataset / f"{dataset}_data_filtered.csv"
mod_data.to_csv(filtered_csv[dataset], index=False)
type(mod_data.img_count[0])

numpy.int64

In [44]:
ex_prompt = (
        "You are a helpful medical assistant. "
        "You are being provided with some images, a medical history about the patient, some image findings, and a final diagnosis. "
        "Follow the examples and provide the final diagnosis for the last case. "
       
        # [EXAMPLE CHUNK]
        # case 431
        # https://kstr.radiology.or.kr/weekly/archive/view.php?number=30&year=&diagnosis=Tracheal%20Leiomyoma&menu_num=2&sub_num=2006#path
        "<image>"
        "<image>"
        "<image>"

        "Medical History: A 55-year-old man had a dyspenea and had been treated for a bronchial asthma. "

        "Image Findings: "
        "Chest radiograph shows soft tissue bulging contour in the right paratracheal area. "
        "Contrast-enhanced chest CT scans show a well-defined, lobulated 4.5-cm sized mass with endotracheal growth in the right posterolateral aspect of trachea, which demonstrates heterogeneous enhancement.\n"
        "He underwent a segmental resection and end-to-end anastomosis of the trachea. "
        "On histopathologic examination, the resected specimen revealed a well-capsulated white tan lobulating tracheal mass with endotracheal growth. "

        "Final Diagnosis: "
        "Tracheal Leiomyoma."
        "<|endofchunk|>"

        # [ANSWER CHUNK]
        # case 904
        # https://kstr.radiology.or.kr/weekly/archive/view.php?number=1002&year=&diagnosis=Langerhans%20cell%20histiocytosis&menu_num=2&sub_num=2015#path
        "<image>"
        "<image>"
        "<image>"
        "<image>"

        "Medical History: A 48-year-old man had a chronic cough. "

        "Image Findings: "
        "Chest radiograph shows ill-defined small nodular opacities in both lungs with upper lung zone predominance. "
        "Chest CT scans demonstrate numerous thick- and thin-walled small lung cysts and micronodules in both lungs with upper lung predominance. "

        "Final Diagnosis: "
        # "Langerhans cell histiocytosis."
    )

In [77]:
ex = mod_data.iloc[2]
prompt = (
        "You are a helpful medical assistant. "
        "You are being provided with some images, a medical history about the patient, some image findings, and a final diagnosis. "
        "Follow the examples and provide the final diagnosis for the last case. "
       
        # [EXAMPLE CHUNK]
        # case 431
        # https://kstr.radiology.or.kr/weekly/archive/view.php?number=30&year=&diagnosis=Tracheal%20Leiomyoma&menu_num=2&sub_num=2006#path
        
        f"{'<image>' * 3}"

        "Medical History: A 55-year-old man had a dyspenea and had been treated for a bronchial asthma. "

        "Image Findings: "
        "Chest radiograph shows soft tissue bulging contour in the right paratracheal area. "
        "Contrast-enhanced chest CT scans show a well-defined, lobulated 4.5-cm sized mass with endotracheal growth in the right posterolateral aspect of trachea, which demonstrates heterogeneous enhancement.\n"
        "He underwent a segmental resection and end-to-end anastomosis of the trachea. "
        "On histopathologic examination, the resected specimen revealed a well-capsulated white tan lobulating tracheal mass with endotracheal growth. "

        "Final Diagnosis: "
        "Tracheal Leiomyoma."
        "<|endofchunk|>"

        # [ANSWER CHUNK]
        
        f"{'<image>' * ex['img_count']}"

        f"Medical History: {ex['history']} "

        f"Image Findings: {ex['findings']}"

        "Final Diagnosis: "
    )
prompt

'You are a helpful medical assistant. You are being provided with some images, a medical history about the patient, some image findings, and a final diagnosis. Follow the examples and provide the final diagnosis for the last case. <image><image><image>Medical History: A 55-year-old man had a dyspenea and had been treated for a bronchial asthma. Image Findings: Chest radiograph shows soft tissue bulging contour in the right paratracheal area. Contrast-enhanced chest CT scans show a well-defined, lobulated 4.5-cm sized mass with endotracheal growth in the right posterolateral aspect of trachea, which demonstrates heterogeneous enhancement.\nHe underwent a segmental resection and end-to-end anastomosis of the trachea. On histopathologic examination, the resected specimen revealed a well-capsulated white tan lobulating tracheal mass with endotracheal growth. Final Diagnosis: Tracheal Leiomyoma.<|endofchunk|><image><image><image>Medical History: A 34-year-old woman presented for evaluation 

In [105]:
# dataset = "kstr"
separator = "\n"

instruction = {
    "few": (
        "You are a helpful medical assistant. "
        "You are being provided with some images, a medical history about the patient, some image findings, and a final diagnosis. "
        f"Follow the examples and provide the final diagnosis for the last case.{separator*2}"
    ),
    "zero": (
        "You are a helpful medical assistant. "
        "You are being provided with some images, a medical history about the patient, and some image findings. "
        f"Provide the final diagnosis in short text. Do not provide anything else, such as a discussion or an explanation.{separator*2}"
    )
} 

In [102]:
def make_chunk_list(data_from: str, separator: str = " ") -> list:
    chunks = []
    if data_from == "kstr":
        datas = pd.read_csv(filtered_csv[data_from], index_col=None)
        for i, data in datas.iterrows():
            chunks.append ((
                data['id'],
                f"{'<image>' * data['img_count']}{separator}"
                f"Medical History: {data['history']}{separator}"
                f"Image Findings: {data['findings']}{separator}"
                f"Final Diagnosis: {data['diagnosis']}.{separator}"
                f"<|endofchunk|>{separator*2}"
            ))
    return chunks

In [103]:
chunks = make_chunk_list("kstr",separator)

In [134]:
image_path = {
    "kstr": root / "data" / "kstr" / "kstr_images"
}
zeroshot_prompts = []
zeroshot_images = []
for i, chunk in tqdm(chunks):
    zeroshot_images.append([Image.open(img) for img in sorted((image_path[dataset] / f"case_{i:04}").glob("*.jpg"))])
    zeroshot_prompts.append(instruction["zero"] + chunk.split("Final Diagnosis: ")[0] + "Final Diagnosis: ")

100%|██████████| 1298/1298 [00:12<00:00, 102.03it/s]


In [124]:
print(len(zeroshot_prompts))
print(len(zeroshot_images))

1298
1298


In [106]:
print(instruction["few"] + chunks[0][1] + chunks[1][1])

You are a helpful medical assistant. You are being provided with some images, a medical history about the patient, some image findings, and a final diagnosis. Follow the examples and provide the final diagnosis for the last case.

<image><image><image><image><image>
Medical History: 27-year-old man with a complaint of fever, HIV(+).
Image Findings: A 26-year-old man presented with fever, cough, and weight loss for two months. Serological test for HIV was positive. Many acid-fast bacilli were found in his sputum. Chest radiograph shows an infiltration in the right upper lung zone with mediastinal widening. CT scans show diffusely enlarged mediastinal lymph nodes that have irregular low density internally and rim enhancement. An irregularly shaped gas collection is seen in the right subcarinal area, which communicates with esophageal gas. An esophagogram with gastrografin (not shown) showed a fistulous communication between the esophagus and the gas-filled space. After two months of anti

# INFER

In [135]:
print(zeroshot_images[0])
print(zeroshot_prompts[0])

[<PIL.JpegImagePlugin.JpegImageFile image mode=L size=794x529 at 0x7F77EC8C6C40>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=340x454 at 0x7F78695A8B80>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=340x454 at 0x7F78695A84F0>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=340x454 at 0x7F78695A87F0>, <PIL.JpegImagePlugin.JpegImageFile image mode=L size=340x454 at 0x7F78695A8550>]
You are a helpful medical assistant. You are being provided with some images, a medical history about the patient, and some image findings. Provide the final diagnosis in short text. Do not provide anything else, such as a discussion or an explanation.

<image><image><image><image><image>
Medical History: 27-year-old man with a complaint of fever, HIV(+).
Image Findings: A 26-year-old man presented with fever, cough, and weight loss for two months. Serological test for HIV was positive. Many acid-fast bacilli were found in his sputum. Chest radiograph shows an infiltration in the right u

In [None]:
"""
Step 3: Preprocess data 
"""
print('Preprocess data')
pixels = processor.preprocess_images(zeroshot_images[0])
pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1)
tokenized_data = processor.encode_text(zeroshot_prompts[0])

In [None]:
"""
Step 4: Generate response 
"""
# actually run zero-shot prompt through model:
print('Generate from multimodal zero-shot prompt')
generated_text = model.generate(
    vision_x=pixels.to(device),
    lang_x=tokenized_data["input_ids"].to(device),
    attention_mask=tokenized_data["attention_mask"].to(device),
    max_new_tokens=10,
)
response = processor.tokenizer.decode(generated_text[0])
response = clean_generation(response)
response = response.split("Final Diagnosis: ")[-1]

print(f'{response=}')

In [59]:
%load_ext autoreload
%autoreload 2

from my_utils import load_prompts, get_prompt_from_id

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
prompts = load_prompts("kstr", shots=0)

Loading prompts and images...


  0%|          | 0/1298 [00:00<?, ?it/s]

100%|██████████| 1298/1298 [00:11<00:00, 112.17it/s]


In [None]:
for img in get_prompt_from_id(prompts,902)['images']:
    img.show()

In [57]:
print(prompts[1]['text'])

You are a helpful medical assistant. You are being provided with some images, a medical history about the patient, and some image findings. Provide the final diagnosis in short text. Do not provide anything else, such as a discussion or an explanation.

<image><image><image><image>
Medical History: 27-year-old man with a complaint of incidental CPA abnormality.
Image Findings: The vast majority of mesothelial cysts (pericardial or pleuropericardial) are probably congenital and result from aberrations in the formation of the coelomic cavities. Grossly, the cysts are spherical or oval in shape, thin-walled, and  often translucent: the vast majority are unilocular and contain clear or straw-colored fluid. Of the 72 cysts reviewed by a Mayo Clinic group, 54 were in the cardiophrenic angles and 18 arose at a higher level: 11 of the latter cysts extended into the superior mediastinum. In most cases they range in diameter from 3 to 8 cm but rarely have been reported to be as small as 1 cm or 