In [1]:
import os
import numpy as np
import torch
import pydicom
from pathlib import Path
import pickle
from tqdm import tqdm
import torch.nn.functional as F
from scipy.ndimage import zoom
import warnings
warnings.filterwarnings('ignore')

print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")
print(f"PyTorch version: {torch.__version__}")

Using device: cuda
PyTorch version: 2.1.0+cu121


In [2]:
# Import CT-CLIP - adjust path as needed
import sys
sys.path.append(r'CT_CLIP')  # Adjust this path to your CT-CLIP directory

try:
    from ct_clip.ct_clip import CTCLIP
    print("✓ CT-CLIP imported successfully!")
except ImportError as e:
    print(f"✗ Error importing CT-CLIP: {e}")
    print("Make sure the CT_CLIP path is correct and the module is installed")

✓ CT-CLIP imported successfully!


In [3]:
def load_dicom_series(folder_path):
    """
    Load a series of DICOM files from a folder and create a 3D volume.
    
    Args:
        folder_path (str): Path to folder containing DICOM files
        
    Returns:
        numpy.ndarray: 3D CT volume (H, W, D)
    """
    dicom_files = []
    folder = Path(folder_path)
    
    # Get all .dcm files
    dcm_files = list(folder.glob("*.dcm"))
    if not dcm_files:
        raise ValueError(f"No .dcm files found in {folder_path}")
    
    print(f"Found {len(dcm_files)} DICOM files in {folder.name}")
    
    # Load DICOM files
    for dcm_file in dcm_files:
        try:
            ds = pydicom.dcmread(dcm_file)
            dicom_files.append(ds)
        except Exception as e:
            print(f"Error reading {dcm_file}: {e}")
            continue
    
    if not dicom_files:
        raise ValueError(f"No valid DICOM files could be loaded from {folder_path}")
    
    # Sort by slice location or instance number
    try:
        dicom_files.sort(key=lambda x: float(x.SliceLocation))
        print("Sorted by SliceLocation")
    except (AttributeError, ValueError):
        try:
            dicom_files.sort(key=lambda x: int(x.InstanceNumber))
            print("Sorted by InstanceNumber")
        except (AttributeError, ValueError):
            print("Warning: Could not sort DICOM files by slice location or instance number")
    
    # Extract pixel arrays and create 3D volume
    slices = []
    for ds in dicom_files:
        # Apply DICOM rescaling
        pixel_array = ds.pixel_array.astype(np.float32)
        
        # Apply rescale slope and intercept if available (convert to Hounsfield Units)
        if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
            pixel_array = pixel_array * ds.RescaleSlope + ds.RescaleIntercept
        
        slices.append(pixel_array)
    
    # Stack slices to create 3D volume
    volume = np.stack(slices, axis=-1)  # Shape: (H, W, D)
    
    print(f"Loaded CT volume with shape: {volume.shape}")
    return volume

print("✓ DICOM loading function defined")

✓ DICOM loading function defined


In [4]:
# Configuration - adjust these paths
base_folder = r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\TEST CT SCANS"

# Get all subdirectories (each containing one CT scan)
ct_scan_folders = [d for d in Path(base_folder).iterdir() if d.is_dir()]

if not ct_scan_folders:
    print(f"No subdirectories found in {base_folder}")
    print("Make sure your CT scans are organized in separate folders")
else:
    print(f"Found {len(ct_scan_folders)} CT scan folders:")
    for folder in ct_scan_folders:
        print(f"  - {folder.name}")
    
    # Test loading the first CT scan
    test_folder = ct_scan_folders[0]
    print(f"\nTesting DICOM loading with: {test_folder.name}")
    
    try:
        ct_volume = load_dicom_series(test_folder)
        print(f"✓ Successfully loaded CT volume with shape: {ct_volume.shape}")
        print(f"✓ HU range: [{ct_volume.min():.1f}, {ct_volume.max():.1f}]")
    except Exception as e:
        print(f"✗ Error loading DICOM: {e}")
        import traceback
        traceback.print_exc()

Found 3 CT scan folders:
  - THX AX MIP 10 8
  - THX BB COR 3 3
  - THX BB SAG 3 3

Testing DICOM loading with: THX AX MIP 10 8
Found 45 DICOM files in THX AX MIP 10 8
Sorted by SliceLocation
Loaded CT volume with shape: (512, 512, 45)
✓ Successfully loaded CT volume with shape: (512, 512, 45)
✓ HU range: [-1024.0, 3071.0]


In [5]:
def preprocess_ct_volume(volume, target_size=(224, 224, 224)):
    """
    Preprocess CT volume for CT-CLIP model.
    Based on typical medical image preprocessing practices.
    
    Args:
        volume (numpy.ndarray): Input CT volume (H, W, D)
        target_size (tuple): Target size (H, W, D)
        
    Returns:
        torch.Tensor: Preprocessed volume ready for model input
    """
    print(f"Original volume shape: {volume.shape}")
    print(f"Volume HU range: [{volume.min():.1f}, {volume.max():.1f}]")
    
    # Apply windowing (typical for CT scans)
    # Using a general soft tissue window: center=40, width=350
    window_center, window_width = 40, 350
    window_min = window_center - window_width // 2
    window_max = window_center + window_width // 2
    
    # Apply windowing
    volume_windowed = np.clip(volume, window_min, window_max)
    
    # Normalize to [0, 1]
    volume_norm = (volume_windowed - window_min) / (window_max - window_min)
    
    print(f"After windowing and normalization: [{volume_norm.min():.3f}, {volume_norm.max():.3f}]")
    
    # Resize volume to target size
    current_shape = volume_norm.shape
    zoom_factors = [target_size[i] / current_shape[i] for i in range(3)]
    print(f"Zoom factors: {zoom_factors}")
    
    volume_resized = zoom(volume_norm, zoom_factors, order=1)
    print(f"Resized volume shape: {volume_resized.shape}")
    
    # Convert to tensor and add batch and channel dimensions
    # CT-CLIP expects shape: (batch, channels, depth, height, width)
    volume_tensor = torch.from_numpy(volume_resized).float()
    volume_tensor = volume_tensor.permute(2, 0, 1)  # (D, H, W)
    volume_tensor = volume_tensor.unsqueeze(0).unsqueeze(0)  # (1, 1, D, H, W)
    
    print(f"Final tensor shape: {volume_tensor.shape}")
    return volume_tensor

print("✓ Preprocessing function defined")

✓ Preprocessing function defined


In [6]:
# Test preprocessing on the loaded CT volume
if 'ct_volume' in locals():
    try:
        preprocessed_volume = preprocess_ct_volume(ct_volume)
        print(f"✓ Successfully preprocessed CT volume")
        print(f"✓ Ready for model input with shape: {preprocessed_volume.shape}")
    except Exception as e:
        print(f"✗ Error preprocessing CT volume: {e}")
        import traceback
        traceback.print_exc()
else:
    print("✗ No CT volume loaded. Run the DICOM loading cell first.")

Original volume shape: (512, 512, 45)
Volume HU range: [-1024.0, 3071.0]
After windowing and normalization: [0.000, 1.000]
Zoom factors: [0.4375, 0.4375, 4.977777777777778]
Resized volume shape: (224, 224, 224)
Final tensor shape: torch.Size([1, 1, 224, 224, 224])
✓ Successfully preprocessed CT volume
✓ Ready for model input with shape: torch.Size([1, 1, 224, 224, 224])


In [14]:
import torch
import os
from transformers import AutoModel, AutoTokenizer
from ct_clip.ct_clip import CTCLIP, TextTransformer, VisionTransformer

def load_ctclip_model(
    model_path=r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\CT-CLIP\CT_VocabFine_v2.pt",
    device='cuda'
):
    """
    Load CT-CLIP model with proper BiomedVLP tokenizer.
    
    Args:
        model_path (str): Path to pretrained model weights (optional)
        device (str): Device to use
        
    Returns:
        CTCLIP model, tokenizer, biomed_model (for text embeddings)
    """
    print("Initializing CT-CLIP model...")
    
    # Load the specialized biomedical tokenizer and model
    print("Loading BiomedVLP-CXR-BERT-specialized tokenizer and model...")
    url = "microsoft/BiomedVLP-CXR-BERT-specialized"
    try:
        tokenizer = AutoTokenizer.from_pretrained(url, trust_remote_code=True)
        biomed_model = AutoModel.from_pretrained(url, trust_remote_code=True)
        biomed_model.to(device)
        biomed_model.eval()
        print("✓ BiomedVLP tokenizer and model loaded successfully")
    except Exception as e:
        print(f"Error loading BiomedVLP model: {e}")
        raise
    
    # Define text encoder (TextTransformer from ct_clip.py)
    text_encoder = TextTransformer(
        dim=768,  # Matches dim_text
        num_tokens=28897,  # Matches default in CTCLIP
        max_seq_len=256,  # Matches text_seq_len
        depth=12,  # Matches checkpoint's layer count (inferred from text_transformer.encoder.layer.11)
        heads=12,  # Matches BERT-style attention
        dim_head=64,
        rotary_pos_emb=False,  # Default in CTCLIP
        causal=False  # Default in CTCLIP
    )
    
    # Define image encoder (VisionTransformer from ct_clip.py)
    image_encoder = VisionTransformer(
        dim=768,  # Matches expected output dimension
        image_size=224,
        patch_size=16,
        channels=1,  # Grayscale CT scans
        depth=12,  # Matches visual_enc_depth
        heads=12,  # Matches visual_heads
        dim_head=64,  # Matches visual_dim_head
        patch_dropout=0.5  # Default in CTCLIP
    )
    
    # Initialize CTCLIP model
    model = CTCLIP(
        image_encoder=image_encoder,
        text_encoder=text_encoder,
        dim_image=294912,  # Matches checkpoint's to_visual_latent.weight
        dim_text=768,
        dim_latent=512,
        num_text_tokens=28897,
        text_enc_depth=12,
        text_seq_len=256,
        text_heads=12,
        text_dim_head=64,
        text_has_cls_token=False,
        text_pad_id=0,
        text_rotary_pos_emb=False,
        text_causal_mask=False,
        visual_enc_depth=12,
        visual_heads=12,
        visual_dim_head=64,
        visual_image_size=224,
        visual_patch_size=16,
        visual_patch_dropout=0.5,
        visual_has_cls_token=False,
        channels=1,
        use_all_token_embeds=False,
        downsample_image_embeds=False,
        extra_latent_projection=False,
        use_mlm=False
    )
    
    if model_path and os.path.exists(model_path):
        print(f"Loading pretrained weights from {model_path}")
        try:
            checkpoint = torch.load(model_path, map_location='cpu')
            # Check if checkpoint is nested
            state_dict = checkpoint.get('model_state_dict', checkpoint)
            model.load_state_dict(state_dict, strict=False)  # Use strict=False to handle partial matches
            print("✓ Pretrained weights loaded successfully")
        except Exception as e:
            print(f"Warning: Could not load pretrained weights: {e}")
            print("Using randomly initialized model")
    else:
        print("No pretrained weights specified, using randomly initialized model")
        print("Note: For meaningful features, you should use pretrained weights")
    
    model.to(device)
    model.eval()
    return model, tokenizer, biomed_model


def get_text_embeddings(text_prompts, tokenizer, biomed_model, device='cuda'):
    """
    Get text embeddings using the BiomedVLP model.
    
    Args:
        text_prompts (list): List of text strings
        tokenizer: BiomedVLP tokenizer
        biomed_model: BiomedVLP model
        device (str): Device to use
        
    Returns:
        torch.Tensor: Text embeddings
    """
    # Tokenize and compute the sentence embeddings
    tokenizer_output = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=text_prompts,
        add_special_tokens=True,
        padding='longest',
        return_tensors='pt'
    )
    
    # Move tokenizer output to device
    tokenizer_output = {k: v.to(device) for k, v in tokenizer_output.items()}
    
    with torch.no_grad():
        embeddings = biomed_model.get_projected_text_embeddings(
            input_ids=tokenizer_output['input_ids'],
            attention_mask=tokenizer_output['attention_mask']
        )
    
    return embeddings


def compute_text_similarity(text_prompts, tokenizer, biomed_model, device='cuda'):
    """
    Compute cosine similarity between text prompts.
    
    Args:
        text_prompts (list): List of text strings
        tokenizer: BiomedVLP tokenizer
        biomed_model: BiomedVLP model
        device (str): Device to use
        
    Returns:
        torch.Tensor: Similarity matrix
    """
    embeddings = get_text_embeddings(text_prompts, tokenizer, biomed_model, device)
    
    # Compute the cosine similarity of sentence embeddings
    sim = torch.mm(embeddings, embeddings.t())
    return sim


# Configuration
model_path = r"C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\CT-CLIP\CT_VocabFine_v2.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

try:
    model, tokenizer, biomed_model = load_ctclip_model(model_path, device)
    print("✓ Model and tokenizer loaded successfully!")
    print(f"✓ CT-CLIP model is on device: {next(model.parameters()).device}")
    print(f"✓ BiomedVLP model is on device: {next(biomed_model.parameters()).device}")
    
    # Example usage of text embeddings
    text_prompts = [
        "There is no pneumothorax or pleural effusion",
        "No pleural effusion or pneumothorax is seen",
        "The extent of the pleural effusion is constant."
    ]
    
    print("\nTesting text embeddings...")
    embeddings = get_text_embeddings(text_prompts, tokenizer, biomed_model, device)
    print(f"✓ Text embeddings shape: {embeddings.shape}")
    
    # Compute similarity matrix
    similarity_matrix = compute_text_similarity(text_prompts, tokenizer, biomed_model, device)
    print(f"✓ Similarity matrix shape: {similarity_matrix.shape}")
    print("✓ Similarity matrix:")
    print(similarity_matrix)
    
except Exception as e:
    print(f"✗ Error loading model: {e}")
    import traceback
    traceback.print_exc()

Initializing CT-CLIP model...
Loading BiomedVLP-CXR-BERT-specialized tokenizer and model...
✓ BiomedVLP tokenizer and model loaded successfully


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'CXRBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.


Loading pretrained weights from C:\Users\20203686\OneDrive - TU Eindhoven\TUe\Master\Year 2\MASTER PROJECT\CT-CLIP\CT_VocabFine_v2.pt
✓ Pretrained weights loaded successfully
✓ Model and tokenizer loaded successfully!
✓ CT-CLIP model is on device: cuda:0
✓ BiomedVLP model is on device: cuda:0

Testing text embeddings...
✓ Text embeddings shape: torch.Size([3, 128])
✓ Similarity matrix shape: torch.Size([3, 3])
✓ Similarity matrix:
tensor([[ 1.0000,  0.7456, -0.1916],
        [ 0.7456,  1.0000, -0.4709],
        [-0.1916, -0.4709,  1.0000]], device='cuda:0')


In [22]:
import torch
import torch.nn.functional as F

def extract_features(ct_volume_tensor, model, tokenizer, biomed_model, device='cuda'):
    """
    Extract features for specified medical conditions from a CT volume using CT-CLIP.
    
    Args:
        ct_volume_tensor (torch.Tensor): Preprocessed CT volume tensor (1, 1, D, H, W)
        model (CTCLIP): Loaded CT-CLIP model
        tokenizer: BiomedVLP tokenizer
        biomed_model: BiomedVLP model
        device (str): Device to use ('cuda' or 'cpu')
        
    Returns:
        dict: Dictionary with condition names as keys and binary labels (0 or 1) as values
    """
    # Define the list of conditions based on training data
    conditions = [
        "Arterial wall calcification",
        "Cardiomegaly",
        "Pericardial effusion",
        "Coronary artery wall calcification",
        "Hiatal hernia",
        "Lymphadenopathy",
        "Emphysema",
        "Atelectasis",
        "Lung nodule",
        "Lung opacity",
        "Pulmonary fibrotic sequela",
        "Pleural effusion",
        "Mosaic attenuation pattern",
        "Peribronchial thickening",
        "Consolidation",
        "Bronchiectasis",
        "Interlobular septal thickening"
    ]
    
    # Create text prompts for presence and absence of each condition
    text_prompts = [f"Presence of {condition}" for condition in conditions] + \
                   [f"No {condition}" for condition in conditions]
    
    print(f"Generated {len(text_prompts)} text prompts for {len(conditions)} conditions")
    
    # Move CT volume tensor to device
    ct_volume_tensor = ct_volume_tensor.to(device)
    
    # Generate CT volume embedding
    with torch.no_grad():
        # Use visual_transformer instead of image_encoder
        image_embedding = model.visual_transformer(ct_volume_tensor)  # Shape: (1, num_patches + 1, dim_image) or (1, dim_image)
        # If use_all_token_embeds is True, select non-CLS tokens or average
        if model.use_all_token_embeds:
            image_embedding = image_embedding[:, 1:] if model.visual_has_cls_token else image_embedding  # Exclude CLS token
            image_embedding = torch.mean(image_embedding, dim=1)  # Average over patches
        else:
            image_embedding = image_embedding[:, 0] if model.visual_has_cls_token else image_embedding  # Use CLS token or full embedding
        # Project to latent space
        image_latent = model.to_visual_latent(image_embedding)  # Shape: (1, dim_latent)
        image_latent = F.normalize(image_latent, dim=-1)  # L2 normalize
    
    print(f"CT volume latent shape: {image_latent.shape}")
    
    # Generate text embeddings
    tokenizer_output = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=text_prompts,
        add_special_tokens=True,
        padding='longest',
        return_tensors='pt'
    )
    
    tokenizer_output = {k: v.to(device) for k, v in tokenizer_output.items()}
    
    with torch.no_grad():
        # Use BiomedVLP model for text embeddings
        text_embeddings = biomed_model.get_projected_text_embeddings(
            input_ids=tokenizer_output['input_ids'],
            attention_mask=tokenizer_output['attention_mask']
        )
        # Project to latent space using CTCLIP's text projection
        text_latents = model.to_text_latent(text_embeddings)  # Shape: (num_prompts, dim_latent)
        text_latents = F.normalize(text_latents, dim=-1)  # L2 normalize
    
    print(f"Text latents shape: {text_latents.shape}")
    
    # Compute cosine similarities
    similarities = F.cosine_similarity(image_latent, text_latents, dim=-1)  # Shape: (num_prompts,)
    print(f"Similarity scores shape: {similarities.shape}")
    
    # Classify each condition based on similarity scores
    results = {}
    for i, condition in enumerate(conditions):
        presence_score = similarities[i]
        absence_score = similarities[i + len(conditions)]
        results[condition] = 1 if presence_score > absence_score else 0
        print(f"{condition}: Presence score = {presence_score:.4f}, Absence score = {absence_score:.4f}, Predicted = {results[condition]}")
    
    return results

# Ensure required variables are available
if 'preprocessed_volume' not in locals() or 'model' not in locals() or 'tokenizer' not in locals() or 'biomed_model' not in locals():
    print("✗ Error: Required variables (preprocessed_volume, model, tokenizer, biomed_model) not found.")
    print("Please run the previous cells to load and preprocess the CT volume and model.")
else:
    try:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {device}")
        results = extract_features(preprocessed_volume, model, tokenizer, biomed_model, device)
        print("\nFeature extraction results:")
        for condition, label in results.items():
            print(f"{condition}: {label}")
    except Exception as e:
        print(f"✗ Error during feature extraction: {e}")
        import traceback
        traceback.print_exc()

Using device: cuda
Generated 34 text prompts for 17 conditions
✗ Error during feature extraction: 5


Traceback (most recent call last):
  File "C:\Users\20203686\AppData\Local\Temp\ipykernel_14352\3604583878.py", line 108, in <module>
    results = extract_features(preprocessed_volume, model, tokenizer, biomed_model, device)
  File "C:\Users\20203686\AppData\Local\Temp\ipykernel_14352\3604583878.py", line 51, in extract_features
    image_embedding = model.visual_transformer(ct_volume_tensor)  # Shape: (1, num_patches + 1, dim_image) or (1, dim_image)
  File "c:\Users\20203686\AppData\Local\anaconda3\envs\MasterProject\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "c:\Users\20203686\AppData\Local\anaconda3\envs\MasterProject\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "c:\Users\20203686\Documents\GitHub\CT-CLIP\CT_CLIP\ct_clip\ct_clip.py", line 374, in forward
    x = self.to_tokens(x)
  File "c:\Users\20203686\AppData\Local\a