In [None]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

In [None]:
# Define constants
NUM_EPOCHS = 3

BATCH_SIZE = 8
NUM_CLASSES = 10
MAX_LENGTH = 32
LEARNING_RATE = 2e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset, random_split
import torchvision.transforms as transforms
from torchvision import datasets

# 📌 Define Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for BLIP
    transforms.ToTensor(),
])

# 📌 Load STL-10 Dataset
full_dataset = datasets.STL10(root="./data", split="train", transform=transform, download=True)
subset_size = int(0.2 * len(full_dataset))  # Small subset for testing
indices = np.random.choice(len(full_dataset), subset_size, replace=False)
subset_dataset = Subset(full_dataset, indices)

# 📌 Split into Train (80%) & Test (20%)
train_size = int(0.8 * subset_size)
test_size = subset_size - train_size
train_dataset, test_dataset = random_split(subset_dataset, [train_size, test_size])

# 📌 Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

class_names = [
    "airplane", "bird", "car", "cat", "deer",
    "dog", "horse", "monkey", "ship", "truck"
]

# 📌 Check DataLoader Output
for images, labels in train_dataloader:
    print(f"Batch shape: {images.shape}, Labels: {labels}")
    plt.imshow(images[0].permute(1, 2, 0))
    break

# 📌 Select a Sample Batch for Inference
sample_images, sample_labels = next(iter(test_dataloader))  # Get a batch of test images

os.makedirs("sample_batch", exist_ok=True)

# 📌 Save Each Image Using Matplotlib
for i in range(len(sample_images)):
    class_name = class_names[sample_labels[i].item()]  # Get class name
    
    # Convert Tensor to NumPy Image
    img = sample_images[i].permute(1, 2, 0).numpy()  # Change from (C, H, W) to (H, W, C)

    # Plot and Save
    plt.imshow(img)
    plt.title(class_name)  # Set class name as title
    plt.savefig(f"sample_batch/sample_{i}.png")
    plt.close()  # Close figure to avoid memory issues


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torchvision.transforms as transforms
from torchvision import datasets
from transformers import ViTFeatureExtractor, ViTModel, BertTokenizer, BertModel
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import datetime
from PIL import Image
import torchvision.utils as vutils

# Define constants
BATCH_SIZE = 8
NUM_CLASSES = 10
MAX_LENGTH = 16
NUM_EPOCHS = 5
LEARNING_RATE = 2e-5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Class names for STL-10
class_names = [
    "airplane", "bird", "car", "cat", "deer",
    "dog", "horse", "monkey", "ship", "truck"
]

# Create directory for saving caption images
os.makedirs("training_captions", exist_ok=True)
os.makedirs("fixed_samples_results", exist_ok=True)  # New directory for fixed samples

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the model
class ImageTextClassifier(nn.Module):
    def __init__(self, num_classes, temperature=1.0, max_length=MAX_LENGTH):
        super().__init__()
        
        # Vision encoder (ViT)
        self.vision_encoder = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.vision_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
        vision_dim = self.vision_encoder.config.hidden_size
        
        # Text tokenizer and generation
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.vocab_size = self.tokenizer.vocab_size
        self.max_length = max_length
        
        # Vision to text projection
        self.vision_to_text_proj = nn.Linear(vision_dim, vision_dim)
        
        # Text generation module
        self.text_gen_transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=vision_dim, nhead=8, batch_first=True), 
            num_layers=2
        )
        self.text_output_layer = nn.Linear(vision_dim, self.vocab_size)
        
        # Text embeddings for decoder
        self.text_embeddings = nn.Embedding(self.vocab_size, vision_dim)
        
        # Text encoder (BERT)
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        text_dim = self.text_encoder.config.hidden_size
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(text_dim, text_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(text_dim // 2, num_classes)
        )
        
        # Gumbel softmax temperature
        self.temperature = temperature
        
    def forward(self, images, labels=None):
        batch_size = images.size(0)
        
        # Vision encoding
        vision_outputs = self.vision_encoder(images).last_hidden_state
        vision_features = vision_outputs[:, 0, :]  # CLS token
        
        # Project vision features for text generation
        vision_proj = self.vision_to_text_proj(vision_features).unsqueeze(1)
        
        # Text generation with differentiable sampling
        
        # Start with batch of [CLS] tokens
        start_tokens = torch.full((batch_size, 1), self.tokenizer.cls_token_id, 
                                  dtype=torch.long, device=images.device)
        token_embeds = self.text_embeddings(start_tokens)
        
        # Storage for text logits and soft token representations
        all_text_logits = []
        all_soft_tokens = []
        
        # Autoregressive generation
        for step in range(self.max_length - 1):
            # Create attention mask
            seq_len = token_embeds.size(1)
            tgt_mask = (torch.triu(torch.ones(seq_len, seq_len, device=images.device)) == 1).transpose(0, 1)
            tgt_mask = tgt_mask.float().masked_fill(tgt_mask == 0, float('-inf')).masked_fill(tgt_mask == 1, float(0.0))
            
            # Decode next token
            tgt_output = self.text_gen_transformer(
                token_embeds, 
                vision_proj.repeat(1, seq_len, 1),
                tgt_mask=tgt_mask
            )
            
            next_token_logits = self.text_output_layer(tgt_output[:, -1:, :])
            all_text_logits.append(next_token_logits)
            
            # Gumbel softmax for differentiable sampling
            if self.training:
                soft_tokens = F.gumbel_softmax(next_token_logits, tau=self.temperature, hard=False, dim=-1)
            else:
                indices = torch.argmax(next_token_logits, dim=-1)
                soft_tokens = F.one_hot(indices, num_classes=self.vocab_size).float()
            
            all_soft_tokens.append(soft_tokens)
            
            # Convert soft tokens to embeddings
            next_token_embeds = torch.matmul(soft_tokens, self.text_embeddings.weight)
            
            # Append to sequence
            token_embeds = torch.cat([token_embeds, next_token_embeds], dim=1)
        
        # Concatenate all tokens
        text_logits = torch.cat(all_text_logits, dim=1)
        soft_tokens = torch.cat(all_soft_tokens, dim=1)
        
        # Convert soft tokens to text (always, for sampling during training)
        token_indices = torch.argmax(soft_tokens, dim=-1)
        full_tokens = torch.cat([start_tokens, token_indices], dim=1)
        generated_text = [self.tokenizer.decode(tokens, skip_special_tokens=True) for tokens in full_tokens]
        
        # Convert to token IDs for BERT
        token_ids = torch.cat([start_tokens, torch.argmax(soft_tokens, dim=-1)], dim=1)
        attention_mask = torch.ones_like(token_ids)
        
        # Create soft embeddings
        bert_inputs = torch.cat([
            self.text_encoder.embeddings.word_embeddings(start_tokens),
            torch.matmul(soft_tokens, self.text_encoder.embeddings.word_embeddings.weight)
        ], dim=1)
        
        # Forward through BERT
        text_outputs = self.text_encoder(
            inputs_embeds=bert_inputs,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Classification
        text_features = text_outputs.pooler_output
        logits = self.classifier(text_features)
        
        result = {
            "logits": logits,
            "text_logits": text_logits,
            "generated_text": generated_text,
        }
        
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            classification_loss = loss_fn(logits, labels)
            result["loss"] = classification_loss
        
        return result

# Helper function to denormalize image and save with caption
def save_image_with_caption(image, caption, true_label, pred_label, filename):
    # Denormalize the image
    img = image.clone().cpu().detach()
    img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    img = torch.clamp(img, 0, 1)
    
    # Create figure with caption
    plt.figure(figsize=(5, 5))
    plt.imshow(img.permute(1, 2, 0).numpy())
    plt.axis('off')
    plt.title(f"Caption: {caption}\nTrue: {true_label}, Pred: {pred_label}", fontsize=10)
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

# Function to evaluate and save results for fixed sample batch
def evaluate_fixed_samples(model, fixed_images, fixed_labels, epoch, device=DEVICE):
    model.eval()
    fixed_images = fixed_images.to(device)
    fixed_labels = fixed_labels.to(device)
    
    with torch.no_grad():
        outputs = model(fixed_images)
        _, predicted = outputs["logits"].max(1)
        
        # Create a grid of images for this epoch
        fig, axes = plt.subplots(2, 4, figsize=(20, 10)) if len(fixed_images) >= 8 else plt.subplots(1, len(fixed_images), figsize=(5*len(fixed_images), 5))
        axes = axes.flatten()
        
        for i in range(len(fixed_images)):
            caption = outputs["generated_text"][i]
            true_class = class_names[fixed_labels[i].item()]
            pred_class = class_names[predicted[i].item()]
            correct = predicted[i].item() == fixed_labels[i].item()
            status = "✓" if correct else "✗"
            
            # Individual image
            individual_filename = f"fixed_samples_results/epoch{epoch+1}_sample{i}.png"
            save_image_with_caption(
                fixed_images[i], 
                caption, 
                true_class, 
                pred_class, 
                individual_filename
            )
            
            # For the grid
            ax = axes[i]
            img = fixed_images[i].cpu()
            img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            img = torch.clamp(img, 0, 1)
            
            ax.imshow(img.permute(1, 2, 0).numpy())
            ax.set_title(f"Caption: {caption}\nTrue: {true_class}\nPred: {pred_class} {status}", fontsize=9)
            ax.axis('off')
        
        # Save the grid
        plt.tight_layout()
        plt.savefig(f"fixed_samples_results/epoch{epoch+1}_grid.png")
        plt.close()
        
        # Return accuracy on fixed set for monitoring
        accuracy = (predicted == fixed_labels).sum().item() / len(fixed_labels)
        return accuracy, outputs["generated_text"]

# Training function
def train_model(model, train_loader, val_loader, fixed_samples, optimizer, scheduler, num_epochs=NUM_EPOCHS, device=DEVICE):
    model.to(device)
    best_val_acc = 0.0
    
    # Unpack fixed samples
    fixed_images, fixed_labels = fixed_samples
    
    # Track evolution of captions for all epochs
    caption_evolution = {i: [] for i in range(len(fixed_images))}
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        
        # Create tqdm progress bar with loss tracking
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch_idx, (images, labels) in enumerate(progress_bar):
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images, labels)
            loss = outputs["loss"]
            
            loss.backward()
            optimizer.step()
            
            # Update metrics
            running_loss += loss.item()
            avg_loss = running_loss / (batch_idx + 1)
            
            _, predicted = outputs["logits"].max(1)
            train_total += labels.size(0)
            batch_correct = predicted.eq(labels).sum().item()
            train_correct += batch_correct
            batch_acc = 100 * batch_correct / labels.size(0)
            
            # Update progress bar with current loss
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'batch_acc': f'{batch_acc:.1f}%',
                'avg_acc': f'{100 * train_correct / train_total:.1f}%'
            })
            
            # Save one sample image and caption from each batch
            if batch_idx % 10 == 0:  # Save every 10th batch to avoid too many images
                sample_idx = 0
                caption = outputs["generated_text"][sample_idx]
                true_class = class_names[labels[sample_idx].item()]
                pred_class = class_names[predicted[sample_idx].item()]
                
                filename = f"training_captions/epoch{epoch+1}_batch{batch_idx}.png"
                save_image_with_caption(
                    images[sample_idx], 
                    caption, 
                    true_class, 
                    pred_class, 
                    filename
                )
        
        # Calculate epoch metrics
        train_acc = 100 * train_correct / train_total
        train_loss = running_loss / len(train_loader)
        
        # Evaluate on fixed samples after each epoch
        print(f"\n--- Evaluating fixed samples after epoch {epoch+1} ---")
        fixed_acc, fixed_captions = evaluate_fixed_samples(model, fixed_images, fixed_labels, epoch, device)
        
        # Store captions for tracking evolution
        for i, caption in enumerate(fixed_captions):
            caption_evolution[i].append(caption)
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        # Create tqdm progress bar for validation
        val_progress_bar = tqdm(val_loader, desc=f"Validation")
        
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(val_progress_bar):
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images, labels)
                loss = outputs["loss"]
                
                val_loss += loss.item()
                _, predicted = outputs["logits"].max(1)
                val_total += labels.size(0)
                batch_correct = predicted.eq(labels).sum().item()
                val_correct += batch_correct
                
                # Update validation progress bar
                avg_val_loss = val_loss / (batch_idx + 1)
                val_progress_bar.set_postfix({
                    'val_loss': f'{avg_val_loss:.4f}',
                    'val_acc': f'{100 * val_correct / val_total:.1f}%'
                })
        
        val_acc = 100 * val_correct / val_total
        val_loss = val_loss / len(val_loader)
        
        print(f"\nEpoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, "
              f"Fixed Samples Acc: {fixed_acc*100:.2f}%")
        
        # Print some generated text examples for fixed samples
        print("\nFixed Samples Generated Text:")
        for i in range(min(len(fixed_images), 8)):
            true_class = class_names[fixed_labels[i].item()]
            print(f"Sample {i+1} ({true_class}) Caption: {fixed_captions[i]}")
        
        scheduler.step(val_loss)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print(f"Model saved! Best validation accuracy: {best_val_acc:.2f}%")
    
    # After all epochs, create caption evolution visualization
    plt.figure(figsize=(15, len(fixed_images)*2))
    for i in range(min(len(fixed_images), 8)):
        plt.subplot(len(fixed_images), 1, i+1)
        true_class = class_names[fixed_labels[i].item()]
        plt.title(f"Sample {i+1} ({true_class}) Caption Evolution", fontsize=10)
        plt.axis('off')
        for epoch, caption in enumerate(caption_evolution[i]):
            plt.text(0, 1-epoch*0.2, f"Epoch {epoch+1}: {caption}", fontsize=9)
    plt.tight_layout()
    plt.savefig("fixed_samples_results/caption_evolution.png")
    plt.close()
    
    return model

# Main training script
def main():
    print(f"Using device: {DEVICE}")
    
    # Load STL-10 Dataset
    full_dataset = datasets.STL10(root="./data", split="train", transform=transform, download=True)
    subset_size = int(0.2 * len(full_dataset))  # Small subset for testing
    indices = np.random.choice(len(full_dataset), subset_size, replace=False)
    subset_dataset = Subset(full_dataset, indices)

    # Split into Train (80%) & Test (20%)
    train_size = int(0.8 * subset_size)
    test_size = subset_size - train_size
    train_dataset, test_dataset = random_split(subset_dataset, [train_size, test_size])

    # Create DataLoaders
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Testing samples: {len(test_dataset)}")
    
    # Get fixed sample batch for monitoring
    sample_images, sample_labels = next(iter(test_dataloader))
    
    # Save the original fixed samples for reference
    os.makedirs("fixed_samples_original", exist_ok=True)
    for i in range(len(sample_images)):
        img = sample_images[i].clone()
        img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        img = torch.clamp(img, 0, 1)
        
        plt.figure(figsize=(5, 5))
        plt.imshow(img.permute(1, 2, 0).numpy())
        plt.axis('off')
        plt.title(f"Class: {class_names[sample_labels[i].item()]}")
        plt.tight_layout()
        plt.savefig(f"fixed_samples_original/sample_{i}.png")
        plt.close()
    
    # Initialize model
    model = ImageTextClassifier(num_classes=NUM_CLASSES, temperature=1.0)
    
    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )
    
    # Train model
    trained_model = train_model(
        model, train_dataloader, test_dataloader, (sample_images, sample_labels), 
        optimizer, scheduler, num_epochs=NUM_EPOCHS
    )
    
    print("Training complete!")

    # Final evaluation on fixed samples
    print("\nFinal evaluation on fixed samples:")
    model.load_state_dict(torch.load("best_model.pth"))
    final_acc, final_captions = evaluate_fixed_samples(model, sample_images, sample_labels, NUM_EPOCHS, DEVICE)
    print(f"Best model accuracy on fixed samples: {final_acc*100:.2f}%")
    
    # Create a final comparison grid - original vs. final prediction
    fig, axes = plt.subplots(len(sample_images), 2, figsize=(10, len(sample_images)*2.5))
    
    for i in range(len(sample_images)):
        # Original image
        img = sample_images[i].clone()
        img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        img = torch.clamp(img, 0, 1)
        
        axes[i, 0].imshow(img.permute(1, 2, 0).numpy())
        axes[i, 0].set_title(f"Original: {class_names[sample_labels[i].item()]}")
        axes[i, 0].axis('off')
        
        # Final prediction with caption
        axes[i, 1].imshow(img.permute(1, 2, 0).numpy())
        axes[i, 1].set_title(f"Caption: {final_captions[i]}")
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig("fixed_samples_results/final_comparison.png")
    plt.close()

if __name__ == "__main__":
    main()