# Install the required packages

In [None]:
!pip3 install datasets transformers accelerate torchvision matplotlib

# Import Required Libraries

In [None]:
# import all packets
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from datasets import DatasetDict
from transformers import AutoTokenizer, BlipForConditionalGeneration, BlipProcessor
import matplotlib.pyplot as plt
from PIL import Image
import io
from torch.utils.data import DataLoader

In [140]:
def analyze_dataset(dataset_dict: dict):
    """Perform basic analysis on the dataset."""
    for split in dataset_dict.keys():
        dataset = dataset_dict[split]
        print(f"Dataset Split: {split}")
        print(f"Number of Samples: {len(dataset)}")
        print(f"Features: {dataset.features.keys()}")
        
        # Check for missing values
        missing_values = {col: sum(1 for x in dataset[col] if x is None) for col in dataset.features.keys()}
        print(f"Missing values: {missing_values}")
        
        # Show a few sample images with captions
        sample = dataset[0]
        image = sample['image']
        if isinstance(image, Image.Image):
            image = image.convert('RGB')
        plt.imshow(image)
        plt.title(sample['label'])
        plt.axis("off")
        plt.show()

# Load the Dataset

In [None]:
from datasets import load_dataset
data = load_dataset("jmhessel/newyorker_caption_contest", "matching")

# Check the dataset format

In [None]:
print(data)

In [None]:
print(data["train"][0])

# Analyze the dataset

In [None]:
analyze_dataset(data)

# Evaluation function

In [45]:
from torch.utils.data import DataLoader

def custom_collate_fn(batch):
    # Handle variable-sized inputs here
    images = [item['image'] for item in batch]
    captions = [item['label'] for item in batch]
    # Resize images to the same size
    images = torch.stack([transforms.Resize((224, 224))(image) for image in images])
    # Convert captions to tensor
    captions = torch.stack(captions)
    return {'image': images, 'label': captions}

# evauate model
def evaluate_model(model, dataloader, criterion, device):
    """Evaluate model performance on validation or test dataset."""
    print("batch -----------> ", )
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            captions = batch['label'].to(device)
            outputs = model(pixel_values=images, labels=captions)
            loss = outputs.loss
            total_loss += loss.item()
    return total_loss / len(dataloader)



# Test Model function

In [None]:
def test_model(dataset_dict):
    """Test the trained model on the test dataset."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
    test_dataloader = DataLoader(dataset_dict['test'], batch_size=8, shuffle=False)
    test_loss = evaluate_model(model, test_dataloader, nn.CrossEntropyLoss(), device)
    print(f"Test Loss: {test_loss:.4f}")


# Following steps are perfomed:
##### 1. Load model and processor
##### 2. Preprocessing function
##### 3. Load and preprocess dataset

In [None]:
from transformers import BlipForConditionalGeneration, BlipProcessor
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

# Load model and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Preprocessing function
def preprocess_data(example):
    image = example["image"].convert("RGB")  # Convert grayscale to RGB
    example["image"] = transform(image)  # Apply transforms

    # Convert label ('A', 'B', etc.) to integer index
    label_idx = ord(example["label"]) - ord("A")
    caption = example["caption_choices"][label_idx]

    # Tokenize with padding and ensure output is a tensor
    tokenized = processor.tokenizer(
        caption, padding="max_length", truncation=True, max_length=32, return_tensors="pt"
    )["input_ids"]
    tokenized = tokenized.squeeze(0) if tokenized.dim() > 1 else tokenized
    example["label"] = tokenized  # This should be a torch.Tensor
    return example


# Load and preprocess dataset
data = load_dataset("jmhessel/newyorker_caption_contest", "matching")
processed_data = data.map(preprocess_data)


In [74]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


# Train the model on the preprocessed dataset

In [None]:
# Custom collate function for batching
def collate_fn(batch):
    print("batch inside collate function-----------> ", len(batch))
    valid_items = []
    for item in batch:
        try:
            # Ensure image and label are tensors (if not, try converting them)
            if not isinstance(item["image"], torch.Tensor):
                item["image"] = torch.tensor(item["image"])
            if not isinstance(item["label"], torch.Tensor):
                item["label"] = torch.tensor(item["label"], dtype=torch.long)
            valid_items.append(item)
        except Exception as e:
            print(f"Error processing item: {e}")
            # Skip this item if conversion fails
            pass

    if not valid_items:
        # If no valid items, return an empty dict or handle appropriately
        return {"image": None, "label": None}

    images = torch.stack([i["image"] for i in valid_items])
    labels = torch.nn.utils.rnn.pad_sequence(
        [i["label"] for i in valid_items], batch_first=True, padding_value=0
    )
    return {"image": images, "label": labels}



# Dataloader
train_dataloader = DataLoader(
    processed_data["train"], batch_size=8, shuffle=True, collate_fn=collate_fn
)
print("Dataloader created!")

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
print("Starting training loop...")

for epoch in range(3):
    print(f"Epoch {epoch + 1}")
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        print("batch -----------> ", len(batch))
        if batch["image"] is None:
            print("Skipping batch due to missing data")
            continue  # Skip batch if no valid items
        images = batch["image"].to(device)
        captions = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=images, input_ids=captions, labels=captions)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1} - Loss: {total_loss:.4f}")
    
torch.save(model.state_dict(), "blip_captioning_model.pth")
print("Model saved to blip_captioning_model.pth")

print("Training complete!")

In [None]:
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
from PIL import Image

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load the model architecture
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
model.load_state_dict(torch.load("blip_captioning_model.pth", map_location=device))
model.to(device)
model.eval()

# Load the processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

print("Model loaded successfully for evaluation!")


In [None]:
# Load a sample image from the test dataset
sample_image = data["test"][0]["image"]  # Assuming data is still loaded
sample_image.show()  # Display the image (optional)


In [None]:
def evaluate_model(test_data):
    results = []
    predictions = []
    references = []

    counter = 0
    # print(test_data[0])
    for i in range(len(test_data)):
        image_path = test_data[i]["image"]
        print(image_path)
        counter += 1
        print(test_data[0])
        ground_truth = test_data[i]["label"]
        try:
            generated_caption = generate_caption(image_path)
            # Store results
            predictions.append(generated_caption)
            references.append(ground_truth)
            results.append([image_path, ground_truth, generated_caption])
        except Exception as error:
            generated_caption = generate_caption(image_path)
            # Store results
            predictions.append("error")
            references.append("error")
            results.append([image_path, ground_truth, generated_caption])

    # Compute automatic scores
    scores = compute_metrics(predictions, references)
    print("Evaluation Metrics:", scores)

    # Save results
    df = pd.DataFrame(results, columns=["Image Path", "Ground Truth", "Generated Caption"])
    df.to_csv("caption_results.csv", index=False)
    print("Evaluation completed. Results saved in caption_results.csv")

test_data = processed_dataset["test"]
evaluate_model(test_data)


In [None]:
label = {
    "A" : 0,
    "B" : 1,
    "C" : 2,
    "D" : 3,        
    "E" : 4,
    "F" : 5,
    "G" : 6,
    "H" : 7,
    "I" : 8,
    "J" : 9,
    "K" : 10,
    "L" : 11,
    "M" : 12,
    "N" : 13,
    "O" : 14,
    "P" : 15,
    "Q" : 16,
    "R" : 17,
    "S" : 18,
    "T" : 19,
    "U" : 20,
    "V" : 21,
    "W" : 22,
    "X" : 23,
    "Y" : 24,
    "Z" : 25    
}
# Process the image for the model
inputs = processor(images=sample_image, return_tensors="pt").to(device)

# Generate a caption
with torch.no_grad():
    generated_ids = model.generate(**inputs, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

# Print results
print("Generated Caption:", generated_caption)
print("Ground Truth Caption:", data["test"][0]["label"])  # Assuming label stores the actual caption
print("Image Location:", data["test"][0]["image_location"])  # Assuming image location is stored
print("Image Description:", data["test"][0]["image_description"])  # Assuming image
print("Image Uncanny Description:", data["test"][0]["image_uncanny_description"])  # Assuming image uncanny description
print("Caption Choices selected:", data["test"][0]["caption_choices"][label[data["test"][0]["label"]]])  # Assuming caption choices are stored
print("Caption Choices:", data["test"][0]["caption_choices"])  # Assuming caption choices are stored

# Convert the model into transformer load

In [None]:
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from huggingface_hub import notebook_login

# Authenticate with Hugging Face
# notebook_login()

# Define model directory and Hugging Face repo
model_dir = "blip_caption_model"
hf_repo = "Nishthaaa/image_captioning"

# Load processor (update to match your training processor)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

# Load your `.pth` model
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")  # Base model
model.load_state_dict(torch.load("blip_captioning_model.pth", map_location="cpu"))  # Load trained weights

# Save processor and model
processor.save_pretrained(model_dir)
model.save_pretrained(model_dir)

# # Push to Hugging Face Model Hub
# processor.push_to_hub(hf_repo)
# model.push_to_hub(hf_repo)


# How to push to Huggingface

In [None]:
# On your terminal do the following steps:
# 1. Login to huggingface via your token
# 2. create the space on hugging face
# 3. push the model to hugging face via the below code
# 4. check via streamlit app which is created
from huggingface_hub import Repository, HfApi
import torch

# Set your Hugging Face repo name
hf_repo_name = "Nishthaaa/image_captioning"
# Clone the repo locally
repo = Repository(local_dir="blip_caption_model", clone_from=f"https://huggingface.co/{hf_repo_name}")
# Save the model
model_path = "blip_caption_model"
torch.save(model.state_dict(), model_path)
# Push to Hugging Face
repo.push_to_hub(commit_message="Upload trained BLIP captioning model")
print("Model pushed successfully!")
