In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [2]:
%load_ext autoreload
%autoreload 2

import warnings
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from monai.utils import ensure_tuple_rep
from transformers import BertTokenizer, BertModel
from transformers.utils import logging
from torch import nn

warnings.simplefilter("ignore")
logging.set_verbosity_error()
torch.set_printoptions(profile="default")
torch.autograd.set_detect_anomaly(False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized', do_lower_case=True)
text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized").to(device)
text_encoder.resize_token_embeddings(len(tokenizer))

Embedding(30522, 768, padding_idx=0)

In [3]:
valid_reports = "/project/project_465001111/ct_clip/CT-CLIP-UT/reports/valid_reports.csv"
valid_labels = "/project/project_465001111/ct_clip/CT-CLIP-UT/labels/valid_labels.csv"
reports_df = pd.read_csv(valid_reports)
labels_df = pd.read_csv(valid_labels)

In [4]:
text_reports_dict = {
    row['VolumeName']: f"{str(row['Findings_EN'])} {str(row['Impressions_EN'])}"
    for _, row in reports_df.iterrows()
}

text_reports_dict

{'valid_1_a_1.nii.gz': 'Trachea, both main bronchi are open. Mediastinal main vascular structures, heart contour, size are normal. Thoracic aorta diameter is normal. Pericardial effusion-thickening was not observed. Thoracic esophageal calibration was normal and no significant tumoral wall thickening was detected. No enlarged lymph nodes in prevascular, pre-paratracheal, subcarinal or bilateral hilar-axillary pathological dimensions were detected. When examined in the lung parenchyma window; A few millimetric nonspecific nodules and mild recessions are observed in the upper lobe and lower lobe of the right lung. Aeration of both lung parenchyma is normal and no infiltrative lesion is detected in the lung parenchyma. Pleural effusion-thickening was not detected. Upper abdominal organs included in the sections are normal. No space-occupying lesion was detected in the liver that entered the cross-sectional area. Bilateral adrenal glands were normal and no space-occupying lesion was detect

In [55]:
PATHOLOGIES = labels_df.columns[1:-1]
labels_df['one_hot_labels'] = labels_df[PATHOLOGIES].values.tolist()

In [56]:
def get_cls_embedding(text):
    tokens = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=512
    )
    tokens = {k: v.to(device) for k, v in tokens.items()}
    with torch.no_grad():
        output = text_encoder(**tokens).last_hidden_state[:, 0, :]  # CLS token
    return output.squeeze(0).cpu()


In [57]:
pathology_embeddings = {}

for pathology in tqdm(PATHOLOGIES):
    present_embeddings = []
    absent_embeddings = []

    for idx, row in labels_df.iterrows():
        vol = row['VolumeName']
        if vol not in text_reports_dict:
            continue
        text = text_reports_dict[vol]
        embedding = get_cls_embedding(text)

        if row[pathology] == 1:
            present_embeddings.append(embedding)
        else:
            absent_embeddings.append(embedding)

    if not present_embeddings or not absent_embeddings:
        print(f"Skipping {pathology}: not enough samples")
        continue

    present_avg = torch.stack(present_embeddings).mean(dim=0)
    absent_avg = torch.stack(absent_embeddings).mean(dim=0)
    diff_vector = present_avg - absent_avg

    pathology_embeddings[pathology] = diff_vector.numpy()

100%|██████████| 17/17 [09:51<00:00, 34.77s/it]


In [60]:
output_file = "/project/project_465001111/ct_clip/CT-CLIP-UT/src/resources/pathology_diff_embeddings.npy"
np.save(output_file, pathology_embeddings)

print(f"Saved pathology embeddings to {output_file}")

Saved pathology embeddings to /project/project_465001111/ct_clip/CT-CLIP-UT/src/resources/pathology_diff_embeddings.npy


In [62]:
pathology_embeddings["Medical material"]

array([ 0.10675037,  0.06374189,  0.07476789,  0.06367601,  0.10834735,
       -0.1170496 , -0.1855166 ,  0.00817448, -0.10134506,  0.15482326,
       -0.14807042,  0.03412744,  0.00220391, -0.05160905,  0.1554439 ,
        0.11838487, -0.02782467, -0.00389355,  0.18813065,  0.01229225,
       -0.02197525,  0.14039484, -0.0654846 , -0.08672924,  0.00867936,
       -0.02128163, -0.05558404, -0.02722287, -0.04295498, -0.09767595,
       -0.0024294 ,  0.04523085, -0.02537951, -0.07898887,  0.03030148,
       -0.04641089, -0.04797716,  0.04694137, -0.04968792,  0.0069387 ,
       -0.11809726,  0.052692  ,  0.02107817, -0.08812366, -0.00412425,
       -0.0742615 ,  0.01818008,  0.01092157, -0.0595842 ,  0.01999733,
       -0.00528154, -0.10838734,  0.06681275, -0.00044985,  0.03231245,
       -0.06177442,  0.07521291,  0.00740924,  0.09587963,  0.0955925 ,
       -0.0617066 , -0.04136835,  0.08584622, -0.04582641, -0.0389861 ,
       -0.04368504,  0.01134068, -0.01094938,  0.03619786, -0.08