In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"   

In [None]:
!git clone https://github.com/microsoft/LLaVA-Med.git


In [None]:
%cd LLaVA-Med

In [None]:
!pip install --upgrade pip setuptools wheel
!pip install -e .
!pip install numpy==1.26.4
!pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
!pip install transformers==4.36.2 peft==0.4.0 accelerate==0.21.0


In [None]:
%cd LLaVA-Med
!mkdir -p checkpoints
!huggingface-cli download microsoft/llava-med-v1.5-mistral-7b --local-dir checkpoints/llava-med-v1.5-mistral-7b --local-dir-use-symlinks False


# Load Base Model

In [None]:
%cd LLaVA-Med

In [None]:
from torch.utils.data import Dataset
from PIL import Image
from llava.conversation import conv_templates
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from llava.model.builder import load_pretrained_model
import pandas as pd
import os
import torch

tokenizer, base_model, image_processor, context_len = load_pretrained_model(
    model_path="checkpoints/llava-med-v1.5-mistral-7b",
    model_base=None,
    model_name="llava-med-v1.5-mistral-7b"
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.unk_token

lora_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="all", 
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, lora_config)

for name, param in model.named_parameters():
    if any(norm in name.lower() for norm in ["norm", "ln"]):
        param.requires_grad = True

for name, param in model.named_parameters():
    if param.requires_grad and not name.endswith(".lora_A") and not name.endswith(".lora_B"):
        param.data = param.data.to(torch.float32)

model.print_trainable_parameters()


# Load PEFT Model

In [None]:
# @title Default title text
from torch.utils.data import Dataset
from PIL import Image
from llava.conversation import conv_templates
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
from transformers import TrainingArguments, Trainer, AutoTokenizer
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
from llava.model.builder import load_pretrained_model
import pandas as pd
import os
import torch

model_file_path = "/content/drive/MyDrive/lora_output/model (BitFit 2 Epoch)"

peft_config = PeftConfig.from_pretrained(model_file_path)


model_name = get_model_name_from_path(peft_config.base_model_name_or_path)

tokenizer, base_model, image_processor, context_len = load_pretrained_model(
    model_path=peft_config.base_model_name_or_path,
    model_base=None,
    model_name=model_name,
    load_8bit=False,
    load_4bit=False,
)


model = PeftModel.from_pretrained(base_model, model_file_path)

for name, param in model.named_parameters():
    if any(norm in name.lower() for norm in ["norm", "ln"]):
        param.requires_grad = True

for name, param in model.named_parameters():
    if param.requires_grad and not name.endswith(".lora_A") and not name.endswith(".lora_B"):
        param.data = param.data.to(torch.float32)


# Training

In [None]:
class MultiClassChestXrayDataset(Dataset):
    def __init__(self, image_dir, csv_path, metadata_csv, tokenizer, image_processor):
        self.data = pd.read_csv(csv_path)
        self.metadata = pd.read_csv(metadata_csv).set_index("Image Index")[["Patient Age", "Patient Sex"]].to_dict(orient="index")
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.image_processor = image_processor

        self.label_to_condition = {
            0: "No Finding",
            1: "Pneumothorax",
            2: "Cardiomegaly",
            3: "Infiltration",
            4: "Effusion",
            5: "Emphysema",
            6: "Atelectasis"
        }

        self.data = self.data[self.data.iloc[:, 0].apply(lambda f: os.path.exists(os.path.join(image_dir, f)))].reset_index(drop=True)
        print(f"Dataset initialized with {len(self.data)} valid samples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename = self.data.iloc[idx, 0]
        label = int(self.data.iloc[idx, 1])

        meta = self.metadata.get(filename, {})
        age = meta.get('Patient Age', 'unknown')
        sex = meta.get('Patient Sex', 'unknown')

        question = f"You are an expert radiologist. This chest X-ray is from a {age}-year-old {sex} patient. What condition does this chest X-ray show? Choose from: 'No Finding', 'Pneumothorax', 'Cardiomegaly', 'Infiltration', 'Effusion', 'Emphysema', or 'Atelectasis'."
        answer = self.label_to_condition[label]

        conv = conv_templates["mistral_instruct"].copy()
        conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question)
        conv.append_message(conv.roles[1], answer)
        conversation = conv.get_prompt()


        input_ids = tokenizer_image_token(conversation, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        labels = input_ids.clone()


        assistant_start_token = self.tokenizer.encode(conv.roles[1], add_special_tokens=False)[0]
        try:
            assistant_idx = (input_ids == assistant_start_token).nonzero(as_tuple=True)[0]
            if len(assistant_idx) > 0:
             
                labels[:assistant_idx[0]] = -100
        except:
     
            labels[:len(labels)//2] = -100

 
        image_path = os.path.join(self.image_dir, filename)
        try:
            image = Image.open(image_path).convert("RGB")
            pixel_values = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
        except Exception as e:
            print(f"Error loading image {filename}: {e}")

            pixel_values = torch.zeros((3, 224, 224))

        return {
            "input_ids": input_ids.flatten(),
            "labels": labels.flatten(),
            "pixel_values": pixel_values,
        }

In [None]:
train_dataset = MultiClassChestXrayDataset(
    "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/train",
    "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/train.csv",
    "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/pneumothorax_combined_open_classification.csv",
    tokenizer, image_processor
)

val_dataset = MultiClassChestXrayDataset(
    "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/val",
    "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/val.csv",
    "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/pneumothorax_combined_open_classification.csv",
    tokenizer, image_processor
)

def collate_fn(instances):
    input_ids = [instance['input_ids'] for instance in instances]
    labels = [instance['labels'] for instance in instances]
    images = [instance['pixel_values'] for instance in instances]


    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=-100
    )
    images = torch.stack(images, dim=0)

    return {
        'input_ids': input_ids,
        'labels': labels,
        'images': images,
        'attention_mask': input_ids.ne(tokenizer.pad_token_id),
    }

In [None]:
from transformers import Trainer

class LlavaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        try:

            images = inputs.pop("images", None)

        
            outputs = model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask', None),
                labels=inputs['labels'],
                images=images
            )

            loss = outputs.loss
            return (loss, outputs) if return_outputs else loss

        except Exception as e:
            print(f"Error in compute_loss: {e}")
            print(f"Input shapes: input_ids={inputs['input_ids'].shape}, labels={inputs['labels'].shape}")
            if 'images' in inputs:
                print(f"images shape: {inputs['images'].shape}")
            raise e


In [None]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

training_args = TrainingArguments(
    output_dir="./lora_output",
    per_device_train_batch_size=2,  
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,   
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=100,
    logging_strategy="steps",
    logging_steps=10,
    logging_dir="./logs",
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_steps=50,
    fp16=True,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    group_by_length=False,  
)

trainer = LlavaTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=collate_fn,
)

trainer.train()


In [None]:
model.save_pretrained("/content/drive/MyDrive/lora_open_output/model (Open 1 Epoch)")
tokenizer.save_pretrained("/content/drive/MyDrive/lora_open_output/model (Open 1 Epoch)")


# Inference

In [None]:
import os
import random
import pandas as pd
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from llava.mm_utils import tokenizer_image_token
from llava.conversation import conv_templates
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN

condition_to_label = {
    "no finding": 0,
    "pneumothorax": 1,
    "cardiomegaly": 2,
    "infiltration": 3,
    "effusion": 4,
    "emphysema": 5,
    "atelectasis": 6
}

label_to_condition = {
    0: "No Finding",
    1: "Pneumothorax",
    2: "Cardiomegaly",
    3: "Infiltration",
    4: "Effusion",
    5: "Emphysema",
    6: "Atelectasis"
}

def parse_response_to_label(response):
    """Parse model response to extract predicted label"""
    response_lower = response.lower().strip()


    if response_lower in condition_to_label:
        return condition_to_label[response_lower]

  
    for condition, label in condition_to_label.items():
        if condition in response_lower:
            return label

   
    for condition in condition_to_label.keys():
        if condition.replace(" ", "") in response_lower.replace(" ", ""):
            return condition_to_label[condition]

    return -1

def generate_response(image_path, metadata=None):
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']

      
        if hasattr(model, 'device'):
            image_tensor = image_tensor.to(model.device, dtype=torch.float32)
        else:
            image_tensor = image_tensor.to('cuda', dtype=torch.float32)

       
        age = metadata.get('Patient Age', 'unknown') if metadata else 'unknown'
        sex = metadata.get('Patient Sex', 'unknown') if metadata else 'unknown'

        question = f"You are an expert radiologist. This chest X-ray is from a {age}-year-old {sex} patient. What condition does this chest X-ray show? Choose from: No Finding, Pneumothorax, Cardiomegaly, Infiltration, Effusion, Emphysema, or Atelectasis."

        conv = conv_templates["mistral_instruct"].copy()
        conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + question)
        conv.append_message(conv.roles[1], None)
        prompt_formatted = conv.get_prompt()

    
        input_ids = tokenizer_image_token(prompt_formatted, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

      
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)

        input_ids = input_ids.to(model.device)

        with torch.inference_mode():
            output_ids = model.base_model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True,
                temperature=0.2,
                top_p=0.9,
                max_new_tokens=64,  
                use_cache=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )


        input_token_len = input_ids.shape[1]


        if output_ids.shape[1] > input_token_len:
            outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        else:
            outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
      
            if question in outputs:
                outputs = outputs.split(question)[-1]

        return outputs.strip()

    except Exception as e:
        print(f"Error in generate_response: {e}")
        return "Error generating response"

In [None]:
from llava.model.builder import load_pretrained_model
from peft import PeftModel

model = PeftModel.from_pretrained(base_model, "/content/drive/MyDrive/lora_open_output/model (Open 1 Epoch)")
model.eval()

In [None]:
import torch

test_folder = "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/test"
csv_path = f"/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/test.csv"
metadata_csv_path = "/content/drive/MyDrive/preprocessed_pneumothorax_open_classification/pneumothorax_combined_open_classification.csv"

df = pd.read_csv(csv_path)
filename_to_label = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))
metadata_df = pd.read_csv(metadata_csv_path)
metadata_lookup = metadata_df.set_index("Image Index")[["Patient Age", "Patient Sex"]].to_dict(orient="index")

true_labels, predicted_labels = [], []
parse_failures = 0
shuffled_items = list(filename_to_label.items())
random.seed(42)
random.shuffle(shuffled_items)
model = model.float().eval()

for i, (filename, label) in enumerate(shuffled_items):
    image_path = os.path.join(test_folder, filename)
    if not os.path.exists(image_path):
        continue

    metadata = metadata_lookup.get(filename)
    try:
        response = generate_response(image_path, metadata)
        predicted_label = parse_response_to_label(response)

        if predicted_label == -1:
            parse_failures += 1
            print(f"Failed to parse response: '{response}' for {filename}")
            continue

        true_labels.append(int(label))
        predicted_labels.append(predicted_label)

        print(f"Filename       : {filename}")
        print(f"Ground Truth   : {label_to_condition[int(label)]}")
        print(f"Model Response : {response}")
        print(f"Parsed Label   : {label_to_condition[predicted_label]}")
        print("-" * 60)

        if (i + 1) % 20 == 0:
         
            total = len(true_labels)
            correct = sum([1 for t, p in zip(true_labels, predicted_labels) if t == p])
            accuracy = correct / total if total > 0 else 0

     
            precision_macro = precision_score(true_labels, predicted_labels, average='macro', zero_division=0)
            recall_macro = recall_score(true_labels, predicted_labels, average='macro', zero_division=0)
            f1_macro = f1_score(true_labels, predicted_labels, average='macro', zero_division=0)

            print(f"=== INTERMEDIATE RESULTS ({i + 1} samples processed, {total} valid) ===")
            print(f"Parse failures: {parse_failures}")
            print(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
            print(f"Macro Precision: {precision_macro:.4f}")
            print(f"Macro Recall: {recall_macro:.4f}")
            print(f"Macro F1-Score: {f1_macro:.4f}")
            print("=" * 60)

    except Exception as e:
        print(f"Error processing {filename}: {e}")
        continue


from sklearn.metrics import (
    confusion_matrix, precision_score, recall_score, f1_score,
    balanced_accuracy_score, matthews_corrcoef, classification_report
)

if true_labels and predicted_labels:
    total = len(true_labels)
    correct = sum([1 for t, p in zip(true_labels, predicted_labels) if t == p])
    accuracy = correct / total

    precision_macro = precision_score(true_labels, predicted_labels, average='macro', zero_division=0)
    precision_micro = precision_score(true_labels, predicted_labels, average='micro', zero_division=0)
    precision_weighted = precision_score(true_labels, predicted_labels, average='weighted', zero_division=0)

    recall_macro = recall_score(true_labels, predicted_labels, average='macro', zero_division=0)
    recall_micro = recall_score(true_labels, predicted_labels, average='micro', zero_division=0)
    recall_weighted = recall_score(true_labels, predicted_labels, average='weighted', zero_division=0)

    f1_macro = f1_score(true_labels, predicted_labels, average='macro', zero_division=0)
    f1_micro = f1_score(true_labels, predicted_labels, average='micro', zero_division=0)
    f1_weighted = f1_score(true_labels, predicted_labels, average='weighted', zero_division=0)

    balanced_acc = balanced_accuracy_score(true_labels, predicted_labels)
    mcc = matthews_corrcoef(true_labels, predicted_labels)

    print("\n=== FINAL ACCURACY RESULTS ===")
    print(f"Total samples processed: {len(shuffled_items)}")
    print(f"Valid predictions: {total}")
    print(f"Parse failures: {parse_failures}")
    print(f"Correct predictions: {correct}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")

    print("\n=== CONFUSION MATRIX ===")
    cm = confusion_matrix(true_labels, predicted_labels, labels=list(range(7)))
    print("Confusion Matrix (rows=true, cols=predicted):")
    print("Classes: 0=No Finding, 1=Pneumothorax, 2=Cardiomegaly, 3=Infiltration, 4=Effusion, 5=Emphysema, 6=Atelectasis")
    print(cm)

    print("\n=== METRICS ===")
    print(f"Precision (Macro): {precision_macro:.4f}")
    print(f"Precision (Micro): {precision_micro:.4f}")
    print(f"Precision (Weighted): {precision_weighted:.4f}")
    print(f"Recall (Macro): {recall_macro:.4f}")
    print(f"Recall (Micro): {recall_micro:.4f}")
    print(f"Recall (Weighted): {recall_weighted:.4f}")
    print(f"F1-Score (Macro): {f1_macro:.4f}")
    print(f"F1-Score (Micro): {f1_micro:.4f}")
    print(f"F1-Score (Weighted): {f1_weighted:.4f}")
    print(f"Balanced Accuracy: {balanced_acc:.4f}")
    print(f"MCC: {mcc:.4f}")

    print("\n=== CLASS-WISE STATISTICS ===")

    from collections import Counter
    true_counts = Counter(true_labels)
    pred_counts = Counter(predicted_labels)

    print("True Label Distribution:")
    for label in range(7):
        count = true_counts.get(label, 0)
        percentage = (count / len(true_labels)) * 100 if len(true_labels) > 0 else 0
        print(f"  {label_to_condition[label]}: {count} samples ({percentage:.1f}%)")

    print("\nPredicted Label Distribution:")
    for label in range(7):
        count = pred_counts.get(label, 0)
        percentage = (count / len(predicted_labels)) * 100 if len(predicted_labels) > 0 else 0
        print(f"  {label_to_condition[label]}: {count} samples ({percentage:.1f}%)")

  
    print("\nPer-Class Performance:")
    precision_per_class = precision_score(true_labels, predicted_labels, average=None, zero_division=0, labels=range(7))
    recall_per_class = recall_score(true_labels, predicted_labels, average=None, zero_division=0, labels=range(7))
    f1_per_class = f1_score(true_labels, predicted_labels, average=None, zero_division=0, labels=range(7))

    for label in range(7):
        tp = sum([1 for t, p in zip(true_labels, predicted_labels) if t == label and p == label])
        fn = sum([1 for t, p in zip(true_labels, predicted_labels) if t == label and p != label])
        fp = sum([1 for t, p in zip(true_labels, predicted_labels) if t != label and p == label])
        tn = sum([1 for t, p in zip(true_labels, predicted_labels) if t != label and p != label])

        accuracy_class = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0

        print(f"\n{label_to_condition[label]} (Label {label}):")
        print(f"  True Positives: {tp}")
        print(f"  False Positives: {fp}")
        print(f"  False Negatives: {fn}")
        print(f"  True Negatives: {tn}")
        print(f"  Class Accuracy: {accuracy_class:.4f} ({accuracy_class * 100:.2f}%)")
        print(f"  Precision: {precision_per_class[label]:.4f} ({precision_per_class[label] * 100:.2f}%)")
        print(f"  Recall: {recall_per_class[label]:.4f} ({recall_per_class[label] * 100:.2f}%)")
        print(f"  F1-Score: {f1_per_class[label]:.4f} ({f1_per_class[label] * 100:.2f}%)")

    print("\n=== DETAILED CLASSIFICATION REPORT ===")
    target_names = [label_to_condition[i] for i in range(7)]
    print(classification_report(true_labels, predicted_labels, target_names=target_names, zero_division=0))