# VetVision-LM — Results Visualisation

**Author:** Devarchith Parashara Batchu  
**Repository:** https://github.com/devarchith/vetvision-lm

This notebook visualises:
1. t-SNE embeddings (vision + text, coloured by species)
2. Retrieval metrics comparison
3. Ablation study results
4. Attention maps

In [None]:
import sys
sys.path.insert(0, '../src')

import json
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['figure.dpi'] = 120
from pathlib import Path

## 1. t-SNE Embedding Visualisation (Synthetic)

In [None]:
from utils.visualize import plot_multimodal_tsne

# Generate synthetic embeddings to demonstrate visualisation
np.random.seed(42)
N = 100
# Simulate species-separated clusters
canine_v = np.random.randn(N//2, 64) + np.array([2.0] + [0]*63)
feline_v = np.random.randn(N//2, 64) + np.array([-2.0] + [0]*63)
vision_embeds = np.vstack([canine_v, feline_v])
text_embeds = vision_embeds + np.random.randn(*vision_embeds.shape) * 0.3
species = np.array([0]*(N//2) + [1]*(N//2))

fig = plot_multimodal_tsne(
    vision_embeds, text_embeds, species,
    output_path='../results/tsne/synthetic_tsne.png',
)
plt.show()
print('t-SNE visualisation saved to results/tsne/')

## 2. Retrieval Metrics — Paper vs Expected

In [None]:
# Paper-reported results (NOT reproduced)
models = ['CheXzero\n(baseline)', 'CLIP\n(baseline)', 'VetVision-LM\n(ours)']
recall_at_1 = [42.8, None, 55.1]  # None = not reported in paper
colors = ['#95a5a6', '#95a5a6', '#2ecc71']

fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(len(models))
bars = []
for i, (m, r, c) in enumerate(zip(models, recall_at_1, colors)):
    if r is not None:
        bar = ax.bar(i, r, color=c, width=0.5, alpha=0.85)
        ax.text(i, r + 0.5, f'{r:.1f}%', ha='center', va='bottom', fontweight='bold')
    else:
        ax.bar(i, 0, color=c, width=0.5, alpha=0.3)
        ax.text(i, 2, 'N/A', ha='center', va='bottom', color='gray')

ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=11)
ax.set_ylabel('Recall@1 (%)', fontsize=12)
ax.set_title('Image-Text Retrieval: Recall@1\n(Paper-Reported — Not Reproduced)', fontsize=13)
ax.set_ylim(0, 70)
ax.grid(axis='y', alpha=0.3)
ax.axhline(y=42.8, color='red', linestyle='--', alpha=0.5, label='CheXzero baseline')
ax.legend()

plt.tight_layout()
Path('../results').mkdir(exist_ok=True)
plt.savefig('../results/retrieval_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Ablation Study Results

In [None]:
# Paper-reported ablation (NOT reproduced)
ablation_configs = [
    'Full Model', '− Species\nModule', '− Species\nLoss', '− Both'
]
ablation_r1 = [55.1, 48.6, 51.3, 44.8]
colors = ['#2ecc71', '#e74c3c', '#e67e22', '#c0392b']

fig, ax = plt.subplots(figsize=(9, 5))
bars = ax.bar(ablation_configs, ablation_r1, color=colors, width=0.5, alpha=0.85)

for bar, val in zip(bars, ablation_r1):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
            f'{val:.1f}%', ha='center', va='bottom', fontweight='bold')

ax.set_ylabel('i2t Recall@1 (%)', fontsize=12)
ax.set_title('Ablation Study — i2t Recall@1\n(Paper-Reported — Not Reproduced)', fontsize=13)
ax.set_ylim(35, 62)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('../results/ablation_results.png', dpi=150, bbox_inches='tight')
plt.show()
print('Paper-Reported — Not Reproduced')

## 4. Zero-Shot Classification Breakdown

In [None]:
# Paper-reported (NOT reproduced)
species = ['Overall', 'Canine', 'Feline']
accuracy = [77.3, 79.0, 75.5]

fig, ax = plt.subplots(figsize=(7, 4))
bars = ax.bar(species, accuracy, color=['#3498db', '#2ecc71', '#e74c3c'], width=0.4, alpha=0.85)

for bar, val in zip(bars, accuracy):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
            f'{val:.1f}%', ha='center', fontweight='bold')

ax.set_ylabel('Zero-Shot Accuracy (%)', fontsize=12)
ax.set_title('Zero-Shot Species Classification\n(Paper-Reported — Not Reproduced)', fontsize=13)
ax.set_ylim(60, 90)
ax.grid(axis='y', alpha=0.3)
ax.axhline(y=50, color='gray', linestyle='--', alpha=0.5, label='Random baseline (50%)')
ax.legend()

plt.tight_layout()
plt.savefig('../results/classification_breakdown.png', dpi=150, bbox_inches='tight')
plt.show()