In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import (
    ViltProcessor, ViltForQuestionAnswering,
    BlipProcessor, BlipForQuestionAnswering,
    BertTokenizer
)
from functools import partial
from tqdm import tqdm
from PIL import Image

from data import VQADataset, collate_fn_with_tokenizer

torch.manual_seed(42)
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Configuration
dataset_root = 'D:/VQA/cocoqa'
model_type = 'vilt'  # 'vilt' or 'blip'
trained_model_path = './consistency/consistency_models/vilt/pytorch_model.pth'  # Trained model path
batch_size = 8
num_workers = 4
max_samples = 1000
perplexity = 30

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def extract_embeddings_and_losses(model, processor, model_name, dataloader, 
                                  dataset_root, split, device, max_samples=2000):
    model.eval()
    
    embeddings_list = []
    losses_list = []
    labels_list = []
    predictions_list = []
    
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    print(f"Extracting embeddings from {model_name}...")
    
    num_samples = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting", leave=False):
            if num_samples >= max_samples:
                break
                
            try:
                image_paths = []
                for img_id in batch['image_id']:
                    base_path = os.path.join(dataset_root, split, 'images', img_id)
                    found = False
                    for ext in ['.jpg', '.png', '.jpeg']:
                        if os.path.exists(base_path + ext):
                            image_paths.append(base_path + ext)
                            found = True
                            break
                    if not found:
                        if os.path.exists(base_path):
                            image_paths.append(base_path)
                        else:
                            continue
                
                if len(image_paths) == 0:
                    continue
                
                images = [Image.open(path).convert('RGB') for path in image_paths]
                questions = batch['question'][:len(images)]
                answers = batch['answer'][:len(images)].to(device)
                answer_texts = batch['answer_text'][:len(images)] if 'answer_text' in batch else [str(a.item()) for a in answers]
                
                if "blip" in model_name.lower():
                    inputs = processor(images=images, text=questions, return_tensors="pt", 
                                     padding=True, truncation=True).to(device)
                    labels = processor(text=answer_texts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
                    outputs = model(**inputs, labels=labels)
                    
                    batch_loss = outputs.loss.item()
                    losses_batch = np.array([batch_loss] * len(images))
                    embeddings_batch = outputs.logits[:, 0, :].cpu().numpy()
                    predictions_batch = torch.argmax(outputs.logits[:, 0, :], dim=-1).cpu().numpy()
                    
                elif "vilt" in model_name.lower():
                    inputs = processor(images=images, text=questions, return_tensors="pt", 
                                     padding=True, truncation=True, max_length=40).to(device)
                    outputs = model(**inputs)
                    logits = outputs.logits
                    
                    losses_batch = criterion(logits, answers).cpu().numpy()
                    embeddings_batch = logits.cpu().numpy()
                    predictions_batch = torch.argmax(logits, dim=-1).cpu().numpy()
                
                embeddings_list.append(embeddings_batch)
                losses_list.append(losses_batch)
                labels_list.append(answers.cpu().numpy())
                predictions_list.append(predictions_batch)
                
                num_samples += len(images)
                
            except Exception as e:
                print(f"\n⚠ Error: {e}")
                continue
    
    embeddings = np.vstack(embeddings_list)
    losses = np.concatenate(losses_list)
    labels = np.concatenate(labels_list)
    predictions = np.concatenate(predictions_list)
    
    print(f"✓ Extracted {len(embeddings)} samples")
    print(f"  Embedding shape: {embeddings.shape}")
    print(f"  Loss range: [{losses.min():.2f}, {losses.max():.2f}]")
    
    return embeddings, losses, labels, predictions

In [None]:
def assign_difficulty_groups(losses):
    p33 = np.percentile(losses, 33)
    p66 = np.percentile(losses, 66)
    
    groups = np.zeros(len(losses), dtype=int)
    groups[losses >= p33] = 1
    groups[losses >= p66] = 2
    
    print(f"\nDifficulty Groups:")
    print(f"  Easy (loss < {p33:.2f}): {np.sum(groups == 0)} samples")
    print(f"  Medium ({p33:.2f} <= loss < {p66:.2f}): {np.sum(groups == 1)} samples")
    print(f"  Hard (loss >= {p66:.2f}): {np.sum(groups == 2)} samples")
    
    return groups, p33, p66

In [None]:
def visualize_tsne(embeddings, losses, groups, labels, predictions, title, perplexity=30):
    print(f"\nPerforming t-SNE projection...")
    
    if embeddings.shape[1] > 50:
        print(f"  PCA: {embeddings.shape[1]} -> 50")
        pca = PCA(n_components=50)
        embeddings_pca = pca.fit_transform(embeddings)
        print(f"  Explained variance: {pca.explained_variance_ratio_.sum():.2%}")
    else:
        embeddings_pca = embeddings
    
    print(f"  t-SNE: 50 -> 2 (perplexity={perplexity})")
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42, 
                n_iter=1000, verbose=0)
    embeddings_2d = tsne.fit_transform(embeddings_pca)
    
    print(f"✓ t-SNE completed")
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))
    
    difficulty_names = ['Easy', 'Medium', 'Hard']
    difficulty_colors = ['#2ecc71', '#f39c12', '#e74c3c']
    
    # 1. Difficulty Groups
    ax = axes[0, 0]
    for group_idx in range(3):
        mask = (groups == group_idx)
        ax.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
                  c=difficulty_colors[group_idx], label=difficulty_names[group_idx],
                  alpha=0.6, s=30, edgecolors='none')
    
    ax.set_title(f'Output Distribution by Difficulty\n{title}', fontsize=14, fontweight='bold')
    ax.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax.set_ylabel('t-SNE Dimension 2', fontsize=12)
    ax.legend(fontsize=11, loc='upper right')
    ax.grid(True, alpha=0.3)
    
    # 2. Loss Heatmap
    ax = axes[0, 1]
    scatter = ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                        c=losses, cmap='RdYlGn_r', alpha=0.6, s=30, 
                        edgecolors='none', vmin=losses.min(), vmax=losses.max())
    ax.set_title('Loss Distribution (Continuous)', fontsize=14, fontweight='bold')
    ax.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax.set_ylabel('t-SNE Dimension 2', fontsize=12)
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Loss', fontsize=11)
    ax.grid(True, alpha=0.3)
    
    # 3. Correctness
    ax = axes[1, 0]
    correct = (labels == predictions)
    ax.scatter(embeddings_2d[~correct, 0], embeddings_2d[~correct, 1],
              c='#e74c3c', label='Incorrect', alpha=0.6, s=30, edgecolors='none')
    ax.scatter(embeddings_2d[correct, 0], embeddings_2d[correct, 1],
              c='#2ecc71', label='Correct', alpha=0.6, s=30, edgecolors='none')
    ax.set_title('Prediction Correctness', fontsize=14, fontweight='bold')
    ax.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax.set_ylabel('t-SNE Dimension 2', fontsize=12)
    ax.legend(fontsize=11, loc='upper right')
    ax.grid(True, alpha=0.3)
    
    # 4. Statistics
    ax = axes[1, 1]
    ax.axis('off')
    
    stats_text = f"Statistics:\n\n"
    stats_text += f"Total Samples: {len(embeddings)}\n\n"
    
    for group_idx in range(3):
        mask = (groups == group_idx)
        group_losses = losses[mask]
        group_correct = correct[mask]
        
        stats_text += f"{difficulty_names[group_idx]} Group:\n"
        stats_text += f"  Count: {np.sum(mask)}\n"
        stats_text += f"  Mean Loss: {group_losses.mean():.3f}\n"
        stats_text += f"  Std Loss: {group_losses.std():.3f}\n"
        stats_text += f"  Accuracy: {group_correct.mean():.2%}\n\n"
    
    stats_text += f"Overall:\n"
    stats_text += f"  Mean Loss: {losses.mean():.3f}\n"
    stats_text += f"  Std Loss: {losses.std():.3f}\n"
    stats_text += f"  Accuracy: {correct.mean():.2%}\n"
    
    ax.text(0.1, 0.9, stats_text, transform=ax.transAxes,
           fontsize=11, verticalalignment='top', fontfamily='monospace',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    
    plt.tight_layout()
    plt.show()
    
    return embeddings_2d

In [None]:
def compare_before_after(original_data, trained_data, original_tsne, trained_tsne):
    orig_emb, orig_losses, orig_groups, orig_labels, orig_preds = original_data
    train_emb, train_losses, train_groups, train_labels, train_preds = trained_data
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    difficulty_names = ['Easy', 'Medium', 'Hard']
    difficulty_colors = ['#2ecc71', '#f39c12', '#e74c3c']
    
    # Row 1: Original Model
    ax = axes[0, 0]
    for group_idx in range(3):
        mask = (orig_groups == group_idx)
        ax.scatter(original_tsne[mask, 0], original_tsne[mask, 1],
                  c=difficulty_colors[group_idx], label=difficulty_names[group_idx],
                  alpha=0.6, s=20, edgecolors='none')
    ax.set_title('BEFORE: Difficulty Groups', fontsize=13, fontweight='bold')
    ax.set_xlabel('t-SNE Dim 1', fontsize=10)
    ax.set_ylabel('t-SNE Dim 2', fontsize=10)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    
    ax = axes[0, 1]
    scatter = ax.scatter(original_tsne[:, 0], original_tsne[:, 1],
                        c=orig_losses, cmap='RdYlGn_r', alpha=0.6, s=20, edgecolors='none')
    ax.set_title('BEFORE: Loss Distribution', fontsize=13, fontweight='bold')
    ax.set_xlabel('t-SNE Dim 1', fontsize=10)
    ax.set_ylabel('t-SNE Dim 2', fontsize=10)
    plt.colorbar(scatter, ax=ax)
    ax.grid(True, alpha=0.3)
    
    ax = axes[0, 2]
    ax.axis('off')
    orig_stats = f"BEFORE Statistics:\n\n"
    for group_idx in range(3):
        mask = (orig_groups == group_idx)
        group_losses = orig_losses[mask]
        orig_stats += f"{difficulty_names[group_idx]}:\n"
        orig_stats += f"  Mean: {group_losses.mean():.3f}\n"
        orig_stats += f"  Std: {group_losses.std():.3f}\n\n"
    ax.text(0.1, 0.9, orig_stats, transform=ax.transAxes,
           fontsize=10, verticalalignment='top', fontfamily='monospace',
           bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))
    
    # Row 2: Trained Model
    ax = axes[1, 0]
    for group_idx in range(3):
        mask = (train_groups == group_idx)
        ax.scatter(trained_tsne[mask, 0], trained_tsne[mask, 1],
                  c=difficulty_colors[group_idx], label=difficulty_names[group_idx],
                  alpha=0.6, s=20, edgecolors='none')
    ax.set_title('AFTER: Difficulty Groups', fontsize=13, fontweight='bold')
    ax.set_xlabel('t-SNE Dim 1', fontsize=10)
    ax.set_ylabel('t-SNE Dim 2', fontsize=10)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)
    
    ax = axes[1, 1]
    scatter = ax.scatter(trained_tsne[:, 0], trained_tsne[:, 1],
                        c=train_losses, cmap='RdYlGn_r', alpha=0.6, s=20, edgecolors='none')
    ax.set_title('AFTER: Loss Distribution', fontsize=13, fontweight='bold')
    ax.set_xlabel('t-SNE Dim 1', fontsize=10)
    ax.set_ylabel('t-SNE Dim 2', fontsize=10)
    plt.colorbar(scatter, ax=ax)
    ax.grid(True, alpha=0.3)
    
    ax = axes[1, 2]
    ax.axis('off')
    train_stats = f"AFTER Statistics:\n\n"
    for group_idx in range(3):
        mask = (train_groups == group_idx)
        group_losses = train_losses[mask]
        train_stats += f"{difficulty_names[group_idx]}:\n"
        train_stats += f"  Mean: {group_losses.mean():.3f}\n"
        train_stats += f"  Std: {group_losses.std():.3f}\n\n"
    
    train_stats += f"\nImprovement:\n"
    for group_idx in range(3):
        mask_orig = (orig_groups == group_idx)
        mask_train = (train_groups == group_idx)
        std_orig = orig_losses[mask_orig].std()
        std_train = train_losses[mask_train].std()
        improvement = (std_orig - std_train) / std_orig * 100
        train_stats += f"{difficulty_names[group_idx]} Std: {improvement:+.1f}%\n"
    
    ax.text(0.1, 0.9, train_stats, transform=ax.transAxes,
           fontsize=10, verticalalignment='top', fontfamily='monospace',
           bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3))
    
    plt.tight_layout()
    plt.show()

In [None]:
# Load data
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
collate_fn = partial(collate_fn_with_tokenizer, tokenizer=tokenizer)

dataset = VQADataset(root_dir=dataset_root, split='train', transform=image_transform)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers
)

print(f"Dataset loaded: {len(dataset)} samples")

In [None]:
# Load models
if model_type == 'vilt':
    pretrained_name = 'dandelin/vilt-b32-finetuned-vqa'
    processor = ViltProcessor.from_pretrained(pretrained_name)
    original_model = ViltForQuestionAnswering.from_pretrained(pretrained_name, use_safetensors=True).to(device)
elif model_type == 'blip':
    pretrained_name = 'Salesforce/blip-vqa-base'
    processor = BlipProcessor.from_pretrained(pretrained_name)
    original_model = BlipForQuestionAnswering.from_pretrained(pretrained_name).to(device)

print(f"Original model loaded: {pretrained_name}")

In [None]:
# Extract original model embeddings
print("\n" + "="*60)
print("Processing ORIGINAL model...")
print("="*60)

orig_emb, orig_losses, orig_labels, orig_preds = extract_embeddings_and_losses(
    original_model, processor, pretrained_name, dataloader,
    dataset_root, 'train', device, max_samples
)

orig_groups, orig_p33, orig_p66 = assign_difficulty_groups(orig_losses)

In [None]:
# Visualize original model
print("\n" + "="*60)
print("Visualizing ORIGINAL model...")
print("="*60)

original_tsne = visualize_tsne(
    orig_emb, orig_losses, orig_groups, orig_labels, orig_preds,
    title=f"Original {model_type.upper()} Model",
    perplexity=perplexity
)

In [None]:
# Load trained model
if os.path.exists(trained_model_path):
    print(f"\nLoading trained model from: {trained_model_path}")
    
    if model_type == 'vilt':
        trained_model = ViltForQuestionAnswering.from_pretrained(pretrained_name, use_safetensors=True)
        trained_model.load_state_dict(torch.load(trained_model_path, map_location=device))
        trained_model = trained_model.to(device)
    elif model_type == 'blip':
        trained_model = BlipForQuestionAnswering.from_pretrained(pretrained_name)
        trained_model.load_state_dict(torch.load(trained_model_path, map_location=device))
        trained_model = trained_model.to(device)
    
    print("✓ Trained model loaded successfully")
else:
    print(f"\n⚠ WARNING: Trained model not found at {trained_model_path}")
    print("Please train the model first using consistency_train.py")
    trained_model = None

In [None]:
# Extract trained model embeddings
if trained_model is not None:
    print("\n" + "="*60)
    print("Processing TRAINED model...")
    print("="*60)
    
    train_emb, train_losses, train_labels, train_preds = extract_embeddings_and_losses(
        trained_model, processor, f"Trained {model_type}", dataloader,
        dataset_root, 'train', device, max_samples
    )
    
    train_groups, train_p33, train_p66 = assign_difficulty_groups(train_losses)

In [None]:
# Visualize trained model
if trained_model is not None:
    print("\n" + "="*60)
    print("Visualizing TRAINED model...")
    print("="*60)
    
    trained_tsne = visualize_tsne(
        train_emb, train_losses, train_groups, train_labels, train_preds,
        title=f"Consistency-Trained {model_type.upper()} Model",
        perplexity=perplexity
    )

In [None]:
# Compare before and after
if trained_model is not None:
    print("\n" + "="*60)
    print("Creating COMPARISON visualization...")
    print("="*60)
    
    compare_before_after(
        (orig_emb, orig_losses, orig_groups, orig_labels, orig_preds),
        (train_emb, train_losses, train_groups, train_labels, train_preds),
        original_tsne,
        trained_tsne
    )
    
    print("\n✓ Visualization completed!")