# DINOv2 model
Notebook to check implementation of the DINOv2 model from PlantCLEF

https://github.com/dsgt-kaggle-clef/fungiclef-2025/blob/main/fungiclef/preprocessing/embedding.py

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from fungiclef.config import get_device

device = get_device()
device

'cuda'

In [3]:
import timm
import torch
import pytorch_lightning as pl
from fungiclef.model_setup import setup_fine_tuned_model


class DINOv2LightningModel(pl.LightningModule):
    """PyTorch Lightning module for extracting embeddings from a fine-tuned DINOv2 model."""

    def __init__(
        self,
        model_path: str = setup_fine_tuned_model(),
        model_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
    ):
        super().__init__()
        self.model_device = get_device()
        self.num_classes = 7806  # total plant species

        # load the fine-tuned model
        self.model = timm.create_model(
            model_name,
            pretrained=False,
            num_classes=self.num_classes,
            checkpoint_path=model_path,
        )

        # load transform
        self.data_config = timm.data.resolve_model_data_config(self.model)
        self.transform = timm.data.create_transform(
            **self.data_config, is_training=False
        )

        # move model to device
        self.model.to(self.model_device)
        self.model.eval()

    def forward(self, batch):
        """Extract [CLS] token embeddings using fine-tuned model."""
        with torch.no_grad():
            batch = batch.to(self.model_device)  # move to device

            if batch.dim() == 5:  # (B, grid_size**2, C, H, W)
                B, G, C, H, W = batch.shape
                batch = batch.view(B * G, C, H, W)  # (B * grid_size**2, C, H, W)
            # forward pass
            features = self.model.forward_features(batch)
            embeddings = features[:, 0, :]  # extract [CLS] token

        return embeddings

    def predict_step(self, batch, batch_idx):
        """Runs inference on batch and returns embeddings and top-K logits."""
        return self(batch)  # [CLS] token embeddings

In [4]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from fungiclef.serde import deserialize_image


class FungiDataset(Dataset):
    def __init__(self, df, transform=None, col_name: str = "data"):
        self.df = df
        self.transform = transform
        self.col_name = col_name

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_bytes = self.df.iloc[idx][self.col_name]
        img = deserialize_image(img_bytes)

        if self.transform:
            return self.transform(img)  # (C, H, W)
        return ToTensor()(img)  # (C, H, W)


class FungiDataModule(pl.LightningDataModule):
    """LightningDataModule for handling dataset loading and preparation."""

    def __init__(
        self,
        pandas_df,
        batch_size=32,
        num_workers=4,
    ):
        super().__init__()
        self.pandas_df = pandas_df
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        """Set up dataset and transformations."""

        self.model = DINOv2LightningModel()
        self.dataset = FungiDataset(
            self.pandas_df,
            self.model.transform,  # Use the model's transform
        )

    def predict_dataloader(self):
        """Returns DataLoader for inference."""
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
        )

In [5]:
import pandas as pd


def pl_trainer_pipeline(
    pandas_df: pd.DataFrame,
    batch_size: int = 32,
    cpu_count: int = 1,
):
    """Pipeline to extract embeddings and top-k logits using PyTorch Lightning."""

    # initialize DataModule
    data_module = FungiDataModule(
        pandas_df,
        batch_size=batch_size,
        num_workers=cpu_count,
    )

    # initialize Model
    model = DINOv2LightningModel()

    # define Trainer (inference mode)
    trainer = pl.Trainer(
        accelerator=get_device(),
        devices=1,
        enable_progress_bar=True,
    )

    # run inference
    predictions = trainer.predict(model, datamodule=data_module)

    all_embeddings = []
    for batch in predictions:
        embed_batch = batch  # batch: List[embeddings]
        all_embeddings.append(embed_batch)  # keep embeddings as tensors

    # convert embeddings to tensor
    embeddings = torch.cat(all_embeddings, dim=0)  # shape: [len(df), grid_size**2, 768]
    return embeddings

In [6]:
data_path = "~/p-dsgt_clef2025-0/shared/fungiclef/data"
train_path = f"{data_path}/dataset/processed/train_serialized.parquet"

# read train dataframe
train_df = pd.read_parquet(train_path)
train_df.head(5)

Unnamed: 0,eventDate,year,month,day,habitat,countryCode,scientificName,kingdom,phylum,class,...,region,district,filename,category_id,metaSubstrate,poisonous,elevation,landcover,biogeographicalRegion,data
0,2021-01-24,2021,1.0,24.0,Mixed woodland (with coniferous and deciduous ...,DK,Xylohypha ferruginosa (Corda) S.Hughes,Fungi,Ascomycota,Eurotiomycetes,...,Sjælland,Næstved,0-3052832307.JPG,2421,wood,0,0.0,16.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
1,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,0-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
2,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,1-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
3,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,2-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
4,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,3-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...


In [7]:
subset_df = train_df.head(700)  # 10% of the data for testing the pipeline
len(subset_df)

700

### run embedding pipeline

In [8]:
# extract embeddings
embeddings = pl_trainer_pipeline(
    subset_df,
    batch_size=64,
    cpu_count=2,
)

/storage/home/hcoda1/9/mgustineli3/clef/fungiclef-2025/.venv/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.10 /storage/home/hcoda1/9/mgustineli3/clef/fungicle ...
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [9]:
embeddings.shape

torch.Size([700, 768])

In [10]:
subset_df.columns

Index(['eventDate', 'year', 'month', 'day', 'habitat', 'countryCode',
       'scientificName', 'kingdom', 'phylum', 'class', 'order', 'family',
       'genus', 'specificEpithet', 'hasCoordinate', 'species',
       'iucnRedListCategory', 'substrate', 'latitude', 'longitude',
       'coorUncert', 'observationID', 'region', 'district', 'filename',
       'category_id', 'metaSubstrate', 'poisonous', 'elevation', 'landcover',
       'biogeographicalRegion', 'data'],
      dtype='object')

In [11]:
embed_sub_df = subset_df[["filename"]].copy()
embed_sub_df.head()
embed_sub_df["embeddings"] = embeddings.cpu().tolist()
display(embed_sub_df[["filename", "embeddings"]])

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."
...,...,...
695,0-2864912308.JPG,"[0.12116563320159912, 0.6529977917671204, -0.3..."
696,1-3005502302.JPG,"[0.08340801298618317, 1.1956936120986938, -0.7..."
697,2-3005502302.JPG,"[0.35695040225982666, 0.791736900806427, -0.65..."
698,3-3005502302.JPG,"[-0.5019737482070923, -1.1750932931900024, -2...."


In [12]:
import os
from pathlib import Path

# get list of stored filed in cloud bucket
root = Path(os.path.expanduser("~"))
! date

# write to parquet
project_path = f"{root}/p-dsgt_clef2025-0/shared/fungiclef"
output_path = f"{project_path}/temp/embeddings/subset_train_embeddings.parquet"
# make dir if not exist
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
embed_sub_df.to_parquet(output_path)

Fri May  9 22:56:01 EDT 2025


In [13]:
# read embed data
embed_df = pd.read_parquet(output_path)
embed_df.head()

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."


### run embed pipeline on entire training data

In [14]:
# extract embeddings
embeddings = pl_trainer_pipeline(
    train_df,
    batch_size=64,
    cpu_count=2,
)

/storage/home/hcoda1/9/mgustineli3/clef/fungiclef-2025/.venv/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.10 /storage/home/hcoda1/9/mgustineli3/clef/fungicle ...
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [15]:
embeddings.shape

torch.Size([7819, 768])

In [16]:
embed_train_df = train_df[["filename"]].copy()
embed_train_df.head()
embed_train_df["embeddings"] = embeddings.cpu().tolist()
display(embed_train_df[["filename", "embeddings"]])

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."
...,...,...
7814,0-4100093368.JPG,"[-0.8721289038658142, 1.655833125114441, 0.139..."
7815,2-4100093368.JPG,"[-0.47504621744155884, 1.323290228843689, 0.90..."
7816,2-3429079314.JPG,"[0.8044366240501404, 1.585023045539856, 0.8527..."
7817,0-4847339663.JPG,"[-1.722570538520813, 1.0407721996307373, 0.565..."


In [17]:
# write to parquet
output_path = f"{project_path}/temp/embeddings/train_embeddings.parquet"
# make dir if not exist
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
embed_train_df.to_parquet(output_path)

In [18]:
# read embed data
embed_df = pd.read_parquet(output_path)
embed_df.head()

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."


In [19]:
len(embed_df)

7819