In [None]:
# Step 1: Import Required Libraries
import torch
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, load_from_disk, Features, Sequence, Value, Image
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm # For progress bars
import random



In [2]:
# Step 2: Load Dataset (e.g., ROCOv2 Radiology Dataset)
print("Loading dataset...")
# Define expected features including CUIs as a sequence of strings
# Adjust based on actual dataset structure if needed
expected_features = Features({
    'image': Image(),
    'caption': Value(dtype='string', id=None),
    'image_id': Value(dtype='string', id=None),
    'cui': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
    # Add other columns if they exist and you might need them
})

# Using a small subset for faster demonstration/testing
train_split_size = 10000 # How many training samples
eval_split_size = 4000 # How many validation samples
# dataset_name = "eltorio/ROCOv2-radiology"
dataset_name = "roco_updated_cuis"

def count_unique_cuis(dataset):
    unique_cuis = set()
    for sample in dataset:
        # print(sample)
        if "cui" in sample and sample["cui"]:
            # print(f"Sample CUI: {sample['cui']}")
            unique_cuis.update(sample["cui"])  # Add all CUIs in the sample to the set
    return len(unique_cuis), unique_cuis

try:
    full_dataset = load_from_disk(dataset_name)
    print(f"Loaded full dataset with {len(full_dataset)} samples.")

    # Load train split
    train_data = full_dataset['train'].select(range(train_split_size))
    print(f"Selected {len(train_data)} samples for training with columns: {train_data.column_names}")

    # Load validation split
    validation_data = full_dataset['validation'].select(range(eval_split_size))
    print(f"Selected {len(validation_data)} samples for validation with columns: {validation_data.column_names}")

    # # Count unique CUIs in the training and validation datasets
    # train_unique_count, train_unique_cuis = count_unique_cuis(train_data)
    # validation_unique_count, validation_unique_cuis = count_unique_cuis(validation_data)

    # # Print the results
    # print(f"Number of unique CUIs in the training dataset: {train_unique_count}")
    # print(f"Number of unique CUIs in the validation dataset: {validation_unique_count}")

except Exception as e:
    print(f"Error loading dataset: {e}")
    print(f"Please ensure the dataset '{dataset_name}' is accessible and has the expected features (image, caption, cui).")
    print("You might need to adjust 'expected_features' or the split syntax.")
    exit()

Loading dataset...


Loading dataset from disk:   0%|          | 0/27 [00:00<?, ?it/s]

Loaded full dataset with 3 samples.
Selected 10000 samples for training with columns: ['image', 'image_id', 'caption', 'cui']
Selected 4000 samples for validation with columns: ['image', 'image_id', 'caption', 'cui']


In [3]:
# Step 3: Load CLIP Model and Processor
print("Loading CLIP model...")
# Use a standard CLIP model checkpoint
model_id = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading CLIP model...
Model loaded on cuda


In [None]:
# Step 4: CUI to Text Conversion and Data Preparation Function

def cuis_to_text(cui_list):
    """
    Converts a list of CUIs into a descriptive text string.
    Example: ["C0040405", "C0041618"] -> "Medical concepts observed: C0040405, C0041618"
    Handles empty lists.
    A more advanced version could map CUIs to names (e.g., using UMLS).
    """
    if not cui_list:
        return "No specific medical concepts identified." # Or handle as invalid?
    random.shuffle(cui_list)
    return "Medical concepts observed: " + ", ".join(cui_list)

def collate_fn_cui(batch):
    valid_batch = []
    texts = []

    for example in batch:
        # Basic validation: Check for image and non-empty CUI list
        if example["image"] is not None and example["cui"] and len(example["cui"]) > 0:
             valid_batch.append(example)
             # Convert CUIs to text representation
             texts.append(cuis_to_text(example["cui"]))


    if not valid_batch:
        return None # Skip batch if no valid examples remain

    images = [example["image"] for example in valid_batch]

    # Use the CLIP processor
    inputs = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True, # Pad text sequences to max length in batch
        truncation=True # Truncate text sequences if they exceed model max length
    )

    # Move tensors to the correct device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    # Store number of valid items for potential index tracking in evaluation
    inputs['num_valid_in_batch'] = len(valid_batch)
    return inputs

# Create DataLoaders
batch_size = 4
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_cui)
eval_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_cui)

In [5]:
# Step 5: Fine-Tuning Setup
optimizer = AdamW(model.parameters(), lr=1e-6) # Smaller LR often better for CLIP FT
epochs = 10 # Number of fine-tuning epochs

In [6]:
# Step 6: Fine-Tuning Loop
print("Starting fine-tuning...")
model.train() # Set model to training mode

for epoch in range(epochs):
    total_loss = 0
    num_batches_processed = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for batch in progress_bar:
        if batch is None: # Skip if collate_fn returned None
            print("Skipping empty batch.")
            continue

        # Extract valid count and remove it from batch dict before passing to model
        num_valid = batch.pop('num_valid_in_batch', None)
        if num_valid is None or num_valid == 0: # Should not happen if batch is not None, but check
             continue

        optimizer.zero_grad()

        # Forward pass: CLIPModel calculates contrastive loss
        outputs = model(**batch, return_loss=True)
        loss = outputs.loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches_processed += 1
        progress_bar.set_postfix({"Loss": loss.item()})

    if num_batches_processed > 0:
        avg_loss = total_loss / num_batches_processed
        print(f"Epoch {epoch+1}/{epochs}, Average Training Loss: {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch+1}/{epochs} completed with no valid batches processed.")


Starting fine-tuning...


Epoch 1/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 1/10, Average Training Loss: 0.5419


Epoch 2/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 2/10, Average Training Loss: 0.4048


Epoch 3/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 3/10, Average Training Loss: 0.3476


Epoch 4/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 4/10, Average Training Loss: 0.3250


Epoch 5/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 5/10, Average Training Loss: 0.2867


Epoch 6/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 6/10, Average Training Loss: 0.2533


Epoch 7/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 7/10, Average Training Loss: 0.2325


Epoch 8/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 8/10, Average Training Loss: 0.2095


Epoch 9/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 9/10, Average Training Loss: 0.1960


Epoch 10/10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 10/10, Average Training Loss: 0.1783


In [7]:
# Step 7: Save Fine-Tuned Model
save_directory = "./clip-finetuned-medical-cui-10000-10"
model.save_pretrained(save_directory)
processor.save_pretrained(save_directory)
print(f"Fine-tuning completed and model saved to {save_directory}")

Fine-tuning completed and model saved to ./clip-finetuned-medical-cui-10000-10
