In [14]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from transformers import CLIPProcessor, CLIPModel, T5ForConditionalGeneration, T5Tokenizer
from PIL import Image
import os
import re

# 1. Image Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image)
    image_tensor = torch.clamp(image_tensor, 0, 1)
    return image_tensor.unsqueeze(0)  # Add batch dimension

# 2. Text Preprocessing
def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    return text

def tokenize_text(text, tokenizer):
    return tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)

# 3. Model Setup
## Image Encoder (CLIP Model)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

## Text Decoder (T5)
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")

# 4. Multimodal Fusion and Text Generation
def generate_caption(image_path):
    # Preprocess image
    image = preprocess_image(image_path)
    inputs = clip_processor(images=image, return_tensors="pt", do_rescale=False)

    # Extract features from CLIP model
    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)

    # Prepare text input
    prompt = "Describe this medical image:"
    input_ids = t5_tokenizer(prompt, return_tensors="pt").input_ids

    # Get initial decoder input
    decoder_input_ids = t5_model._shift_right(input_ids)

    # Project image features to match text embedding size
    image_features_proj = nn.Linear(image_features.shape[-1], t5_model.config.d_model)(image_features)
    image_features_proj = image_features_proj.unsqueeze(1)  # Add sequence dimension

    # Generate the caption
    outputs = t5_model.generate(
        input_ids=input_ids,
        decoder_input_ids=decoder_input_ids,
        encoder_outputs=torch.stack([image_features_proj]),
        max_length=50
    )

    caption = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return caption

# 5. Fine-Tuning Process (Placeholder)
def fine_tune_model():
    pass  # Placeholder for future implementation

# 6. Chatbot Integration
def chatbot_response(user_input, image_path):
    caption = generate_caption(image_path)
    response = f"Generated Medical Report: {caption}\nUser Query: {user_input}"
    return response

# Example Usage
if __name__ == "__main__":
    image_path = r"C:/Users/admin/Downloads/chest2.jpeg"
    user_query = "What are the abnormalities?"
    print(chatbot_response(user_query, image_path))

Generated Medical Report: Describe this medical image:nastrntstrntstrntstrntstrntstrntstrntstrnts
User Query: What are the abnormalities?


In [8]:
!pip install sentencepiece




In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, CLIPProcessor, CLIPModel, AdamW
from PIL import Image
import json
import os
import torch.nn as nn
from tqdm import tqdm
import gc
from torch.cuda.amp import autocast, GradScaler

class ImageToTextProjection(nn.Module):
    def __init__(self, input_dim=512, output_dim=768):
        super().__init__()
        # Simplified architecture - direct projection
        self.projection = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return self.projection(x).unsqueeze(1)

class MedicalReportDataset(Dataset):
    def __init__(self, jsonl_file, clip_processor, t5_tokenizer, max_samples=None):
        self.clip_processor = clip_processor
        self.t5_tokenizer = t5_tokenizer
        
        # Add counters for debugging
        total_entries = 0
        valid_entries = 0
        error_types = {}
        
        self.data = []
        with open(jsonl_file, "r") as f:
            for line in f:
                total_entries += 1
                entry = json.loads(line)
                
                # Check if entry has figures
                if not entry.get("figures"):
                    error_types["no_figures"] = error_types.get("no_figures", 0) + 1
                    continue
                    
                figure = entry["figures"][0]
                image_path = os.path.join(entry["location"], figure["graphic_ref"].split("\\")[-1])
                
                # Check if file exists
                if not os.path.exists(image_path):
                    error_types["missing_file"] = error_types.get("missing_file", 0) + 1
                    continue
                
                # Check if we can open the image
                try:
                    with Image.open(image_path) as img:
                        pass
                    valid_entries += 1
                    self.data.append(entry)
                except Exception as e:
                    error_type = str(type(e).__name__)
                    error_types[error_type] = error_types.get(error_type, 0) + 1
                    continue
                
                if max_samples and valid_entries >= max_samples:
                    break
        
        print(f"\nDataset Loading Statistics:")
        print(f"Total entries in JSONL: {total_entries}")
        print(f"Valid entries found: {valid_entries}")
        print("\nError breakdown:")
        for error_type, count in error_types.items():
            print(f"{error_type}: {count}")
        print("\nExample valid image path:", self.data[0]["figures"][0]["graphic_ref"] if self.data else "No valid images found")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        figure = item["figures"][0]
        
        # Print the first few paths for debugging
        if idx < 5:
            print(f"Loading image {idx}: {os.path.join(item['location'], figure['graphic_ref'].split('\\')[-1])}")
            
        image_path = os.path.join(item["location"], figure["graphic_ref"].split("\\")[-1])
        caption = figure["fig_caption"]

        try:
            image = Image.open(image_path).convert("RGB")
            pixel_values = self.clip_processor(images=image, return_tensors=None)["pixel_values"][0]
            labels = self.t5_tokenizer(caption, max_length=128, truncation=True)["input_ids"]
            
            return {
                "pixel_values": torch.tensor(pixel_values),
                "labels": torch.tensor(labels)
            }
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.nn.utils.rnn.pad_sequence(
        [item["labels"] for item in batch], 
        batch_first=True, 
        padding_value=-100
    )
    
    return {
        "pixel_values": pixel_values,
        "labels": labels
    }

def train_model(dataloader, clip_model, t5_model, projection_layer, optimizer, scaler, device, max_batches=None):
    total_loss = 0
    batch_count = 0
    
    # Use tqdm for progress bar
    pbar = tqdm(dataloader, desc="Training", total=min(len(dataloader), max_batches) if max_batches else len(dataloader))
    
    for batch in pbar:
        if batch is None:
            continue
            
        if max_batches and batch_count >= max_batches:
            break
            
        # Move batch to device
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # Use automatic mixed precision
        with autocast():
            # Get CLIP features
            with torch.no_grad():
                image_features = clip_model.get_image_features(pixel_values=pixel_values)

            # Project features and get T5 outputs
            projected_features = projection_layer(image_features)
            outputs = t5_model(inputs_embeds=projected_features, labels=labels)
            loss = outputs.loss

        # Scaled backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        total_loss += loss.item()
        batch_count += 1
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Clear cache periodically
        if batch_count % 10 == 0:
            torch.cuda.empty_cache()
            gc.collect()

    return total_loss / batch_count if batch_count > 0 else 0

def main():
    # Set device and enable benchmarking
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.backends.cudnn.benchmark = True
    
    # Initialize models
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
    t5_model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
    
    # Simplified projection layer
    projection_layer = ImageToTextProjection().to(device)
    
    # Create dataset with limited samples for testing
    jsonl_file = "A:\Vishal\BiomedCLIP_data_pipeline\_results\data\pubmed_parsed_data.jsonl"
    dataset = MedicalReportDataset(
        jsonl_file, 
        clip_processor, 
        t5_tokenizer,
        max_samples=1000  # Limit samples for testing
    )
    
    # Create dataloader with increased num_workers
    dataloader = DataLoader(
        dataset,
        batch_size=16,  # Increased batch size
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=4,  # Adjust based on your CPU
        pin_memory=True
    )
    
    # Initialize optimizer and gradient scaler
    optimizer = AdamW(
        list(t5_model.parameters()) + list(projection_layer.parameters()),
        lr=5e-5,
        weight_decay=0.01
    )
    scaler = GradScaler()
    
    # Training loop
    num_epochs = 3
    max_batches_per_epoch = 100  # Limit batches per epoch for testing
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # Set models to appropriate modes
        clip_model.eval()
        t5_model.train()
        projection_layer.train()
        
        # Train one epoch
        avg_loss = train_model(
            dataloader, 
            clip_model, 
            t5_model, 
            projection_layer, 
            optimizer, 
            scaler, 
            device,
            max_batches=max_batches_per_epoch
        )
        
        print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 1 == 0:
            torch.save({
                'epoch': epoch,
                't5_model_state_dict': t5_model.state_dict(),
                'projection_layer_state_dict': projection_layer.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, f'checkpoint_epoch_{epoch+1}.pth')

if __name__ == "__main__":
    main()

SyntaxError: f-string expression part cannot include a backslash (97095523.py, line 81)

In [16]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5ForConditionalGeneration, CLIPProcessor, CLIPModel
from PIL import Image
import warnings

# Suppress warnings
warnings.simplefilter("ignore", FutureWarning)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

### Step 1: Load BiomedVLP-CXR-BERT-general Model for Text Embeddings ###
bert_model_name = "microsoft/BiomedVLP-CXR-BERT-general"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name).to(device)

### Step 2: Load CLIP Model for Image Feature Extraction ###
class MedCLIPModel(nn.Module):
    def __init__(self):
        super(MedCLIPModel, self).__init__()
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    
    def forward(self, image):
        return self.model.get_image_features(image)

# Function to load MedCLIP model
def load_medclip_model():
    try:
        model = MedCLIPModel().to(device)
        print("Image Model weights loaded successfully!")
        return model
    except Exception as e:
        print(f"Error loading MedCLIP model: {e}")
        return None

# Image processing function
def preprocess_image(image_path, processor):
    try:
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor(images=image, return_tensors="pt").pixel_values
        if pixel_values is None:
            raise ValueError("Image preprocessing failed. Check input format.")
        return pixel_values.to(device)
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

### Step 3: Load T5 Model for Medical Report Generation ###
t5_model_name = "t5-base"  # Use a fine-tuned T5 medical model if available
t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name).to(device)

### Step 4: Generate Medical Caption from Image and Text ###
def generate_medical_report(image_path, medclip_model, processor, bert_model, bert_tokenizer, t5_model, t5_tokenizer):
    if medclip_model is None:
        raise RuntimeError("MedCLIP model is not initialized. Check model loading.")

    # Extract Image Features (not used in this version)
    pixel_values = preprocess_image(image_path, processor)
    if pixel_values is None:
        return "Error: Image processing failed."

    # Simplified prompt for T5
    prompt = "Generate a detailed medical report for a chest X-ray."
    t5_inputs = t5_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
    output_ids = t5_model.generate(**t5_inputs, max_length=150)
    caption = t5_tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return caption

### Load Models ###
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
medclip_model = load_medclip_model()

### Provide Image Path ###
image_path = "C:/Users/admin/Downloads/chest2.jpeg"  # Replace with your image path
caption = generate_medical_report(image_path, medclip_model, processor, bert_model, bert_tokenizer, t5_model, t5_tokenizer)
print("Generated Medical Report:", caption)

Image Model weights loaded successfully!
Generated Medical Report: a detailed medical report for a chest X-ray. Generate a detailed medical report for a chest X-ray.
