<a href="https://colab.research.google.com/github/dudeurv/SAMed/blob/main/BraTS_SAMed_train_command_line.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup environment

In [1]:
!pip install einops==0.6.1
!pip install icecream==2.1.3
!pip install MedPy==0.4.0
!pip install monai==1.1.0
!pip install opencv_python==4.5.4.58
!pip install SimpleITK==2.2.1
!pip install tensorboardX==2.6
!pip install ml-collections==0.1.1
!pip install onnx==1.13.1
!pip install onnxruntime==1.14.1
!pip install tensorboardX
!pip install torchmetrics

Collecting einops==0.6.1
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting icecream==2.1.3
  Downloading icecream-2.1.3-py2.py3-none-any.whl (8.4 kB)
Collecting colorama>=0.3.9 (from icecream==2.1.3)
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting executing>=0.3.1 (from icecream==2.1.3)
  Downloading executing-2.0.1-py2.py3-none-any.whl (24 kB)
Collecting asttokens>=2.0.1 (from icecream==2.1.3)
  Downloading asttokens-2.4.1-py2.py3-none-any.whl (27 kB)
Installing collected packages: executing, colorama, asttokens, icecream
Successfully installed asttokens-2.4.1 colorama-0.4.6 executing-2.0.1 icecream-2.1.3
Collecting MedPy==0.4.0
  Downloading MedPy-0.4.0.tar.gz (151 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.8/151.8 kB[0m [3

# Download codes, pretrained weights and test data

In [2]:
import os

CODE_DIR = 'samed_codes'

# Create the parent directory
os.makedirs(f'./{CODE_DIR}', exist_ok=True)

# Clone the SAMed repository into its subfolder
!git clone https://github.com/dudeurv/SAMed.git $CODE_DIR

os.chdir(f'./{CODE_DIR}')

os.listdir()

Cloning into 'samed_codes'...
remote: Enumerating objects: 527, done.[K
remote: Counting objects: 100% (360/360), done.[K
remote: Compressing objects: 100% (210/210), done.[K
remote: Total 527 (delta 230), reused 231 (delta 147), pack-reused 167[K
Receiving objects: 100% (527/527), 771.81 KiB | 3.69 MiB/s, done.
Resolving deltas: 100% (291/291), done.


['subsample_datasets.py',
 'requirements.txt',
 '.gitignore',
 'eval_BraTS.py',
 'BraTS_SAMed_train_command_line.ipynb',
 'segment_anything',
 '.git',
 'trainer.py',
 'utils.py',
 'datasets',
 'lists',
 'materials',
 'dataset_BraTS.py',
 'train.py',
 'preprocess',
 'README.md',
 'test.py',
 'sam_lora_image_encoder_mask_decoder.py',
 'sam_lora_image_encoder.py',
 'train_BraTS.py',
 'trainer_BraTS.py']

In [3]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

download_with_pydrive = True

class Downloader(object):
  def __init__(self, use_pydrive):
    self.use_pydrive = use_pydrive
    current_directory = os.getcwd()
    self.save_dir = '.'
    if self.use_pydrive:
      self.authenticate()

  def authenticate(self):
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    self.drive = GoogleDrive(gauth)

  def download_file(self, file_id, file_name):
    file_dst = f'{self.save_dir}/{file_name}'
    if os.path.exists(file_dst):
      print(f'{file_name} already exists')
      return
    downloaded = self.drive.CreateFile({'id': file_id})
    downloaded.FetchMetadata(fetch_all=True)
    downloaded.GetContentFile(file_dst)

downloader = Downloader(download_with_pydrive)

sam_model = {'id': '1_oCdoEEu3mNhRfFxeWyRerOKt8OEUvcg', 'name': 'sam_vit_b_01ec64.pth'}
downloader.download_file(file_id=sam_model['id'], file_name=sam_model['name'])
data = {'id': '1nHZWlCBpudbT4zzPyqyu2Vi5uILcxSrv', 'name': 'Slices.zip'}
downloader.download_file(file_id=data['id'], file_name=data['name'])

In [4]:
!unzip -n Slices.zip -d /content/samed_codes/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice121.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice119.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice109.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice108.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice100.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice129.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice113.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-001-seg.nii.gz_slice116.png  
  inflating: /content/samed_codes/Slices/Train/BraTS-GLI-seg/BraTS-GLI-00008-00

# Execute SAMed

In [5]:
%%bash
python /content/samed_codes/train_BraTS.py --root_path /content/samed_codes/Slices/Train --output /content/samed_codes/training_output --ckpt /content/samed_codes/sam_vit_b_01ec64.pth --vit_name vit_b --warmup --AdamW


torch.Size([3, 1, 1])
Namespace(root_path='/content/samed_codes/Slices/Train', output='/content/samed_codes/training_output', dataset='BraTS', num_classes=8, max_iterations=100, max_epochs=10, stop_epoch=10, batch_size=10, n_gpu=2, deterministic=1, base_lr=0.005, img_size=512, seed=1234, vit_name='vit_b', ckpt='/content/samed_codes/sam_vit_b_01ec64.pth', lora_ckpt=None, rank=4, warmup=True, warmup_period=250, AdamW=True, module='sam_lora_image_encoder', dice_param=0.8, is_pretrain=True, exp='BraTS_512')
The length of train set is: 1395
279 iterations per epoch. 2790 max iterations 
New best model saved with loss 0.6969
--- Epoch 0/10: Training loss = 0.7125, Testing loss = 0.6969, Best loss = 0.6969, Best epoch = 0
New best model saved with loss 0.2948
--- Epoch 1/10: Training loss = 0.3928, Testing loss = 0.2948, Best loss = 0.2948, Best epoch = 1
Model saved to /content/samed_codes/training_output/BraTS_512_pretrain_vit_b_10k_epo10_bs10_lr0.005/epoch_001.pth
New best model saved with

  0%|                                          | 0/10 [00:00<?, ?it/s] 10%|███▍                              | 1/10 [01:38<14:46, 98.46s/it] 20%|██████▊                           | 2/10 [03:15<13:01, 97.75s/it] 30%|██████████▏                       | 3/10 [04:52<11:20, 97.23s/it] 40%|█████████████▌                    | 4/10 [06:28<09:41, 96.98s/it] 50%|█████████████████                 | 5/10 [08:05<08:04, 96.84s/it] 60%|████████████████████▍             | 6/10 [09:42<06:27, 96.76s/it] 70%|███████████████████████▊          | 7/10 [11:18<04:49, 96.56s/it] 80%|███████████████████████████▏      | 8/10 [12:55<03:13, 96.62s/it] 90%|██████████████████████████████▌   | 9/10 [14:31<01:36, 96.46s/it] 90%|█████████████████████████████▋   | 9/10 [16:07<01:47, 107.53s/it]
