# De-Fake Inference Notebook

This notebook demonstrates the inference process for detecting fake images using a combination of CLIP and BLIP models with a neural network classifier.

## Overview

The de-fake pipeline works as follows:
1. **Image Preprocessing**: Load and preprocess the input image
2. **Caption Generation**: Use BLIP model to generate a descriptive caption
3. **Feature Extraction**: Extract image and text features using CLIP
4. **Classification**: Use a neural network to classify the image as real or fake

## Prerequisites

For Google Colab usage:
1. This notebook will automatically clone the De-Fake repository
2. Install all required dependencies
3. Set up the environment

Make sure you have the required models downloaded:
- `finetune_clip.pt`: Fine-tuned CLIP model
- `clip_linear.pt`: Neural network classifier

**Note for Colab**: Upload these model files to the Colab environment or place them in the De-Fake directory after cloning.

In [None]:
# Install required packages
!pip install torch torchvision transformers clip pillow matplotlib scikit-learn tqdm

# Clone the De-Fake repository for BLIP models
!git clone https://github.com/dlii0086/De-Fake.git
%cd De-Fake

In [None]:
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm
from blipmodels import blip_decoder

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Models

Load the pre-trained CLIP, BLIP, and classifier models.

In [None]:
# Load CLIP model
print("Loading CLIP model...")
model, preprocess = clip.load("ViT-B/32", device=device)

# Load BLIP model for caption generation
print("Loading BLIP model...")
blip_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
blip = blip_decoder(pretrained=blip_url, image_size=224, vit='base')
blip.eval()
blip = blip.to(device)

# Define the neural network classifier
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size_list, num_classes):
        super(NeuralNet, self).__init__()
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(input_size, hidden_size_list[0])
        self.fc2 = nn.Linear(hidden_size_list[0], hidden_size_list[1])
        self.fc3 = nn.Linear(hidden_size_list[1], num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.dropout2(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        return out

# Load the fine-tuned CLIP model and classifier
print("Loading fine-tuned models...")
try:
    model = torch.load("finetune_clip.pt", map_location=device)
    linear = NeuralNet(1024, [512, 256], 2).to(device)
    linear = torch.load('clip_linear.pt', map_location=device)
    print("Models loaded successfully!")
except FileNotFoundError as e:
    print(f"Model file not found: {e}")
    print("Please make sure finetune_clip.pt and clip_linear.pt are in the current directory")

## 2. Helper Functions

Define utility functions for image preprocessing and visualization.

In [None]:
def preprocess_image(img_path, image_size=224):
    """Preprocess an image for CLIP model"""
    img = Image.open(img_path)
    img = img.resize((image_size, image_size))
    return preprocess(img)

def display_image(img_path, title="Input Image"):
    """Display an image with matplotlib"""
    img = Image.open(img_path)
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
    plt.show()

def predict_image(image_path):
    """
    Perform inference on a single image
    Returns: prediction (0=real, 1=fake), confidence scores, and generated caption
    """
    # Display the input image
    display_image(image_path, "Input Image for Analysis")
    
    # Preprocess image for BLIP
    tform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    
    img = Image.open(image_path).convert('RGB')
    img_tensor = tform(img).unsqueeze(0).to(device)
    
    # Generate caption with BLIP
    print("Generating caption with BLIP...")
    with torch.no_grad():
        caption = blip.generate(img_tensor, sample=False, num_beams=3, max_length=60, min_length=5)
    
    generated_caption = caption[0] if isinstance(caption, list) else caption
    print(f"Generated caption: {generated_caption}")
    
    # Preprocess for CLIP
    image = preprocess_image(image_path).unsqueeze(0).to(device)
    text = clip.tokenize([generated_caption]).to(device)
    
    # Extract features
    print("Extracting features with CLIP...")
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        # Concatenate features
        emb = torch.cat((image_features, text_features), 1)
        
        # Classify
        output = linear(emb.float())
        probabilities = F.softmax(output, dim=1)
        predict = output.argmax(1)
        
    return predict.cpu().numpy()[0], probabilities.cpu().numpy()[0], generated_caption

## 3. Interactive Inference

Run inference on an image and see the results.

In [None]:
# Example usage
image_path = "CLIP.png"  # Replace with your image path

# Check if the file exists
if not Path(image_path).exists():
    print(f"Image file {image_path} not found. Please provide a valid image path.")
    print("Available files in current directory:")
    for file in Path('.').glob('*'):
        if file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
            print(f"  - {file.name}")
else:
    # Run inference
    prediction, confidence, caption = predict_image(image_path)
    
    # Display results
    print("\n" + "="*50)
    print("INFERENCE RESULTS")
    print("="*50)
    print(f"Generated Caption: {caption}")
    print(f"Prediction: {'FAKE' if prediction == 1 else 'REAL'}")
    print(f"Confidence Scores: Real={confidence[0]:.4f}, Fake={confidence[1]:.4f}")

## 4. Batch Processing

Process multiple images in a directory.

In [None]:
def process_directory(image_dir):
    """Process all images in a directory"""
    image_dir = Path(image_dir)
    if not image_dir.exists():
        print(f"Directory {image_dir} not found.")
        return
    
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
        image_files.extend(list(image_dir.glob(ext)))
    
    print(f"Found {len(image_files)} images to process.")
    
    results = []
    for img_path in tqdm(image_files, desc="Processing images"):
        try:
            prediction, confidence, caption = predict_image(str(img_path))
            results.append({
                'image_path': str(img_path),
                'prediction': 'FAKE' if prediction == 1 else 'REAL',
                'real_confidence': float(confidence[0]),
                'fake_confidence': float(confidence[1]),
                'caption': caption
            })
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
    
    return results

# Example batch processing
# results = process_directory("path/to/your/images")

## 5. Model Architecture Details

Understanding the model architecture used for classification.

In [None]:
# Display model architecture
print("Neural Network Classifier Architecture:")
print(linear)

print("\nInput size: 1024 (512 image features + 512 text features)")
print("Hidden layers: [512, 256]")
print("Output size: 2 (Real vs Fake)")

## 6. Performance Metrics

Evaluate model performance on test data (if available).

In [None]:
def evaluate_model(test_images, test_labels):
    """Evaluate model performance on test data"""
    predictions = []
    confidences = []
    
    for img_path in test_images:
        prediction, confidence, _ = predict_image(img_path)
        predictions.append(prediction)
        confidences.append(confidence)
    
    # Calculate metrics
    accuracy = accuracy_score(test_labels, predictions)
    precision = precision_score(test_labels, predictions)
    recall = recall_score(test_labels, predictions)
    
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    
    return predictions, confidences

# Example usage (requires test data)
# test_images = ["path1.jpg", "path2.jpg", ...]
# test_labels = [0, 1, ...]  # 0=real, 1=fake
# predictions, confidences = evaluate_model(test_images, test_labels)

## 7. Next Steps & Colab Instructions

### For Google Colab Usage:

1. **Upload this notebook to Google Colab**
2. **Run all cells sequentially** - The setup cell will automatically:
   - Install required packages
   - Clone the De-Fake repository
   - Change to the De-Fake directory
3. **Upload model files**: After running the setup cell, upload `finetune_clip.pt` and `clip_linear.pt` to the Colab environment
4. **Upload test images**: Place your images in the Colab file system or use Google Drive integration

### For Local Usage:

1. **Ensure models are available**: Place `finetune_clip.pt` and `clip_linear.pt` in the current directory
2. **Prepare your images**: Place images in a directory or use individual file paths
3. **Run cells sequentially**: Execute cells in order from top to bottom
4. **Customize for your needs**: Modify the image paths and parameters as needed

This notebook provides a complete inference pipeline for the De-Fake project, replicating the functionality of `test.py` in an interactive Jupyter environment that works both locally and on Google Colab.