In [3]:
import os
root = "/root/letractien/data/nas07/Dataset/Image/EMBED/"
jpgs = "jpgs"
annotated_mass = "annotated_mass"
other_jpgs = "other_jpgs"
csv_files = "csv_files"

root_list = os.listdir(root)
jpgs_list = os.listdir(os.path.join(root, jpgs))
annotated_mass_list = os.listdir(os.path.join(root, annotated_mass))
other_jpgs_list = os.listdir(os.path.join(root, other_jpgs))
csv_files_list = os.listdir(os.path.join(root, csv_files))

print(len(root_list))
print(len(jpgs_list))
print(len(annotated_mass_list))
print(len(other_jpgs_list))
print(len(csv_files_list))

4
284815
3186
195611
6


In [None]:
!nvidia-smi

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5, 6"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
import torch
torch.manual_seed(1234)

from qwen_vl_utils import process_vision_info
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig

CACHE_DIR = "/root/letractien/Mammo-VLM/.cache"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_enable_fp32_cpu_offload=True
)
model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path, 
    torch_dtype="auto", 
    device_map="auto",
    quantization_config=bnb_config,
    cache_dir=CACHE_DIR
)
processor = AutoProcessor.from_pretrained(
    model_path, 
    # min_pixels=256*28*28, 
    # max_pixels=1280*28*28,,
    cache_dir=CACHE_DIR
)

In [None]:
import json
import random
import io
import ast
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageColor
import xml.etree.ElementTree as ET
additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]

def plot_bounding_boxes(img, bounding_boxes, input_width, input_height, save_path=None):

    width, height = img.size
    draw = ImageDraw.Draw(img)

    colors = [
        'red',
        'green',
        'blue',
        'yellow',
        'orange',
        'pink',
        'purple',
        'brown',
        'gray',
        'beige',
        'turquoise',
        'cyan',
        'magenta',
        'lime',
        'navy',
        'maroon',
        'teal',
        'olive',
        'coral',
        'lavender',
        'violet',
        'gold',
        'silver',
    ] + additional_colors

    bounding_boxes = parse_json(bounding_boxes)
    font = ImageFont.load_default()

    try:
        json_output = ast.literal_eval(bounding_boxes)
    except Exception as e:
        end_idx = bounding_boxes.rfind('"}') + len('"}')
        truncated_text = bounding_boxes[:end_idx] + "]"
        json_output = ast.literal_eval(truncated_text)

    for i, bounding_box in enumerate(json_output):
        color = colors[i % len(colors)]
        abs_y1 = int(bounding_box["bbox_2d"][1]/input_height * height)
        abs_x1 = int(bounding_box["bbox_2d"][0]/input_width * width)
        abs_y2 = int(bounding_box["bbox_2d"][3]/input_height * height)
        abs_x2 = int(bounding_box["bbox_2d"][2]/input_width * width)

        if abs_x1 > abs_x2:
            abs_x1, abs_x2 = abs_x2, abs_x1

        if abs_y1 > abs_y2:
            abs_y1, abs_y2 = abs_y2, abs_y1

        draw.rectangle(
            ((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4
        )

        if "label" in bounding_box:
            draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
            
    img.show()
    if save_path is not None:
        img.save(save_path)

def parse_json(json_output):
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line == "```json":
            json_output = "\n".join(lines[i+1:])  
            json_output = json_output.split("```")[0]
            break
    return json_output

In [None]:
def inference(img_url, prompt, system_prompt="You are a helpful assistant in detecting suspicious breast lumps or microcalcifications in mammography images", max_new_tokens=1024):
    image = Image.open(img_url)
    messages = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "image": img_url
                }
            ]
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")

    output_ids = model.generate(**inputs, max_new_tokens=1024)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    print("Output: \n", output_text[0])

    input_height = inputs['image_grid_thw'][0][1]*14
    input_width = inputs['image_grid_thw'][0][2]*14

    return output_text[0], input_height, input_width

In [None]:
import numpy as np
import preprocess

save_dir = "out/EMBED_QW7B"
os.makedirs(save_dir, exist_ok=True)

root = "/root/letractien/data/nas07/Dataset/Image/EMBED/"
jpgs = "jpgs"
annotated_mass = "annotated_mass"
other_jpgs = "other_jpgs"

log_path = os.path.join(save_dir, "log.txt")
img_dir = os.path.join(root, jpgs)
count = 0

for idx, file in enumerate(os.listdir(img_dir)):
    img_in_path = os.path.join(img_dir, file)
    img = Image.open(img_in_path)
    img_arr = np.array(img)

    x, m = preprocess.crop(img_arr, annotation=None)
    x, m, _ = preprocess.pad_image_to_square(x, mask_array=m, annotation=None)
    x, m, _ = preprocess.resize_image(x, mask_array=m, annotation=None, output_shape=(640, 640))
    norm = preprocess.truncation_normalization(x, m)

    step1 = preprocess.median_denoise(norm, disk_radius=3)
    step2 = preprocess.unsharp_enhance(step1, radius=1.0, amount=1.5)
    step3 = preprocess.morphological_tophat(step2, selem_radius=15)
    step4 = preprocess.non_local_means_denoise(step3, patch_size=5, patch_distance=6, h_factor=0.8)
    step5 = preprocess.wavelet_enhancement(step4, wavelet='db8', level=1)
    final = preprocess.clahe(step5, clip_limit=0.02)

    disp = preprocess.normalize_for_display(final)
    disp = np.nan_to_num(disp)

    img_out_path = os.path.join(save_dir, f"{os.path.splitext(file)[0]}.png")
    Image.fromarray(disp).convert("RGB").save(img_out_path)

    disp = preprocess.normalize_for_display(norm)
    disp = np.nan_to_num(disp)
    img_out_path = os.path.join(save_dir, f"{os.path.splitext(file)[0]}.png")
    Image.fromarray(disp).convert("RGB").save(img_out_path)

    with open("PROMPT_EMBED.txt", "r", encoding="utf-8") as f: prompt = f.read()
    response, input_height, input_width = inference(img_out_path, prompt)
    print(input_height, input_width)

    image = Image.open(img_out_path)
    print(image.size)

    image.thumbnail([640,640], Image.Resampling.LANCZOS)
    img_out_path = os.path.join(save_dir, f"{os.path.splitext(file)[0]}_bbox.png")
    plot_bounding_boxes(image, response, input_width, input_height, img_out_path)

    if count == 100: break