# DINOv2 model
Notebook to check implementation of the DINOv2 model from PlantCLEF

https://github.com/dsgt-kaggle-clef/fungiclef-2025/blob/main/fungiclef/preprocessing/embedding.py

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# # navigate to fungiclef-2025 repo
# %cd /content/drive/MyDrive/fungiclef-2025/fungiclef-2025

In [3]:
# !ls -la

In [4]:
# !pip install -e .

In [5]:
from fungiclef.config import get_device

device = get_device()
device

'cuda'

In [None]:
import os
from pathlib import Path


def get_model_dir() -> str:
    """
    Get the model directory in the plantclef shared project for the current user on PACE
    """
    # get root directory
    root_dir = "/content/drive/MyDrive/fungiclef-2025"
    # check if model directory exists, create if not
    model_dir = os.path.join(root_dir, "model")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    # return model directory
    return model_dir


def setup_fine_tuned_model() -> str:
    """
    Downloads and unzips a model from PACE and returns the path to the specified model file.
    Checks if the model already exists and skips download and extraction if it does.

    :return: Absolute path to the model file.
    """
    model_base_path = get_model_dir()
    tar_filename = "model_best.pth.tar"
    pretrained_model = (
        "vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all"
    )
    relative_model_path = f"pretrained_models/{pretrained_model}/{tar_filename}"
    full_model_path = os.path.join(model_base_path, relative_model_path)

    # Check if the model file exists
    if not os.path.exists(full_model_path):
        raise FileNotFoundError(f"Model file not found at: {full_model_path}")

    # Return the path to the model file
    return full_model_path

In [None]:
import timm
import torch
import pytorch_lightning as pl

from fungiclef.config import get_device


class DINOv2LightningModel(pl.LightningModule):
    """PyTorch Lightning module for extracting embeddings from a fine-tuned DINOv2 model."""

    def __init__(
        self,
        model_path: str = setup_fine_tuned_model(),
        model_name: str = "vit_base_patch14_reg4_dinov2.lvd142m",
    ):
        super().__init__()
        self.model_device = get_device()
        self.num_classes = 7806  # total plant species

        # load the fine-tuned model
        self.model = timm.create_model(
            model_name,
            pretrained=False,
            num_classes=self.num_classes,
            checkpoint_path=model_path,
        )

        # load transform
        self.data_config = timm.data.resolve_model_data_config(self.model)
        self.transform = timm.data.create_transform(
            **self.data_config, is_training=False
        )

        # move model to device
        self.model.to(self.model_device)
        self.model.eval()

    def forward(self, batch):
        """Extract [CLS] token embeddings using fine-tuned model."""
        with torch.no_grad():
            batch = batch.to(self.model_device)  # move to device

            if batch.dim() == 5:  # (B, grid_size**2, C, H, W)
                B, G, C, H, W = batch.shape
                batch = batch.view(B * G, C, H, W)  # (B * grid_size**2, C, H, W)
            # forward pass
            features = self.model.forward_features(batch)
            embeddings = features[:, 0, :]  # extract [CLS] token

        return embeddings

    def predict_step(self, batch, batch_idx):
        """Runs inference on batch and returns embeddings and top-K logits."""
        return self(batch)  # [CLS] token embeddings

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from fungiclef.serde import deserialize_image


class FungiDataset(Dataset):
    def __init__(self, df, transform=None, col_name: str = "data"):
        self.df = df
        self.transform = transform
        self.col_name = col_name

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_bytes = self.df.iloc[idx][self.col_name]
        img = deserialize_image(img_bytes)

        if self.transform:
            return self.transform(img)  # (C, H, W)
        return ToTensor()(img)  # (C, H, W)


class FungiDataModule(pl.LightningDataModule):
    """LightningDataModule for handling dataset loading and preparation."""

    def __init__(
        self,
        pandas_df,
        batch_size=32,
        num_workers=4,
    ):
        super().__init__()
        self.pandas_df = pandas_df
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        """Set up dataset and transformations."""

        self.model = DINOv2LightningModel()
        self.dataset = FungiDataset(
            self.pandas_df,
            self.model.transform,  # Use the model's transform
        )

    def predict_dataloader(self):
        """Returns DataLoader for inference."""
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
        )

In [None]:
import pandas as pd
import pytorch_lightning as pl
from tqdm import tqdm


def pl_trainer_pipeline(
    pandas_df: pd.DataFrame,
    batch_size: int = 32,
    cpu_count: int = 1,
):
    """Pipeline to extract embeddings and top-k logits using PyTorch Lightning."""

    # initialize DataModule
    data_module = FungiDataModule(
        pandas_df,
        batch_size=batch_size,
        num_workers=cpu_count,
    )

    # initialize Model
    model = DINOv2LightningModel()

    # define Trainer (inference mode)
    trainer = pl.Trainer(
        accelerator=get_device(),
        devices=1,
        enable_progress_bar=True,
    )

    # run inference
    predictions = trainer.predict(model, datamodule=data_module)

    all_embeddings = []
    for batch in predictions:
        embed_batch = batch  # batch: List[embeddings]
        all_embeddings.append(embed_batch)  # keep embeddings as tensors

    # convert embeddings to tensor
    embeddings = torch.cat(all_embeddings, dim=0)  # shape: [len(df), grid_size**2, 768]
    return embeddings

In [None]:
data_path = "/content/drive/MyDrive/fungiclef-2025/data"
parquet_path = f"{data_path}/processed/subset_train_serialized.parquet"

subset_df = pd.read_parquet(parquet_path)
subset_df.head()

Unnamed: 0,eventDate,year,month,day,habitat,countryCode,scientificName,kingdom,phylum,class,...,region,district,filename,category_id,metaSubstrate,poisonous,elevation,landcover,biogeographicalRegion,data
0,2021-01-24,2021,1.0,24.0,Mixed woodland (with coniferous and deciduous ...,DK,Xylohypha ferruginosa (Corda) S.Hughes,Fungi,Ascomycota,Eurotiomycetes,...,Sjælland,Næstved,0-3052832307.JPG,2421,wood,0,0.0,16.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
1,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,0-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
2,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,1-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
3,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,2-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...
4,2021-02-27,2021,2.0,27.0,garden,DK,"Comatricha alta Preuss, 1851",Protozoa,Mycetozoa,Myxomycetes,...,Hovedstaden,Gribskov,3-3061954303.JPG,386,wood,0,0.0,17.0,continental,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...


In [None]:
len(subset_df)

700

### run embedding pipeline

In [None]:
# extract embeddings
embeddings = pl_trainer_pipeline(
    subset_df,
    batch_size=64,
    cpu_count=2,
)

INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [None]:
embeddings.shape

torch.Size([700, 768])

In [None]:
subset_df.columns

Index(['eventDate', 'year', 'month', 'day', 'habitat', 'countryCode',
       'scientificName', 'kingdom', 'phylum', 'class', 'order', 'family',
       'genus', 'specificEpithet', 'hasCoordinate', 'species',
       'iucnRedListCategory', 'substrate', 'latitude', 'longitude',
       'coorUncert', 'observationID', 'region', 'district', 'filename',
       'category_id', 'metaSubstrate', 'poisonous', 'elevation', 'landcover',
       'biogeographicalRegion', 'data'],
      dtype='object')

In [None]:
embed_df = subset_df[["filename"]].copy()
embed_df.head()
embed_df["embeddings"] = embeddings.cpu().tolist()
display(embed_df[["filename", "embeddings"]])

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."
...,...,...
695,0-2864912308.JPG,"[0.12116563320159912, 0.6529977917671204, -0.3..."
696,1-3005502302.JPG,"[0.08340801298618317, 1.1956936120986938, -0.7..."
697,2-3005502302.JPG,"[0.35695040225982666, 0.791736900806427, -0.65..."
698,3-3005502302.JPG,"[-0.5019737482070923, -1.1750932931900024, -2...."


In [None]:
# write to parquet
output_path = f"{data_path}/embeddings/subset_train_embeddings.parquet"
# make dir if not exist
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
embed_df.to_parquet(output_path)

In [None]:
# read embed data
embed_df = pd.read_parquet(output_path)
embed_df.head()

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."


### run embed pipeline on entire training data

In [None]:
train_path = f"{data_path}/processed/train_serialized.parquet"
train_df = pd.read_parquet(train_path)
len(train_df)

7819

In [None]:
import cv2
import numpy as np
import io


def detect_and_fix_bytes(img_bytes):
    """Detects if a JPEG image is truncated and attempts to fix it."""
    # Check for JPEG EOF marker
    if img_bytes[-2:] == b"\xff\xd9":
        return img_bytes  # Image is OK

    # Decode using OpenCV
    nparr = np.frombuffer(img_bytes, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    if img is None:
        print("Unable to decode image — possibly deeply corrupted")
        return None

    # Re-encode to JPEG and return fixed bytes
    success, encoded_img = cv2.imencode(".jpg", img)
    if success:
        return encoded_img.tobytes()
    else:
        print("Failed to encode image")
        return None

In [None]:
from tqdm import tqdm

# enable progress bar for visibility
tqdm.pandas()
train_df["data"] = train_df["data"].progress_apply(detect_and_fix_bytes)
# drop any rows that couldn't be fixed
train_df = train_df.dropna(subset=["data"]).reset_index(drop=True)
len(train_df)


  0%|          | 0/7819 [00:00<?, ?it/s][A
 86%|████████▌ | 6694/7819 [00:00<00:00, 66694.87it/s][A

Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possib

100%|██████████| 7819/7819 [00:03<00:00, 2345.89it/s] 

Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted
Unable to decode image — possibly deeply corrupted





7466

In [None]:
# extract embeddings
embeddings = pl_trainer_pipeline(
    train_df,
    batch_size=64,
    cpu_count=2,
)

INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

OSError: Caught OSError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "<ipython-input-12-b3e15a4edbb4>", line 17, in __getitem__
    img = deserialize_image(img_bytes)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/drive/MyDrive/fungiclef-2025/fungiclef-2025/fungiclef/serde.py", line 10, in deserialize_image
    return Image.open(buffer).convert("RGB")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/PIL/Image.py", line 982, in convert
    self.load()
  File "/usr/local/lib/python3.11/dist-packages/PIL/ImageFile.py", line 386, in load
    raise OSError(msg)
OSError: image file is truncated (5 bytes not processed)


In [None]:
embeddings.shape

In [None]:
embed_train_df = train_df[["filename"]].copy()
embed_train_df.head()
embed_train_df["embeddings"] = embeddings.cpu().tolist()
display(embed_train_df[["filename", "embeddings"]])

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."
...,...,...
695,0-2864912308.JPG,"[0.12116563320159912, 0.6529977917671204, -0.3..."
696,1-3005502302.JPG,"[0.08340801298618317, 1.1956936120986938, -0.7..."
697,2-3005502302.JPG,"[0.35695040225982666, 0.791736900806427, -0.65..."
698,3-3005502302.JPG,"[-0.5019737482070923, -1.1750932931900024, -2...."


In [None]:
# write to parquet
output_path = f"{data_path}/embeddings/train_embeddings.parquet"
# make dir if not exist
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
embed_train_df.to_parquet(output_path)

In [None]:
# read embed data
embed_df = pd.read_parquet(output_path)
embed_df.head()

Unnamed: 0,filename,embeddings
0,0-3052832307.JPG,"[-1.2559032440185547, 1.8512070178985596, -0.2..."
1,0-3061954303.JPG,"[-1.6616631746292114, -0.033579133450984955, -..."
2,1-3061954303.JPG,"[0.09439272433519363, 0.09596756845712662, -0...."
3,2-3061954303.JPG,"[1.1593152284622192, 0.2725003957748413, -0.39..."
4,3-3061954303.JPG,"[0.10036885738372803, 1.0950580835342407, 0.41..."
