<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/SAMed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Customized Segment Anything Model for Medical Image Segmentation
### [[Paper](https://arxiv.org/pdf/2304.13785.pdf)] [[Github](https://github.com/hitachinsk/SAMed)]
---
[Kaidong Zhang](https://hitachinsk.github.io/), and [Dong Liu](https://faculty.ustc.edu.cn/dongeliu/), technical report, 2023

All rights reserved by the authors of SAMed

We provide the entire inference pipeline in this page.

# Setup environment

In [1]:
!pip install -q einops==0.6.1 icecream==2.1.3 MedPy==0.4.0 monai==1.1.0 opencv_python==4.5.4.58 SimpleITK==2.2.1 tensorboardX==2.6 ml-collections==0.1.1 onnx==1.13.1 onnxruntime==1.14.1 tensorboardX torchmetrics

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m624.9 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.8/151.8 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.3/60.3 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.5/114.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

# Download codes, pretrained weights and test data

In [2]:
# prepare codes
import os
CODE_DIR = 'samed_codes'
os.makedirs(f'./{CODE_DIR}')
!git clone https://github.com/hitachinsk/SAMed.git $CODE_DIR
os.chdir(f'./{CODE_DIR}')

Cloning into 'samed_codes'...
remote: Enumerating objects: 225, done.[K
remote: Counting objects: 100% (102/102), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 225 (delta 86), reused 72 (delta 72), pack-reused 123[K
Receiving objects: 100% (225/225), 635.01 KiB | 5.88 MiB/s, done.
Resolving deltas: 100% (105/105), done.


In [3]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

download_with_pydrive = True

class Downloader(object):
  def __init__(self, use_pydrive):
    self.use_pydrive = use_pydrive
    current_directory = os.getcwd()
    self.save_dir = '.'
    if self.use_pydrive:
      self.authenticate()

  def authenticate(self):
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    self.drive = GoogleDrive(gauth)

  def download_file(self, file_id, file_name):
    file_dst = f'{self.save_dir}/{file_name}'
    if os.path.exists(file_dst):
      print(f'{file_name} already exists')
      return
    downloaded = self.drive.CreateFile({'id': file_id})
    downloaded.FetchMetadata(fetch_all=True)
    downloaded.GetContentFile(file_dst)

downloader = Downloader(download_with_pydrive)
samed_model = {'id': '1P0Bm-05l-rfeghbrT1B62v5eN-3A-uOr', 'name': 'epoch_159.pth'}
sam_model = {'id': '1_oCdoEEu3mNhRfFxeWyRerOKt8OEUvcg', 'name': 'sam_vit_b_01ec64.pth'}
test_data = {'id': '1RczbNSB37OzPseKJZ1tDxa5OO1IIICzK', 'name': 'test_vol_h5.zip'}
downloader.download_file(file_id=samed_model['id'], file_name=samed_model['name'])
downloader.download_file(file_id=sam_model['id'], file_name=sam_model['name'])
downloader.download_file(file_id=test_data['id'], file_name=test_data['name'])



KeyboardInterrupt: ignored

In [None]:
!unzip test_vol_h5.zip

Archive:  test_vol_h5.zip
   creating: test_vol_h5/
  inflating: test_vol_h5/case0038.npy.h5  
  inflating: test_vol_h5/case0036.npy.h5  
  inflating: test_vol_h5/case0035.npy.h5  
  inflating: test_vol_h5/case0032.npy.h5  
  inflating: test_vol_h5/case0029.npy.h5  
  inflating: test_vol_h5/case0025.npy.h5  
  inflating: test_vol_h5/case0022.npy.h5  
  inflating: test_vol_h5/case0008.npy.h5  
  inflating: test_vol_h5/case0004.npy.h5  
  inflating: test_vol_h5/case0003.npy.h5  
  inflating: test_vol_h5/case0001.npy.h5  
  inflating: test_vol_h5/case0002.npy.h5  


# Execute SAMed

In [None]:
%%bash
python test.py --volume_path test_vol_h5 --output_dir results --ckpt sam_vit_b_01ec64.pth --lora_ckpt epoch_159.pth

Namespace(config=None, volume_path='test_vol_h5', dataset='Synapse', num_classes=8, list_dir='./lists/lists_Synapse/', output_dir='results', img_size=512, input_size=224, seed=1234, is_savenii=False, deterministic=1, ckpt='sam_vit_b_01ec64.pth', lora_ckpt='epoch_159.pth', vit_name='vit_b', rank=4, module='sam_lora_image_encoder')
12 test iterations per epoch
idx 0 case case0008 mean_dice 0.669957 mean_hd95 12.885176
idx 1 case case0022 mean_dice 0.903700 mean_hd95 8.620059
idx 2 case case0038 mean_dice 0.829463 mean_hd95 5.516328
idx 3 case case0036 mean_dice 0.867280 mean_hd95 7.137733
idx 4 case case0032 mean_dice 0.877049 mean_hd95 16.098598
idx 5 case case0002 mean_dice 0.885726 mean_hd95 7.309665
idx 6 case case0029 mean_dice 0.788059 mean_hd95 37.997616
idx 7 case case0003 mean_dice 0.585645 mean_hd95 101.546292
idx 8 case case0001 mean_dice 0.780159 mean_hd95 33.100349
idx 9 case case0004 mean_dice 0.825131 mean_hd95 7.652281
idx 10 case case0025 mean_dice 0.906964 mean_hd95 6.4

0it [00:00, ?it/s]1it [02:48, 168.50s/it]2it [04:24, 126.13s/it]3it [06:11, 117.23s/it]4it [09:46, 155.92s/it]5it [12:23, 156.28s/it]6it [14:49, 152.69s/it]7it [16:33, 136.78s/it]8it [20:29, 168.28s/it]9it [23:13, 166.93s/it]10it [25:42, 161.39s/it]11it [27:12, 139.59s/it]12it [28:39, 123.57s/it]12it [28:39, 143.28s/it]
