In [None]:
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from numpy.linalg import norm
from torch.utils.data import Dataset
from torchvision import transforms
import re
import torch.nn as nn
from torch import nn
import torch.optim as optim
from torchvision import transforms, models
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_scheduler
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
import nltk
import random
from nltk.corpus import stopwords
from torchvision import transforms
from sklearn.metrics import roc_auc_score
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, multilabel_confusion_matrix,ConfusionMatrixDisplay
)
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import os

from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from transformers import AutoModel
from torchvision.models import swin_b, Swin_B_Weights
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
from data_preprocessing import MIMIC_MultiModalDataset
from mid_fusion import BioFuse

In [None]:
image_dir="/data/mimic-cxr/mimic-cxr-jpg"
report_dir="/data/reports/"

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    ])


In [None]:
test_dataset = MIMIC_MultiModalDataset(
    image_dir=image_dir,
    report_dir=report_dir,
    mode='test',transform=transform
    )

In [None]:
from transformers import AutoTokenizer
text_encoder_type = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BioFuse().to(device)
criterion = nn.BCEWithLogitsLoss()

In [None]:
label_columns = test_dataset.data.drop(
    columns=["id", "findings", "paths", "Unnamed: 0", "file_path", "impression", "subject_id", "study_id"]
).columns.tolist()

model.eval()


In [None]:
concept_names = [
    "alveolar_opacity",
    "interstitial_opacity",
    "focal_lung_opacity",
    "diffuse_lung_opacity",
    "blunted_costophrenic_angle",
    "air_fluid_level",
    "hyperlucency",
    "volume_loss",
    "enlarged_cardiac_silhouette",
    "mediastinal_widening",
    "fracture_line",
    "tube_or_line_present",
    "lung_mass_or_nodule",
    "infection_language",
    "fluid_language",
    "collapse_language",
    "device_language",
    "acute_finding_language",
    "chronic_finding_language"
]

NUM_CONCEPTS = len(concept_names)

In [None]:
def extract_text_concepts(report: str):
    r = report.lower()
    return torch.tensor([
        int(any(w in r for w in ["consolidation", "airspace"])),
        int("interstitial" in r),
        int("focal" in r),
        int(any(w in r for w in ["diffuse", "bilateral"])),
        int("costophrenic" in r),
        int("air-fluid" in r),
        int(any(w in r for w in ["hyperlucent", "lucency"])),
        int(any(w in r for w in ["volume loss", "collapse"])),
        int("cardiomegaly" in r),
        int("mediastinal widening" in r),
        int("fracture" in r),
        int(any(w in r for w in ["tube", "line", "catheter"])),
        int(any(w in r for w in ["mass", "nodule"])),
        int(any(w in r for w in ["pneumonia", "infection"])),
        int(any(w in r for w in ["effusion", "fluid"])),
        int(any(w in r for w in ["collapse", "atelectasis"])),
        int(any(w in r for w in ["device", "support"])),
        int("acute" in r),
        int("chronic" in r),
    ])

In [None]:
class FrozenBioFuseEncoder(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        for p in self.model.parameters():
            p.requires_grad = False
        self.model.eval()

    def forward(self, images, input_ids, attention_mask):
        with torch.no_grad():
            img_feat = self.model.image_encoder(images)
            txt_feat = self.model.text_encoder(input_ids, attention_mask)
            fused = self.model.fusion(img_feat, txt_feat)
        return fused


In [None]:
class ConceptBottleneck(nn.Module):
    def __init__(self, embed_dim=768, num_concepts=NUM_CONCEPTS):
        super().__init__()
        self.fc = nn.Linear(embed_dim, num_concepts)

    def forward(self, z):
        return torch.sigmoid(self.fc(z))


In [None]:
class ConceptToDisease(nn.Module):
    def __init__(self, num_concepts, num_labels):
        super().__init__()
        self.fc = nn.Linear(num_concepts, num_labels)

    def forward(self, concepts):
        return self.fc(concepts)


In [None]:
class ConceptDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, tokenizer):
        self.ds = base_dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]
        #print(item)

        text = item["text"]
        enc = self.tokenizer(
            text, padding="max_length", truncation=True,
            max_length=128, return_tensors="pt"
        )

        return {
            "image": item["image"],
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": item["label"],
            "concepts": extract_text_concepts(text)
        }


In [None]:
def train_concept_bottleneck(encoder, cbm, dataloader, device, epochs=15):

    opt = torch.optim.Adam(cbm.parameters(), lr=1e-3)
    loss_fn = nn.BCELoss()

    encoder.eval()
    cbm.train()

    for ep in range(epochs):
        total = 0
        for b in dataloader:
            z = encoder(
                b["image"].to(device),
                b["input_ids"].to(device),
                b["attention_mask"].to(device),
            )
            pred = cbm(z)
            loss = loss_fn(pred, b["concepts"].float().to(device))

            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()

        print(f"[CBM] Epoch {ep+1} | Loss {total/len(dataloader):.4f}")


In [None]:
def train_surrogate(encoder, cbm, surrogate, dataloader, device, epochs=10):
    opt = torch.optim.Adam(surrogate.parameters(), lr=1e-3)
    loss_fn = nn.BCEWithLogitsLoss()

    encoder.eval()
    cbm.eval()
    surrogate.train()

    for ep in range(epochs):
        for b in dataloader:
            with torch.no_grad():
                z = encoder(
                    b["image"].to(device),
                    b["input_ids"].to(device),
                    b["attention_mask"].to(device),
                )
                c = cbm(z)

            logits = surrogate(c)
            loss = loss_fn(logits, b["labels"].float().to(device))

            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"[SUR] Epoch {ep+1} | Loss {loss.item():.4f}")


In [None]:
concept_ds = ConceptDataset(train_dataset, tokenizer)
loader = torch.utils.data.DataLoader(concept_ds, batch_size=16, shuffle=True)

encoder = FrozenBioFuseEncoder(model).to(device)
cbm = ConceptBottleneck().to(device)
surrogate = ConceptToDisease(NUM_CONCEPTS,num_labels=14).to(device)

train_concept_bottleneck(encoder, cbm, loader, device)
train_surrogate(encoder, cbm, surrogate, loader, device)


In [None]:
test_loader=DataLoader(test_dataset,batch_size=32)

In [None]:
def get_sample_by_id(dataset, sample_id):
    for i in range(len(dataset)):
        sample = dataset[i]
        if sample["id"] == sample_id:
            return sample
    raise ValueError(f"Sample with id {sample_id} not found.")


In [None]:
concept_ds = ConceptDataset(train_dataset, tokenizer)
loader = torch.utils.data.DataLoader(concept_ds, batch_size=16, shuffle=True)
encoder1 = FrozenBioFuseEncoder(model).to(device)

In [None]:
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


In [None]:
def explain_sample(image, text, tokenizer, encoder, cbm, device):
    enc = tokenizer([text], padding='max_length', truncation=True, max_length=128, return_tensors='pt').to(device)
    z = encoder(
        image.to(device),
        enc['input_ids'],
        enc['attention_mask']
    )
    c = cbm(z).squeeze(0)
    concept_dict = {concept_names[i]: float(c[i]) for i in range(len(concept_names))}

    return concept_dict


In [None]:
def tensor_to_pil(img_tensor):
    if img_tensor.dim() == 4:
        img_tensor = img_tensor.squeeze(0)
    img_tensor = img_tensor.detach().cpu()
    img_np = img_tensor.permute(1,2,0).numpy()
    img_np = (img_np * 255).astype(np.uint8)
    return Image.fromarray(img_np)


In [None]:
def get_concept_explanation(image_tensor, text, encoder, cbm, tokenizer, device):
    enc = tokenizer([text], padding='max_length', truncation=True, max_length=128, return_tensors='pt').to(device)
    z = encoder(
        image_tensor.to(device),
        enc['input_ids'],
        enc['attention_mask']
    )
    c = cbm(z).squeeze(0)
    concept_dict = {concept_names[i]: float(c[i]) for i in range(len(concept_names))}

    return concept_dict


In [None]:
sample_ids = [
    "sub_11984732_idx_1121",
    "sub_11270948_idx_1479",
    "sub_11714071_idx_396",
    "sub_11667471_idx_1073"
]


In [None]:
results = []

for sample_id in sample_ids:
    try:
        sample = get_sample_by_id(test_dataset, sample_id)
        sample_image = sample["image"]
        sample_text  = sample["text"]
        pil_image = tensor_to_pil(sample_image)
        input_image = preprocess(pil_image).unsqueeze(0).to(device)
        sample_concepts = explain_sample(
            image=input_image,
            text=sample_text,
            tokenizer=tokenizer,
            encoder=encoder,
            cbm=cbm,
            device=device
        )
        concept_dict = get_concept_explanation(
            image_tensor=input_image,
            text=sample_text,
            encoder=encoder,
            cbm=cbm,
            tokenizer=tokenizer,
            device=device
        )
        top_concepts = sorted(concept_dict.items(), key=lambda x: x[1], reverse=True)[:5]
        top_concepts_str = "<br>".join([f"{k}: {v:.2f}" for k, v in top_concepts])
        results.append({
            "id": sample_id,
            "text": sample_text,
            "concept_dict":concept_dict,
            "top_concepts": top_concepts,
            "top_concepts_str": top_concepts_str
        })

        print(f"Processed {sample_id}: {top_concepts_str}")
        print(concept_dict)

    except Exception as e:
        print(f"Error processing {sample_id}: {e}")
