# VetVision-LM â€” Dataset & Model Exploration

**Author:** Devarchith Parashara Batchu  
**Repository:** https://github.com/devarchith/vetvision-lm

This notebook explores:
1. Data loading and augmentation pipeline
2. Model architecture inspection
3. Loss function behaviour
4. Smoke-test training run

In [None]:
import sys
sys.path.insert(0, '../src')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

print('PyTorch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())

## 1. Synthetic Data Pipeline

In [None]:
from data.chexpert import SyntheticCheXpertDataset
from data.veterinary import SyntheticVetDataset
from torch.utils.data import DataLoader

chex_ds = SyntheticCheXpertDataset(num_samples=50)
vet_ds = SyntheticVetDataset(num_samples=50)

print(f'CheXpert synthetic dataset: {len(chex_ds)} samples')
print(f'Vet synthetic dataset: {len(vet_ds)} samples')

item = chex_ds[0]
print(f'\nImage shape: {item["image"].shape}')
print(f'Labels shape: {item["labels"].shape}')
print(f'Text: {item["text"]}')

## 2. Augmentation Pipeline

In [None]:
from data.augmentations import build_train_transform, build_val_transform
import numpy as np

train_tf = build_train_transform(img_size=224)
val_tf = build_val_transform(img_size=224)

# Create a synthetic X-ray-like image
dummy_xray = Image.fromarray(
    np.random.randint(50, 200, (300, 300, 3), dtype=np.uint8)
)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(dummy_xray, cmap='gray')
axes[0].set_title('Original')

for i, (label, tf) in enumerate([(('Train Aug.', train_tf), ('Val (no aug)', val_tf))[j] for j in range(2)], 1):
    t = tf(dummy_xray)
    img_show = t.permute(1,2,0).numpy()
    img_show = (img_show - img_show.min()) / (img_show.max() - img_show.min() + 1e-8)
    axes[i].imshow(img_show)
    axes[i].set_title(label)

plt.tight_layout()
plt.savefig('../results/augmentation_comparison.png', dpi=100, bbox_inches='tight')
plt.show()

## 3. Model Architecture

In [None]:
from models.vetvision import VetVisionLM

model = VetVisionLM(
    vision_cfg={'name': 'vit_base_patch16_224', 'pretrained': False},
    text_cfg={'name': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext', 'pretrained': False},
    species_cfg={'num_species': 2, 'species_embed_dim': 64, 'output_dim': 512},
    proj_cfg={'embed_dim': 512, 'hidden_dim': 512, 'vision_input_dim': 768, 'text_input_dim': 768},
)

total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters:     {total:,}')
print(f'Trainable parameters: {trainable:,}')

## 4. Smoke-Test Forward Pass

In [None]:
model.eval()
with torch.no_grad():
    images = torch.randn(4, 3, 224, 224)
    texts = ['Canine thoracic radiograph.'] * 4
    species_ids = torch.tensor([0, 1, 0, 1])
    
    out = model(images=images, texts=texts, species_ids=species_ids)
    
    print('vision_embed:', out.vision_embed.shape)
    print('text_embed:  ', out.text_embed.shape)
    print('species_embed:', out.species_embed.shape)
    
    # Check L2 normalisation
    v_norms = out.vision_embed.norm(dim=-1)
    print(f'\nVision embed norms (should be ~1.0): {v_norms.tolist()}')

## 5. Loss Function Exploration

In [None]:
import torch.nn.functional as F
from losses.contrastive import ContrastiveLoss
from losses.species_loss import SpeciesContrastiveLoss, CombinedLoss

cont_loss = ContrastiveLoss(temperature=0.07)
spec_loss = SpeciesContrastiveLoss(temperature=0.07, margin=0.2)
combined = CombinedLoss(cont_loss, spec_loss, lambda_species=0.5)

# Test with random embeddings
B = 8
v = F.normalize(torch.randn(B, 512), p=2, dim=-1)
t = F.normalize(torch.randn(B, 512), p=2, dim=-1)
s = F.normalize(torch.randn(B, 512), p=2, dim=-1)
sp = torch.tensor([0,0,0,0,1,1,1,1])

result = combined(v, t, s, sp)
print('Combined loss:')
for k, val in result.items():
    print(f'  {k}: {val.item():.4f}')