Skip to content

This repository is the official implementation of vox2vec: A Framework for Self-supervised Contrastive Learning of Voxel-level Representations in Medical Images

License

mishgon/vox2vec

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

vox2vec: A Framework for Self-supervised Contrastive Learning of Voxel-level Representations in Medical Images

This repository is the official implementation of vox2vec: A Framework for Self-supervised Contrastive Learning of Voxel-level Representations in Medical Images.

Requirements

Make sure you have installed torch compatible with your CUDA version. To install other requirements, run

git clone https://github.com/mishgon/vox2vec.git && cd vox2vec && pip install -e .

User guide

Use vox2vec to extract voxel-level features of 3D computed tomography (CT) images.

The recommended preprocessing of CT images, before applying vox2vec, includes the following steps:

  • cropping to body;
  • resampling to 1 x 1 x 2 mm3 voxel spacing;
  • clipping the intensities to [-1350, 1000]HU and rescaling them to [0, 1] interval.

Since CT images are 3D and usually of high resolution, the common practice is patch-wise deep learning pipelines, when an input image is splitted into (overlapping) patches, a neural network is applied to individual patches, and the patch-wise predictions are then aggregated to obtain a final prediction.

vox2vec can be easily plugged in such pipelines as a patch-wise feature extractor. The recommended patch size is (128, 128, 32). For such a patch, vox2vec returns a feature pyramid containing 6 feature maps with increasing number of channels and decreasing resolutions. See an example below

from vox2vec import vox2vec_contrastive

# the pre-trained model is downloaded from the Huggingface Hub 🤗
vox2vec = vox2vec_contrastive()

# the recommended input patch size is (128, 128, 32)
input_patch = torch.rand((3, 1, 128, 128, 32))

# then, the output feature pyramid has size
# [(3, 16, 128, 128, 32),
#  (3, 32, 64, 64, 16),
#  ...,
#  (3, 512, 4, 4, 1)]
feature_pyramid = vox2vec(input_patch)  

You can also download the pre-trained weights manually from here and initialize the vox2vec.nn.fpn.FPN3d(in_channels=1, base_channels=16, num_scales=6) architecture by hands.

Evaluation of the pre-trained model

Below, we describe how to evaluate the pre-trained vox2vec model on the BTCV dataset [1].

First, follow the instructions to download the BTCV dataset. Make sure that the data are unzipped as follows:

/path/to/btcv
    - RawData.zip
    - RawData
        - Training
            - img
                - img0001.nii.gz
                ...
            - label
                - label0001.nii.gz
                ...
        - Testing
            - img
                - img0061.nii.gz
                ...

Also, prepare empty folders /path/to/cache for caching the preprocessed data, and /path/to/logs for logging.

To evaluate the pre-trained model in the linear and non-linear probing setups, run

python eval.py --btcv_dir /path/to/btcv --cache_dir /path/to/cache --ckpt /path/to/vox2vec.pt --setup probing --log_dir /path/to/logs/

To evaluate the pre-trained model in the fine-tuning setup, run

python eval.py --btcv_dir /path/to/btcv --cache_dir /path/to/cache --ckpt /path/to/vox2vec.pt --setup fine-tuning --log_dir /path/to/logs/

As a baseline, train the same architecture from scratch by running

python eval.py --btcv_dir /path/to/btcv --cache_dir /path/to/cache --setup from_scratch --log_dir /path/to/logs/

You likely get the results close to

Model Avg Dice Score
FPN3d from scratch 0.77 +- 0.03
vox2vec linear probing 0.65 +- 0.01
vox2vec non-linear probing 0.73 +- 0.01
vox2vec fine-tuning 0.78 +- 0.02

Pre-training

To reproduce the pre-training of vox2vec we need a large dataset, consisting of six publicly available CT datasets: AMOS [2], FLARE2022 [3], NLST [4], NSCLC [5], MIDRC [6]. We use amid package, which provides us with unified interfaces of these datasets as well as many other publicly available medical image datasets. We refer to amid docs for data downloading instructions.

Since you have prepared the data, run

python warmup_cache.py --cache_dir /path/to/cache

to debug and warmup the pre-training dataset.

To pretrain the vox2vec model, run

python pretrain.py --btcv_dir /path/to/btcv --cache_dir /path/to/cache/ --log_dir /path/to/logs

Citation

If you found this code helpful, please consider citing:

@article{goncharov2023vox2vec,
  title={vox2vec: A Framework for Self-supervised Contrastive Learning of Voxel-level Representations in Medical Images},
  author={Goncharov, Mikhail and Soboleva, Vera and Kurmukov, Anvar and Pisov, Maksim and Belyaev, Mikhail},
  journal={arXiv},
  year={2023},
  url={https://arxiv.org/abs/2307.14725}
}

References

[1] Landman, B., et al.: Miccai multiatlas labeling beyond the cranial vault–workshop and challenge. In: Proc. MICCAI Multi-Atlas Labeling Beyond Cranial Vault—Workshop Challenge. vol. 5, p. 12 (2015)

[2] Ji, Y., et al.: Amos: A large-scale abdominal multi-organ benchmark for versatile medical image segmentation. arXiv preprint arXiv:2206.08023 (2022)

[3] Ma, J., et al.: Fast and low-gpu-memory abdomen ct organ segmentation: the flare challenge. Medical Image Analysis 82, 102616 (2022)

[4] Data from the national lung screening trial (nlst) (2013). https://doi.org/10.7937/TCIA.HMQ8-J677, https://wiki.cancerimagingarchive.net/x/-oJY

[5] Aerts, H., Velazquez, E.R., Leijenaar, R., Parmar, C., Grossmann, P., Cavalho, S., Bussink, J., Monshouwer, R., Haibe-Kains, B., Rietveld, D., et al.: Data from nsclc-radiomics. The cancer imaging archive (2015)

[6] Armato III, S.G., McLennan, G., Bidaut, L., McNitt-Gray, M.F., Meyer, C.R., Reeves, A.P., Zhao, B., Aberle, D.R., Henschke, C.I., Hoffman, E.A., et al.: The lung image database consortium (lidc) and image database resource initiative (idri): a completed reference database of lung nodules on ct scans. Medical physics 38(2), 915–931 (2011) [7] Tsai, E., et al.: Medical imaging data resource center - rsna international covid radiology database release 1a - chest ct covid+ (midrc-ricord-1a) (2020). https://doi.org/10.7937/VTW4-X588, https://wiki.cancerimagingarchive.net/x/DoDTB

About

This repository is the official implementation of vox2vec: A Framework for Self-supervised Contrastive Learning of Voxel-level Representations in Medical Images

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages