<a href="https://colab.research.google.com/github/dudeurv/SAMed/blob/main/BraTS_SAMed_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup environment

In [1]:
# Install pytorch
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
# Install other necessary environments
!pip install einops==0.6.1
!pip install h5py==3.5.0
!pip install icecream==2.1.3
!pip install imageio==2.10.1
!pip install MedPy==0.4.0
!pip install nibabel==4.0.2
!pip install monai==1.1.0
!pip install numpy==1.21.6
!pip install opencv_python==4.5.4.58
!pip install pycocotools==2.0.6
!pip install safetensors==0.3.1
!pip install scipy==1.7.3
!pip install SimpleITK==2.2.1
!pip install tensorboardX==2.6
!pip install tqdm==4.62.3
!pip install ml-collections==0.1.1
!pip install pycocotools==2.0.6
!pip install onnx==1.13.1
!pip install onnxruntime==1.14.1
!pip install tensorboardX

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.11.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp310-cp310-linux_x86_64.whl (1637.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 GB[0m [31m789.9 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.12.0+cu113
  Downloading https://download.pytorch.org/whl/cu113/torchvision-0.12.0%2Bcu113-cp310-cp310-linux_x86_64.whl (22.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 2.1.0+cu121
    Uninstalling torch-2.1.0+cu121:
      Successfully uninstalled torch-2.1.0+cu121
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.16.0+cu121
    Uninstalling torchvision-0.16.0+cu121:
      Successfully uninstalle

Collecting opencv_python==4.5.4.58
  Downloading opencv_python-4.5.4.58-cp310-cp310-manylinux2014_x86_64.whl (60.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.3/60.3 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: opencv_python
  Attempting uninstall: opencv_python
    Found existing installation: opencv-python 4.8.0.76
    Uninstalling opencv-python-4.8.0.76:
      Successfully uninstalled opencv-python-4.8.0.76
Successfully installed opencv_python-4.5.4.58
Collecting pycocotools==2.0.6
  Downloading pycocotools-2.0.6.tar.gz (24 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: pycocotools
  Building wheel for pycocotools (pyproject.toml) ... [?25l[?25hdone
  Created wheel for pycocotools: filename=pycocotools-2.0.6-cp310-cp310-linux_x86_64.whl size=377179 sha256=1

# Download codes, pretrained weights and test data

In [2]:
import os

CODE_DIR = 'samed_codes'

# Create the parent directory
os.makedirs(f'./{CODE_DIR}', exist_ok=True)

# Clone the SAMed repository into its subfolder
!git clone https://github.com/dudeurv/SAMed.git $CODE_DIR

os.chdir(f'./{CODE_DIR}')

Cloning into 'samed_codes'...
remote: Enumerating objects: 315, done.[K
remote: Counting objects: 100% (192/192), done.[K
remote: Compressing objects: 100% (120/120), done.[K
remote: Total 315 (delta 144), reused 72 (delta 72), pack-reused 123[K
Receiving objects: 100% (315/315), 662.32 KiB | 27.60 MiB/s, done.
Resolving deltas: 100% (163/163), done.


In [3]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

download_with_pydrive = True

class Downloader(object):
  def __init__(self, use_pydrive):
    self.use_pydrive = use_pydrive
    current_directory = os.getcwd()
    self.save_dir = '.'
    if self.use_pydrive:
      self.authenticate()

  def authenticate(self):
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    self.drive = GoogleDrive(gauth)

  def download_file(self, file_id, file_name):
    file_dst = f'{self.save_dir}/{file_name}'
    if os.path.exists(file_dst):
      print(f'{file_name} already exists')
      return
    downloaded = self.drive.CreateFile({'id': file_id})
    downloaded.FetchMetadata(fetch_all=True)
    downloaded.GetContentFile(file_dst)

downloader = Downloader(download_with_pydrive)

sam_model = {'id': '1_oCdoEEu3mNhRfFxeWyRerOKt8OEUvcg', 'name': 'sam_vit_b_01ec64.pth'}
downloader.download_file(file_id=sam_model['id'], file_name=sam_model['name'])
train_data = {'id': '183rFdH3S2OukjY8-DJj6KV7rxJTsFydW', 'name': 'Training.zip'}
downloader.download_file(file_id=train_data['id'], file_name=train_data['name'])

In [4]:
!unzip -n Training.zip -d /content/samed_codes/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice4.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice48.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice31.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00008-001-t2f.nii.gz_slice145.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice37.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice0.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice15.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00008-001-t2f.nii.gz_slice146.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice12.png  
  inflating: /content/samed_codes/Train/BraT

# Execute SAMed

In [6]:
%%bash
python /content/samed_codes/train_BraTS.py --root_path /content/samed_codes/Train --output /content/samed_codes/training_output --ckpt /content/samed_codes/sam_vit_b_01ec64.pth --vit_name vit_b --warmup --AdamW


torch.Size([3, 1, 1])
Namespace(root_path='/content/samed_codes/Train', output='/content/samed_codes/training_output', dataset='BraTS', list_dir='./lists/lists_Synapse', num_classes=8, max_iterations=30000, max_epochs=200, stop_epoch=160, batch_size=24, n_gpu=1, deterministic=1, base_lr=0.005, img_size=512, seed=1234, vit_name='vit_b', ckpt='/content/samed_codes/sam_vit_b_01ec64.pth', lora_ckpt=None, rank=4, warmup=True, warmup_period=250, AdamW=True, module='sam_lora_image_encoder', dice_param=0.8, lr_exp=0.9, tf32=False, compile=False, use_amp=False, is_pretrain=True, exp='BraTS_512')
The length of train set is: 1395
56 iterations per epoch. 11200 max iterations 


  0%|                                         | 0/200 [00:00<?, ?it/s]  0%|                                         | 0/200 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/content/samed_codes/train_BraTS.py", line 135, in <module>
    trainer[dataset_name](args, net, snapshot_path, multimask_output, low_res)
  File "/content/samed_codes/trainer_BraTS.py", line 86, in trainer_BraTS
    outputs = model(image_batch, multimask_output, args.img_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/samed_codes/sam_lora_image_encoder.py", line 187, in forward
    return self.sam(batched_input, multimask_output, image_size)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/samed_codes/segment_anything/modeling/sam.py", line 58, in forward
    outputs = self.f

CalledProcessError: ignored