In [None]:
!nvidia-smi

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5, 6, 7"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

import torch
torch.manual_seed(1234)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CACHE_DIR = "/root/letractien/Mammo-VLM/.cache"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map=device, trust_remote_code=True, cache_dir=CACHE_DIR).eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True, cache_dir=CACHE_DIR)

In [None]:
import numpy as np
import preprocess
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt


save_dir = "out/EMBED_QWVL"
os.makedirs(save_dir, exist_ok=True)

root = "/root/letractien/data/nas07/Dataset/Image/EMBED/"
jpgs = "jpgs"
annotated_mass = "annotated_mass"
other_jpgs = "other_jpgs"

log_path = os.path.join(save_dir, "log.txt")
img_dir = os.path.join(root, annotated_mass)
count = 0

for idx, file in enumerate(os.listdir(img_dir)):
    img_in_path = os.path.join(img_dir, file)
    img = Image.open(img_in_path)
    img_arr = np.array(img)

    x, m = preprocess.crop(img_arr, annotation=None)
    x, m, _ = preprocess.pad_image_to_square(x, mask_array=m, annotation=None)
    x, m, _ = preprocess.resize_image(x, mask_array=m, annotation=None, output_shape=(640, 640))
    norm = preprocess.truncation_normalization(x, m)

    step1 = preprocess.median_denoise(norm, disk_radius=3)
    step2 = preprocess.unsharp_enhance(step1, radius=1.0, amount=1.5)
    step3 = preprocess.morphological_tophat(step2, selem_radius=15)
    step4 = preprocess.non_local_means_denoise(step3, patch_size=5, patch_distance=6, h_factor=0.8)
    step5 = preprocess.wavelet_enhancement(step4, wavelet='db8', level=1)
    final = preprocess.clahe(step5, clip_limit=0.02)

    disp = preprocess.normalize_for_display(final)
    disp = np.nan_to_num(disp)

    basename = os.path.splitext(file)[0]
    img_out_path = os.path.join(save_dir, f"{basename}.png")
    Image.fromarray(disp).save(img_out_path)
    with open("PROMPT_EMBED.txt", "r", encoding="utf-8") as f: prompt = f.read()

    query = tokenizer.from_list_format([
        {'image': img_out_path},
        {'text': prompt}
    ])

    response, history = model.chat(tokenizer, query=query, history=None)
    with open(log_path, "a", encoding="utf-8") as f:
        f.write(f"Response {idx}: {response}\n")
        f.write(f"History {idx}: {history}\n")
        f.write("\n")

    image = tokenizer.draw_bbox_on_latest_picture(response, history)
    img_out_path_bbox = os.path.join(save_dir, f"{basename}_{idx}_bbox.png")
    if image: image.save(img_out_path_bbox)
    else: print("No bbox")

    img1 = np.array(Image.open(img_out_path))
    img2 = np.array(Image.open(img_out_path_bbox))
    fig, axs = plt.subplots(1, 2, figsize=(6, 6))
    
    axs[0].imshow(img1, cmap='gray')
    axs[0].set_title("Old BBox")
    axs[0].axis("off")

    axs[1].imshow(img2, cmap='gray')
    axs[1].set_title("New Bbox")
    axs[1].axis("off")

    plt.tight_layout()
    plt.show()