In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import argparse
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import sys

sys.path.append(os.path.abspath('transformer_maskgit'))
from transformer_maskgit import CTViT
from ct_clip import CTCLIP

# Sigmoid function
def sigmoid(tensor):
    return 1 / (1 + torch.exp(-tensor))

# Classifier model
class ImageLatentsClassifier(nn.Module):
    def __init__(self, trained_model, latent_dim, num_classes, dropout_prob=0.3):
        super(ImageLatentsClassifier, self).__init__()
        self.trained_model = trained_model
        self.dropout = nn.Dropout(dropout_prob)
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(latent_dim, num_classes)

    def forward(self, latents=False, *args, **kwargs):
        kwargs['return_latents'] = True
        _, image_latents = self.trained_model(*args, **kwargs)
        image_latents = self.relu(image_latents)
        if latents:
            return image_latents
        image_latents = self.dropout(image_latents)
        return self.classifier(image_latents)

# Custom dataset for inference
class CTReportDatasetinfer(Dataset):
    def __init__(self, data_folder, csv_file, labels_csv):
        """
        data_folder: Directory containing .npz files (e.g., 'NIFTI PRE CT SCANS')
        csv_file: Path to metadata.csv
        labels_csv: Path to CSV with pathology labels (PatientID, ScanName, and 18 pathology columns)
        """
        self.data_folder = data_folder
        self.metadata = pd.read_csv(csv_file)
        self.labels_df = pd.read_csv(labels_csv)
        self.npz_files = []
        for root, _, files in os.walk(data_folder):
            for file in files:
                if file.endswith('.npz'):
                    self.npz_files.append(os.path.join(root, file))
        self.pathologies = ['Medical material', 'Arterial wall calcification', 'Cardiomegaly', 'Pericardial effusion',
                            'Coronary artery wall calcification', 'Hiatal hernia', 'Lymphadenopathy', 'Emphysema',
                            'Atelectasis', 'Lung nodule', 'Lung opacity', 'Pulmonary fibrotic sequela',
                            'Pleural effusion', 'Mosaic attenuation pattern', 'Peribronchial thickening',
                            'Consolidation', 'Bronchiectasis', 'Interlobular septal thickening']

    def __len__(self):
        return len(self.npz_files)

    def __getitem__(self, idx):
        npz_path = self.npz_files[idx]
        patient_id = os.path.basename(os.path.dirname(npz_path))
        scan_name = os.path.basename(npz_path).replace('.npz', '')
        
        # Load .npz file
        data = np.load(npz_path)
        img_data = data['arr_0'].astype(np.float32)
        
        # Resize to (480, 480, 480)
        target_shape = (480, 480, 480)
        img_data = self.resize_volume(img_data, target_shape)
        
        # Convert to tensor and add channel dimension
        img_tensor = torch.tensor(img_data).unsqueeze(0)  # Shape: (1, D, H, W)
        
        # Get labels from labels_csv
        label_row = self.labels_df[(self.labels_df['PatientID'] == patient_id) & 
                                  (self.labels_df['ScanName'] == scan_name)]
        if label_row.empty:
            labels = torch.zeros(len(self.pathologies), dtype=torch.float32)
        else:
            labels = torch.tensor(label_row[self.pathologies].values, dtype=torch.float32).squeeze()
        
        # Accession number (use PatientID_ScanName as unique identifier)
        acc_no = f"{patient_id}_{scan_name}"
        
        return img_tensor, None, labels, acc_no

    def resize_volume(self, volume, target_shape):
        """
        Resize 3D volume to target shape using trilinear interpolation.
        """
        volume_tensor = torch.tensor(volume).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, D, H, W)
        resized = F.interpolate(volume_tensor, size=target_shape, mode='trilinear', align_corners=False)
        return resized.squeeze().numpy()

# Evaluation function
def evaluate_model(args, model, dataloader, device):
    model.eval()
    model = model.to(device)
    predictedall = []
    realall = []
    accs = []
    pathologies = ['Medical material', 'Arterial wall calcification', 'Cardiomegaly', 'Pericardial effusion',
                   'Coronary artery wall calcification', 'Hiatal hernia', 'Lymphadenopathy', 'Emphysema',
                   'Atelectasis', 'Lung nodule', 'Lung opacity', 'Pulmonary fibrotic sequela',
                   'Pleural effusion', 'Mosaic attenuation pattern', 'Peribronchial thickening',
                   'Consolidation', 'Bronchiectasis', 'Interlobular septal thickening']

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            inputs, _, labels, acc_no = batch
            labels = labels.float().to(device)
            inputs = inputs.to(device)
            text_tokens = tokenizer("", return_tensors="pt", padding="max_length", 
                                  truncation=True, max_length=200).to(device)
            output = model(False, text_tokens, inputs, device=device)
            predicted = sigmoid(output).cpu().numpy()[0]
            predictedall.append(predicted)
            realall.append(labels.cpu().numpy()[0])
            accs.append(acc_no[0])
            print(f"Accession: {acc_no[0]}", flush=True)

    # Save results
    plotdir = args.save
    os.makedirs(plotdir, exist_ok=True)
    
    with open(f"{plotdir}/accessions.txt", "w") as file:
        for item in accs:
            file.write(f"{item}\n")

    predictedall = np.array(predictedall)
    realall = np.array(realall)
    
    np.savez(f"{plotdir}/labels_weights.npz", data=realall)
    np.savez(f"{plotdir}/predicted_weights.npz", data=predictedall)

    # Generate classification report
    predicted_binary = (predictedall > 0.5).astype(int)
    report = classification_report(realall, predicted_binary, target_names=pathologies, output_dict=True)
    report_df = pd.DataFrame(report).transpose()
    
    # Save AUROC and classification report
    writer = pd.ExcelWriter(f'{plotdir}/aurocs.xlsx', engine='xlsxwriter')
    report_df.to_excel(writer, sheet_name='Classification Report')
    writer.close()

    print("Evaluation complete. Results saved to:", plotdir)

if __name__ == '__main__':
    # Parse arguments
    parser = argparse.ArgumentParser(description="Inference for CT pathology classification")
    parser.add_argument('--data_folder', type=str, 
                        default=r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\NIFTI PRE CT SCANS",
                        help="Path to preprocessed .npz files")
    parser.add_argument('--reports_file', type=str, 
                        default=r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\metadata.csv",
                        help="Path to metadata CSV")
    parser.add_argument('--labels', type=str, 
                        default=r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\labels.csv",
                        help="Path to labels CSV")
    parser.add_argument('--pretrained', type=str, 
                        default=r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\CT-CLIP\CT_VocabFine_v2.pt",
                        help="Path to pretrained model checkpoint")
    parser.add_argument('--save', type=str, 
                        default=r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\results",
                        help="Directory to save results")
    args = parser.parse_args()

    # Initialize tokenizer and models
    tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized', do_lower_case=True)
    text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized")
    text_encoder.resize_token_embeddings(len(tokenizer))

    image_encoder = CTViT(
        dim=512,
        codebook_size=8192,
        image_size=480,
        patch_size=20,
        temporal_patch_size=10,
        spatial_depth=4,
        temporal_depth=4,
        dim_head=32,
        heads=8
    )

    clip = CTCLIP(
        image_encoder=image_encoder,
        text_encoder=text_encoder,
        dim_image=294912,
        dim_text=768,
        dim_latent=512,
        extra_latent_projection=False,
        use_mlm=False,
        downsample_image_embeds=False,
        use_all_token_embeds=False
    )

    num_classes = 18
    image_classifier = ImageLatentsClassifier(clip, 512, num_classes)
    
    # Load pretrained weights
    image_classifier.load(args.pretrained)

    # Prepare dataset and dataloader
    ds = CTReportDatasetinfer(data_folder=args.data_folder, csv_file=args.reports_file, labels_csv=args.labels)
    dl = DataLoader(ds, num_workers=8, batch_size=1, shuffle=False)

    # Evaluate model
    evaluate_model(args, image_classifier, dl, torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

ModuleNotFoundError: No module named 'transformer_maskgit'