In [None]:
!nvidia-smi

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5, 6, 7"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
import torch
from qwen_vl_utils import process_vision_info
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig

CACHE_DIR = "/root/letractien/Mammo-VLM/.cache"
# bnb_config = BitsAndBytesConfig(load_in_8bit=True)
model_path = "Qwen/Qwen2.5-VL-7B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path, 
    torch_dtype="auto", 
    device_map="auto",
    # quantization_config=bnb_config,
    cache_dir=CACHE_DIR
)
processor = AutoProcessor.from_pretrained(
    model_path, 
    # min_pixels=256*28*28, 
    # max_pixels=1280*28*28,
    cache_dir=CACHE_DIR
)

In [None]:
import json
import random
import io
import ast
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageColor
import xml.etree.ElementTree as ET
additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]

def plot_bounding_boxes(img, bounding_boxes, input_width, input_height):

    width, height = img.size
    draw = ImageDraw.Draw(img)

    colors = [
        'red',
        'green',
        'blue',
        'yellow',
        'orange',
        'pink',
        'purple',
        'brown',
        'gray',
        'beige',
        'turquoise',
        'cyan',
        'magenta',
        'lime',
        'navy',
        'maroon',
        'teal',
        'olive',
        'coral',
        'lavender',
        'violet',
        'gold',
        'silver',
    ] + additional_colors

    bounding_boxes = parse_json(bounding_boxes)
    font = ImageFont.load_default()

    try:
        json_output = ast.literal_eval(bounding_boxes)
    except Exception as e:
        end_idx = bounding_boxes.rfind('"}') + len('"}')
        truncated_text = bounding_boxes[:end_idx] + "]"
        json_output = ast.literal_eval(truncated_text)

    for i, bounding_box in enumerate(json_output):
        color = colors[i % len(colors)]
        abs_y1 = int(bounding_box["bbox_2d"][1]/input_height * height)
        abs_x1 = int(bounding_box["bbox_2d"][0]/input_width * width)
        abs_y2 = int(bounding_box["bbox_2d"][3]/input_height * height)
        abs_x2 = int(bounding_box["bbox_2d"][2]/input_width * width)

        if abs_x1 > abs_x2:
            abs_x1, abs_x2 = abs_x2, abs_x1

        if abs_y1 > abs_y2:
            abs_y1, abs_y2 = abs_y2, abs_y1

        draw.rectangle(
            ((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4
        )

        if "label" in bounding_box:
            draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)

    img.show()

def parse_json(json_output):
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line == "```json":
            json_output = "\n".join(lines[i+1:])  
            json_output = json_output.split("```")[0]
            break
    return json_output

In [None]:
def inference(img_url, prompt, system_prompt="You are a helpful assistant", max_new_tokens=1024):
    image = Image.open(img_url)
    messages = [
        {
            "role": "system",
            "content": system_prompt
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "image": img_url
                }
            ]
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")

    output_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    print("Output: \n", output_text[0])

    input_height = inputs['image_grid_thw'][0][1]*14
    input_width = inputs['image_grid_thw'][0][2]*14

    return output_text[0], input_height, input_width

In [None]:
import os
import dataset

image_annotation_tuples = dataset.load_image_annotation_tuples()
unique_tuples = list({img_path: (img_path, ann) for img_path, ann in image_annotation_tuples}.values())

save_dir = "out/qwen_25_vl_7B_instruct_ipynb"
os.makedirs(save_dir, exist_ok=True)
log_path = os.path.join(save_dir, "log.txt")

In [None]:
import pydicom
import numpy as np
import preprocess

idx = 5
img_path, annotation = unique_tuples[idx]
folder = annotation['study_id']
os.makedirs(os.path.join(save_dir, folder), exist_ok=True)

basename = annotation['image_id']
img_png_path = os.path.join(save_dir, folder, f"{basename}.png")

ds = pydicom.dcmread(img_path)
img_arr = ds.pixel_array.astype(np.float32)

x, m, new_annotation = preprocess.crop(img_arr, annotation=annotation)
x, m, new_annotation = preprocess.pad_image_to_square(x, mask_array=m, annotation=new_annotation)
x, m, new_annotation = preprocess.resize_image(x, mask_array=m, annotation=new_annotation, output_shape=(640, 640))
norm = preprocess.truncation_normalization(x, m)

step1 = preprocess.median_denoise(norm, disk_radius=3)
step2 = preprocess.unsharp_enhance(step1, radius=1.0, amount=1.5)
step3 = preprocess.morphological_tophat(step2, selem_radius=15)
step4 = preprocess.non_local_means_denoise(step3, patch_size=5, patch_distance=6, h_factor=0.8)
step5 = preprocess.wavelet_enhancement(step4, wavelet='db8', level=1)
final = preprocess.clahe(step5, clip_limit=0.02)

disp = preprocess.normalize_for_display(final)
disp = np.nan_to_num(disp)

# new_annotation["xmin"] = max(int(new_annotation['xmin']) - 20, 0)
# new_annotation["ymin"] = max(int(new_annotation['ymin']) - 20, 0)
# new_annotation["xmax"] = min(int(new_annotation['xmax']) + 20, new_annotation["width"])
# new_annotation["ymax"] = min(int(new_annotation['ymax']) + 20, new_annotation["height"])
# disp = preprocess.draw_bbox_grayscale(disp.copy(), new_annotation, color=255, thickness=5)

img_png_path_pre = os.path.join(save_dir, folder, f"{basename}_{idx}_preprocessed.png")
Image.fromarray(disp).convert("RGB").save(img_png_path_pre)

import matplotlib.pyplot as plt
plt.figure(figsize=(4, 4)) 
plt.imshow(disp, cmap='gray')
plt.axis('off')
plt.show()


In [None]:
import prompt

image_path = img_png_path_pre
prompt = "Please strictly mark and select all small, round suspicious masses or suspicious calcifications in the image, and output the corresponding detection frame coordinates for subsequent diagnosis and analysis. The detection frame should fit closely to the detected target, output its bbox coordinates using JSON format."
response, input_height, input_width = inference(image_path, prompt)

image = Image.open(image_path)
print(image.size)

In [None]:
image.thumbnail([640,640], Image.Resampling.LANCZOS)
plot_bounding_boxes(image, response, input_width, input_height)

In [None]:
for idx, (img_path, annotation) in enumerate(unique_tuples):
    img_path, annotation = unique_tuples[idx]
    folder = annotation['study_id']
    os.makedirs(os.path.join(save_dir, folder), exist_ok=True)

    basename = annotation['image_id']
    img_png_path = os.path.join(save_dir, folder, f"{basename}.png")

    ds = pydicom.dcmread(img_path)
    try: img_arr = ds.pixel_array.astype(np.float32)
    except Exception as e: continue 

    x, m, new_annotation = preprocess.crop(img_arr, annotation=annotation)
    x, m, new_annotation = preprocess.pad_image_to_square(x, mask_array=m, annotation=new_annotation)
    x, m, new_annotation = preprocess.resize_image(x, mask_array=m, annotation=new_annotation, output_shape=(640, 640))
    norm = preprocess.truncation_normalization(x, m)

    step1 = preprocess.median_denoise(norm, disk_radius=3)
    step2 = preprocess.unsharp_enhance(step1, radius=1.0, amount=1.5)
    step3 = preprocess.morphological_tophat(step2, selem_radius=15)
    step4 = preprocess.non_local_means_denoise(step3, patch_size=5, patch_distance=6, h_factor=0.8)
    step5 = preprocess.wavelet_enhancement(step4, wavelet='db8', level=1)
    final = preprocess.clahe(step5, clip_limit=0.02)

    disp = preprocess.normalize_for_display(final)
    disp = np.nan_to_num(disp)

    # new_annotation["xmin"] = max(int(new_annotation['xmin']) - 20, 0)
    # new_annotation["ymin"] = max(int(new_annotation['ymin']) - 20, 0)
    # new_annotation["xmax"] = min(int(new_annotation['xmax']) + 20, new_annotation["width"])
    # new_annotation["ymax"] = min(int(new_annotation['ymax']) + 20, new_annotation["height"])
    # disp = preprocess.draw_bbox_grayscale(disp.copy(), new_annotation, color=255, thickness=1)

    img_png_path_pre = os.path.join(save_dir, folder, f"{basename}_{idx}_preprocessed.png")
    Image.fromarray(disp).save(img_png_path_pre)

    # import matplotlib.pyplot as plt
    # plt.figure(figsize=(4, 4)) 
    # plt.imshow(disp, cmap='gray')
    # plt.axis('off')
    # plt.show()

    image_path = img_png_path_pre
    with open("PROMPT_qwen_25_vl_7B_instruct.txt", "r", encoding="utf-8") as f: prompt = f.read()
    response, input_height, input_width = inference(image_path, prompt)

    image = Image.open(image_path)
    print(image.size)

    image.thumbnail([640,640], Image.Resampling.LANCZOS)
    plot_bounding_boxes(image, response, input_width, input_height)
