<a href="https://colab.research.google.com/github/dudeurv/SAM_MRI/blob/main/SAMed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Customized Segment Anything Model for Medical Image Segmentation
### [[Paper](https://arxiv.org/pdf/2304.13785.pdf)] [[Github](https://github.com/hitachinsk/SAMed)]
---
[Kaidong Zhang](https://hitachinsk.github.io/), and [Dong Liu](https://faculty.ustc.edu.cn/dongeliu/), technical report, 2023

All rights reserved by the authors of SAMed

We provide the entire inference pipeline in this page.

# Setup environment

In [1]:
# Install pytorch
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
# Install other necessary environments
!pip install einops==0.6.1
!pip install h5py==3.5.0
!pip install icecream==2.1.3
!pip install imageio==2.10.1
!pip install MedPy==0.4.0
!pip install nibabel==4.0.2
!pip install monai==1.1.0
!pip install numpy==1.21.6
!pip install opencv_python==4.5.4.58
!pip install pycocotools==2.0.6
!pip install safetensors==0.3.1
!pip install scipy==1.7.3
!pip install SimpleITK==2.2.1
!pip install tensorboardX==2.6
!pip install tqdm==4.62.3
!pip install ml-collections==0.1.1
!pip install pycocotools==2.0.6
!pip install onnx==1.13.1
!pip install onnxruntime==1.14.1

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.11.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp310-cp310-linux_x86_64.whl (1637.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 GB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.12.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torchvision-0.12.0%2Bcu113-cp310-cp310-linux_x86_64.whl (22.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 2.1.0+cu121
    Uninstalling torch-2.1.0+cu121:
      Successfully uninstalled torch-2.1.0+cu121
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.16.0+cu121
    Uninstalling torchvision-0.16.0+cu121:
      Successfully uninstalled 

Collecting opencv_python==4.5.4.58
  Downloading opencv_python-4.5.4.58-cp310-cp310-manylinux2014_x86_64.whl (60.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.3/60.3 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: opencv_python
  Attempting uninstall: opencv_python
    Found existing installation: opencv-python 4.8.0.76
    Uninstalling opencv-python-4.8.0.76:
      Successfully uninstalled opencv-python-4.8.0.76
Successfully installed opencv_python-4.5.4.58
Collecting pycocotools==2.0.6
  Downloading pycocotools-2.0.6.tar.gz (24 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: pycocotools
  Building wheel for pycocotools (pyproject.toml) ... [?25l[?25hdone
  Created wheel for pycocotools: filename=pycocotools-2.0.6-cp310-cp310-linux_x86_64.whl size=377152 sha256=01

# Download codes, pretrained weights and test data

In [2]:
# prepare codes
import os
CODE_DIR = 'samed_codes'
os.makedirs(f'./{CODE_DIR}')
!git clone https://github.com/hitachinsk/SAMed.git $CODE_DIR
os.chdir(f'./{CODE_DIR}')

Cloning into 'samed_codes'...
remote: Enumerating objects: 225, done.[K
remote: Counting objects: 100% (58/58), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 225 (delta 42), reused 29 (delta 29), pack-reused 167[K
Receiving objects: 100% (225/225), 636.92 KiB | 12.74 MiB/s, done.
Resolving deltas: 100% (103/103), done.


In [3]:
# Install the SAM library from Facebook Research
!pip -q install 'git+https://github.com/facebookresearch/segment-anything.git'

# Download the pre-trained SAM model path for later use
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
--2023-12-17 15:50:28--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 52.84.251.114, 52.84.251.15, 52.84.251.106, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|52.84.251.114|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 375042383 (358M) [binary/octet-stream]
Saving to: ‘sam_vit_b_01ec64.pth’


2023-12-17 15:50:30 (196 MB/s) - ‘sam_vit_b_01ec64.pth’ saved [375042383/375042383]



In [9]:
# Import necessary libraries
from torch.utils.data import Dataset, DataLoader  # PyTorch classes for handling datasets and data loading
from glob import glob                            # For file path pattern matching
import imageio as iio
import matplotlib.pyplot as plt
import numpy as np

!gdown https://drive.google.com/uc?id=1VWZsgkd5wbTStLDwGuUJ1G3POwPCuOgQ


Downloading...
From: https://drive.google.com/uc?id=1VWZsgkd5wbTStLDwGuUJ1G3POwPCuOgQ
To: /content/samed_codes/train_npz_new_224.zip
100% 351M/351M [00:10<00:00, 34.0MB/s]


# Execute SAMed

In [10]:
!unzip -n train_npz_new_224.zip -d /content/samed_codes/

Archive:  train_npz_new_224.zip
   creating: /content/samed_codes/train_npz_new_224/
  inflating: /content/samed_codes/train_npz_new_224/case0039_slice021.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0027_slice025.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0010_slice062.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0034_slice025.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0007_slice006.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0010_slice061.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0023_slice005.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0034_slice022.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0006_slice114.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0021_slice017.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0009_slice106.npz  
  inflating: /content/samed_codes/train_npz_new_224/case0027_slice019.npz  
  i

In [6]:
import os

def find_file(root_folder, filename):
    for root, dirs, files in os.walk(root_folder):
        if filename in files:
            return os.path.join(root, filename)
    return None

file_path = find_file('/content', 'sam_vit_b_01ec64.pth')
if file_path:
    print(f"File found at: {file_path}")
else:
    print("File not found.")

import os

source = file_path  # Replace with the actual path where the file is currently located
destination = '/content/samed_codes/checkpoints/sam_vit_b_01ec64.pth'

os.makedirs('/content/samed_codes/checkpoints/', exist_ok=True)
os.replace(source, destination)


File found at: /content/samed_codes/sam_vit_b_01ec64.pth


In [7]:
import os

source = '/content/samed_codes/sam_vit_b_01ec64.pth'  # Replace with the actual path where the file is currently located
destination = '/content/samed_codes/checkpoints/sam_vit_b_01ec64.pth'

os.makedirs('/content/samed_codes/checkpoints/', exist_ok=True)
os.replace(source, destination)


In [12]:
%%bash
python /content/samed_codes/train.py --root_path /content/samed_codes/train_npz_new_224 --output /content/samed_codes/training_output --warmup --AdamW


Namespace(root_path='/content/samed_codes/train_npz_new_224', output='/content/samed_codes/training_output', dataset='Synapse', list_dir='./lists/lists_Synapse', num_classes=8, max_iterations=30000, max_epochs=200, stop_epoch=160, batch_size=12, n_gpu=2, deterministic=1, base_lr=0.005, img_size=512, seed=1234, vit_name='vit_b', ckpt='checkpoints/sam_vit_b_01ec64.pth', lora_ckpt=None, rank=4, warmup=True, warmup_period=250, AdamW=True, module='sam_lora_image_encoder', dice_param=0.8, is_pretrain=True, exp='Synapse_512')
The length of train set is: 2211
93 iterations per epoch. 18600 max iterations 


  0%|                                         | 0/200 [00:00<?, ?it/s]  0%|                                         | 0/200 [00:22<?, ?it/s]
Traceback (most recent call last):
  File "/content/samed_codes/train.py", line 122, in <module>
    trainer[dataset_name](args, net, snapshot_path, multimask_output, low_res)
  File "/content/samed_codes/trainer.py", line 78, in trainer_synapse
    outputs = model(image_batch, multimask_output, args.img_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/samed_codes/sam_lora_image_encoder.py", line 187, in forward
    return self.sam(batch

CalledProcessError: ignored