# DINOv2 model
Notebook to check implementation of the DINOv2 model from PlantCLEF

https://github.com/dsgt-kaggle-clef/fungiclef-2025/blob/main/fungiclef/preprocessing/embedding.py

In [None]:
%load_ext autoreload
%autoreload 2

: 

In [None]:
from fungiclef.config import get_device

device = get_device()
device

In [None]:
import timm
import torch
import pytorch_lightning as pl
from fungiclef.model_setup import setup_fine_tuned_model


class DINOv2LightningModel(pl.LightningModule):
    """PyTorch Lightning module for extracting embeddings from a fine-tuned DINOv2 model."""

    def __init__(
        self,
        model_path: str = setup_fine_tuned_model(),
        model_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
    ):
        super().__init__()
        self.model_device = get_device()
        self.num_classes = 7806  # total plant species

        # load the fine-tuned model
        self.model = timm.create_model(
            model_name,
            pretrained=False,
            num_classes=self.num_classes,
            checkpoint_path=model_path,
        )

        # load transform
        self.data_config = timm.data.resolve_model_data_config(self.model)
        self.transform = timm.data.create_transform(
            **self.data_config, is_training=False
        )

        # move model to device
        self.model.to(self.model_device)
        self.model.eval()

    def forward(self, batch):
        """Extract [CLS] token embeddings using fine-tuned model."""
        with torch.no_grad():
            batch = batch.to(self.model_device)  # move to device

            if batch.dim() == 5:  # (B, grid_size**2, C, H, W)
                B, G, C, H, W = batch.shape
                batch = batch.view(B * G, C, H, W)  # (B * grid_size**2, C, H, W)
            # forward pass
            features = self.model.forward_features(batch)
            embeddings = features[:, 0, :]  # extract [CLS] token

        return embeddings

    def predict_step(self, batch, batch_idx):
        """Runs inference on batch and returns embeddings and top-K logits."""
        return self(batch)  # [CLS] token embeddings

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from fungiclef.serde import deserialize_image


class FungiDataset(Dataset):
    def __init__(self, df, transform=None, col_name: str = "data"):
        self.df = df
        self.transform = transform
        self.col_name = col_name

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_bytes = self.df.iloc[idx][self.col_name]
        img = deserialize_image(img_bytes)

        if self.transform:
            return self.transform(img)  # (C, H, W)
        return ToTensor()(img)  # (C, H, W)


class FungiDataModule(pl.LightningDataModule):
    """LightningDataModule for handling dataset loading and preparation."""

    def __init__(
        self,
        pandas_df,
        batch_size=32,
        num_workers=4,
    ):
        super().__init__()
        self.pandas_df = pandas_df
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        """Set up dataset and transformations."""

        self.model = DINOv2LightningModel()
        self.dataset = FungiDataset(
            self.pandas_df,
            self.model.transform,  # Use the model's transform
        )

    def predict_dataloader(self):
        """Returns DataLoader for inference."""
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
        )

In [None]:
import pandas as pd


def pl_trainer_pipeline(
    pandas_df: pd.DataFrame,
    batch_size: int = 32,
    cpu_count: int = 1,
):
    """Pipeline to extract embeddings and top-k logits using PyTorch Lightning."""

    # initialize DataModule
    data_module = FungiDataModule(
        pandas_df,
        batch_size=batch_size,
        num_workers=cpu_count,
    )

    # initialize Model
    model = DINOv2LightningModel()

    # define Trainer (inference mode)
    trainer = pl.Trainer(
        accelerator=get_device(),
        devices=1,
        enable_progress_bar=True,
    )

    # run inference
    predictions = trainer.predict(model, datamodule=data_module)

    all_embeddings = []
    for batch in predictions:
        embed_batch = batch  # batch: List[embeddings]
        all_embeddings.append(embed_batch)  # keep embeddings as tensors

    # convert embeddings to tensor
    embeddings = torch.cat(all_embeddings, dim=0)  # shape: [len(df), grid_size**2, 768]
    return embeddings

In [None]:
data_path = "~/p-dsgt_clef2025-0/shared/fungiclef/data"
train_path = f"{data_path}/dataset/processed/train_serialized.parquet"

# read train dataframe
train_df = pd.read_parquet(train_path)
train_df.head(5)

In [None]:
subset_df = train_df.head(700)  # 10% of the data for testing the pipeline
len(subset_df)

### run embedding pipeline

In [None]:
# extract embeddings
embeddings = pl_trainer_pipeline(
    subset_df,
    batch_size=64,
    cpu_count=2,
)

In [None]:
embeddings.shape

In [None]:
subset_df.columns

In [None]:
embed_sub_df = subset_df[["filename"]].copy()
embed_sub_df.head()
embed_sub_df["embeddings"] = embeddings.cpu().tolist()
display(embed_sub_df[["filename", "embeddings"]])

In [None]:
import os
from pathlib import Path

# get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

# write to parquet
project_path = f"{root}/p-dsgt_clef2025-0/shared/fungiclef"
output_path = f"{project_path}/temp/embeddings/subset_train_embeddings.parquet"
# make dir if not exist
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
embed_sub_df.to_parquet(output_path)

In [None]:
# read embed data
embed_df = pd.read_parquet(output_path)
embed_df.head()

### run embed pipeline on entire training data

In [None]:
# extract embeddings
embeddings = pl_trainer_pipeline(
    train_df,
    batch_size=64,
    cpu_count=2,
)

In [None]:
embeddings.shape

In [None]:
embed_train_df = train_df[["filename"]].copy()
embed_train_df.head()
embed_train_df["embeddings"] = embeddings.cpu().tolist()
display(embed_train_df[["filename", "embeddings"]])

In [None]:
# write to parquet
output_path = f"{project_path}/temp/embeddings/train_embeddings.parquet"
# make dir if not exist
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
embed_train_df.to_parquet(output_path)

In [None]:
# read embed data
embed_df = pd.read_parquet(output_path)
embed_df.head()

In [None]:
len(embed_df), len(train_df)