In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image
from rich import print
import torch
from torchvision.tv_tensors import Image as TVImage
import torchvision.transforms.functional as F 

from trailcaml import TrailCaML, CategoryLabels

In [None]:
def show_images(batch, n_samples=4):
    fig, axes = plt.subplots(1, n_samples, figsize=(15, 5))
    for i in range(n_samples):
        img = batch[i]
        img = F.to_pil_image(img)
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()

In [None]:
image_dir = Path("data/20240624/")
jpgs = [image_dir / Path(f) for f in os.listdir(image_dir) if f.endswith("JPG")]
images = [Image.open(jpg) for jpg in jpgs]

In [None]:
tv_images = [TVImage(img) for img in images]

In [None]:
batch = torch.stack(tv_images)
batch.shape

In [None]:
base_dir = Path("lightning_logs/lightning_logs")
version = 14
epoch=2
val_loss = 0.47
checkpoint = base_dir / f"version_{version}/checkpoints/epoch={epoch}-val_loss={val_loss:.2f}.ckpt"
tcml = TrailCaML.load_from_checkpoint(checkpoint)
tcml.freeze()

In [None]:
with torch.no_grad():
    predictions = tcml(batch)
    confidences = predictions.sigmoid()

In [None]:
preds = [{
    "image": image,
    "confidence": {label: f"{c[i].item():.2f}" for i, label in CategoryLabels.items() if i < len(c)},
} for image, c in zip(jpgs, confidences)]
preds.sort(key=lambda p: p['image'])
print(*preds)