Skip to content

beneroth13/dinov2

 
 

Repository files navigation

Low-resource finetuning of foundation models beats state-of-the-art in histopathology

This is the repository of Low-resource finetuning of foundation models beats state-of-the-art in histopathology which was accepted at ISBI 2024. It is a slightly adapted version of the original DINOv2, GitHub repository.

Finetuning can be compute efficient

Title

We propose finetuning a DINOv2 ViT-S, which yields at least equal performance compared to CTransPath and RetCCL but in a fraction of domain specific training time. Performance is measured on three datasets: TCGA & CPTAC (WSI-level classification) and NCT-CRC (patch-level classification).

Loss and performance over time

Performance over time of finetuning a ViT-s with DINOv2: a) on NCT-CRC and evaluating on the external NCT- CRC testset on patch-level classification and b) on TCGA and testing on TCGA (5-fold cross-validation) and CPTAC (external testset) on WSI-level classification.

Data

For the finetuning process, we utilized histopathological data from two primary datasets:

For testing purposes, we incorporated two additional external datasets:

We used the following testing pipeline for TCGA and CPTAC:

Model farm

We make all models as well as heads used for training publicly available in the following.

Pretrained models finetuned on NCT-CRC-100K

model # of
params
# of
iterations
CRC-VAL-HE-7K
20-NN balanced acc
CRC-VAL-HE-7K
linear balanced acc
teacher backbone
ViT-S/14 21 M 2k 93.8% 92.7% teacher weights
ViT-g/14 1,100 M 10k 93.4% 93.7% teacher weights

Pretrained models finetuned on TCGA

model # of
params
# of
iterations
TCGA
AUROC
CPTAC
AUROC
teacher backbone
ViT-S/14 21 M 30k 89% 85% teacher weights
ViT-g/14 1,100 M 60k 84% 79% teacher weights

Load pretrained model

import torch
import torch.nn as nn

DINO_PATH_FINETUNED_DOWNLOADED=''

def get_dino_finetuned_downloaded():
    # load the original DINOv2 model with the correct architecture and parameters. The positional embedding is too large.
    # load vits or vitg
    model=torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
    #model=torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
    # load finetuned weights
    pretrained = torch.load(DINO_PATH_FINETUNED_DOWNLOADED, map_location=torch.device('cpu'))
    # make correct state dict for loading
    new_state_dict = {}
    for key, value in pretrained['teacher'].items():
        if 'dino_head' in key:
            print('not used')
        else:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = value
    #change shape of pos_embed, shape depending on vits or vitg
    pos_embed = nn.Parameter(torch.zeros(1, 257, 384))
    #pos_embed = nn.Parameter(torch.zeros(1, 257, 1536))
    model.pos_embed = pos_embed
    # load state dict
    model.load_state_dict(new_state_dict, strict=True)
    return model

model=get_dino_finetuned_downloaded()

Installation

This requires the same prerequisites as the original DINOv2 implementation.

The training and evaluation code requires PyTorch 2.0 and xFormers 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:

conda (Recommended) - Clone the repository and then create and activate a dinov2 conda environment using the provided environment definition:

conda env create -f conda.yaml
conda activate dinov2

pip - Clone the repository and then use the provided requirements.txt to install the dependencies:

pip install -r requirements.txt

Use the pipeline

Currently, the github repository is meant to run on one GPU only. It can simply be run by this line of code once all the hyperparameters are set in the dinov2/dinov2/configs/ssl_default_config.yaml:

python dinov2/train/train.py --config-file dinov2/configs/ssl_default_config.yaml

If you want to use more than one GPU, it is important to change the sampler in train.py to a sampler supporting sharding (e.g. SamplerType.SHARDED_INFINITE) and to change the StateDictType in fsdp/__init__.py. Then the starting is done via

torchrun --nproc_per_node=2 dinov2/dinov2/train/train.py --config-file dinov2/configs/ssl_default_config.yaml

nproc_per_node corresponds to the number of GPUs.

Of course arguments can be passed with the function as well (see also the original DINOv2).

To run it, you will have to change the paths to your own dataset in the dinov2/configs/ssl_default_config.yaml. The csv files should just contain the paths for the image files.

Continue finetuning

If you want to continue finetuning or use the DINO heads, the remaining weights can be found here:

model dataset # of
iterations
student backbone student DINO head teacher DINO head
ViT-S/14 NCT-CRC-100K 2k student backbone student DINO head teacher DINO head
ViT-g/14 NCT-CRC-100K 10k student backbone student DINO head teacher DINO head
ViT-S/14 TCGA 30k student backbone student DINO head teacher DINO head
ViT-g/14 TCGA 60k student backbone student DINO head teacher DINO head

To load these weights, it is enough to add the path to the config file under head_path. The path that has to be added is to a folder containing the weights. The weights have to be renamed after downloading them for the available code to work (e.g. student_dino_head_checkpoint.pth). More details can be found in the file /dinov2/dinov2/train/ssl_meta_arch.py.

Citation

If you find our research helpful, please consider citing:

@misc{roth2024lowresource,
  title={Low-resource finetuning of foundation models beats state-of-the-art in histopathology},
  author={Benedikt Roth and Valentin Koch and Sophia J. Wagner and Julia A. Schnabel and Carsten Marr and Tingying Peng},
  year={2024},
  eprint={2401.04720},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}

About

PyTorch code and models for the DINOv2 self-supervised learning method, own data set and own adapted training.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 67.9%
  • Python 32.1%