# Notebook 3: Combining RoBERTa and BLIP-2
This notebook will handle both text classification for yes/no questions using RoBERTa, and image description generation using BLIP-2, based on the type of input query.

## Data Loading and Preprocessing:
- Load and preprocess the dataset containing medical images (same as in Notebook 1 and 2).
- Load the dataset of yes/no questions paired with images.
## Image Feature Extraction and Text Generation (BLIP-2):
- Use the pre-trained BLIP-2 model to generate image descriptions.
## Question Classification (RoBERTa):
- Use the RoBERTa model for binary classification (yes/no) for questions related to the medical images.
## Switching Between Models:
Based on the type of query, decide whether to use the RoBERTa model (for classification) or BLIP-2 (for image description generation).
# Testing and Evaluation:
Test the combined approach by feeding different queries and evaluating the output.

In [None]:
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration, RobertaTokenizer, TFRobertaForSequenceClassification
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import os

#Part 1: Load Models

# Load BLIP-2 model and processor for image captioning
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xxl")
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")

# Load RoBERTa model and tokenizer for yes/no classification
roberta_model = TFRobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Define device for RoBERTa (if using PyTorch)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
roberta_model = roberta_model.to(device)

# Part 2: Image Preprocessing

# Preprocessing for images (for BLIP-2)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_images_from_folder(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            img = Image.open(os.path.join(folder_path, filename)).convert("RGB")
            img = preprocess(img)  # Preprocess the image
            images.append((filename, img))
    return images

# Part 3: Text Generation (BLIP-2)

def generate_image_description(image_tensor):
    pil_image = transforms.ToPILImage()(image_tensor)
    inputs = blip2_processor(images=pil_image, return_tensors="pt")
    outputs = blip2_model.generate(**inputs)
    description = blip2_processor.decode(outputs[0], skip_special_tokens=True)
    return description

# Part 4: Question Classification (RoBERTa)

def classify_question(question):
    inputs = roberta_tokenizer(question, return_tensors="pt", padding=True, truncation=True)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    outputs = roberta_model(**inputs)
    prediction = torch.argmax(outputs.logits, dim=1).item()
    return 'yes' if prediction == 1 else 'no'

# Part 5: Query Handling and Evaluation 

def handle_query(image_path, query):
    """
    Switch between image description generation and yes/no classification based on the type of query.
    """
    if query.lower().startswith("describe"):  # Assuming queries asking for descriptions start with "describe"
        image = Image.open(image_path).convert("RGB")
        image_tensor = preprocess(image)
        return generate_image_description(image_tensor)
    else:
        return classify_question(query)

# Evaluation Function
def evaluate_image_and_queries(image_folder, queries):
    images = load_images_from_folder(image_folder)
    
    for image_name, image_tensor in images:
        print(f"Processing image: {image_name}")
        for query in queries:
            response = handle_query(image_name, query)
            print(f"Query: {query} -> Response: {response}")

# Part 6: Example Usage

# Example image folder and queries
image_folder = 'path_to_your_images'
queries = [
    "Describe this medical image.",
    "Is there a fracture in the image?",
    "Describe the abnormalities in this radiology scan."
]

# Evaluate and print results
evaluate_image_and_queries(image_folder, queries)