<a href="https://colab.research.google.com/github/dudeurv/SAMed/blob/main/BraTS_SAMed_train_command_line.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup environment

In [1]:
!pip install einops==0.6.1
!pip install icecream==2.1.3
!pip install MedPy==0.4.0
!pip install monai==1.1.0
!pip install opencv_python==4.5.4.58
!pip install SimpleITK==2.2.1
!pip install tensorboardX==2.6
!pip install ml-collections==0.1.1
!pip install onnx==1.13.1
!pip install onnxruntime==1.14.1
!pip install tensorboardX

Collecting einops==0.6.1
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting icecream==2.1.3
  Downloading icecream-2.1.3-py2.py3-none-any.whl (8.4 kB)
Collecting colorama>=0.3.9 (from icecream==2.1.3)
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting executing>=0.3.1 (from icecream==2.1.3)
  Downloading executing-2.0.1-py2.py3-none-any.whl (24 kB)
Collecting asttokens>=2.0.1 (from icecream==2.1.3)
  Downloading asttokens-2.4.1-py2.py3-none-any.whl (27 kB)
Installing collected packages: executing, colorama, asttokens, icecream
Successfully installed asttokens-2.4.1 colorama-0.4.6 executing-2.0.1 icecream-2.1.3
Collecting MedPy==0.4.0
  Downloading 

# Download codes, pretrained weights and test data

In [2]:
import os

CODE_DIR = 'samed_codes'

# Create the parent directory
os.makedirs(f'./{CODE_DIR}', exist_ok=True)

# Clone the SAMed repository into its subfolder
!git clone https://github.com/dudeurv/SAMed.git $CODE_DIR

os.chdir(f'./{CODE_DIR}')

os.listdir()

Cloning into 'samed_codes'...
remote: Enumerating objects: 408, done.[K
remote: Counting objects: 100% (241/241), done.[K
remote: Compressing objects: 100% (146/146), done.[K
remote: Total 408 (delta 153), reused 131 (delta 92), pack-reused 167[K
Receiving objects: 100% (408/408), 710.60 KiB | 15.79 MiB/s, done.
Resolving deltas: 100% (214/214), done.


['subsample_datasets.py',
 'requirements.txt',
 '.gitignore',
 'segment_anything',
 '.git',
 'trainer.py',
 'utils.py',
 'datasets',
 'lists',
 'materials',
 'dataset_BraTS.py',
 'train.py',
 'preprocess',
 'README.md',
 'test.py',
 'sam_lora_image_encoder_mask_decoder.py',
 'sam_lora_image_encoder.py',
 'train_BraTS.py',
 'trainer_BraTS.py']

In [3]:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import os

download_with_pydrive = True

class Downloader(object):
  def __init__(self, use_pydrive):
    self.use_pydrive = use_pydrive
    current_directory = os.getcwd()
    self.save_dir = '.'
    if self.use_pydrive:
      self.authenticate()

  def authenticate(self):
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    self.drive = GoogleDrive(gauth)

  def download_file(self, file_id, file_name):
    file_dst = f'{self.save_dir}/{file_name}'
    if os.path.exists(file_dst):
      print(f'{file_name} already exists')
      return
    downloaded = self.drive.CreateFile({'id': file_id})
    downloaded.FetchMetadata(fetch_all=True)
    downloaded.GetContentFile(file_dst)

downloader = Downloader(download_with_pydrive)

sam_model = {'id': '1_oCdoEEu3mNhRfFxeWyRerOKt8OEUvcg', 'name': 'sam_vit_b_01ec64.pth'}
downloader.download_file(file_id=sam_model['id'], file_name=sam_model['name'])
train_data = {'id': '183rFdH3S2OukjY8-DJj6KV7rxJTsFydW', 'name': 'Training.zip'}
downloader.download_file(file_id=train_data['id'], file_name=train_data['name'])

In [4]:
!unzip -n Training.zip -d /content/samed_codes/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice4.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice48.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice31.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00008-001-t2f.nii.gz_slice145.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice37.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice0.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice15.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00008-001-t2f.nii.gz_slice146.png  
  inflating: /content/samed_codes/Train/BraTS-GLI-t2f/BraTS-GLI-00009-000-t2f.nii.gz_slice12.png  
  inflating: /content/samed_codes/Train/BraT

# Execute SAMed

In [5]:
%%bash
python /content/samed_codes/train_BraTS.py --root_path /content/samed_codes/Train --output /content/samed_codes/training_output --ckpt /content/samed_codes/sam_vit_b_01ec64.pth --vit_name vit_b --warmup --AdamW


torch.Size([3, 1, 1])
Namespace(root_path='/content/samed_codes/Train', output='/content/samed_codes/training_output', dataset='BraTS', num_classes=8, max_iterations=100, max_epochs=10, stop_epoch=6, batch_size=12, n_gpu=2, deterministic=1, base_lr=0.005, img_size=512, seed=1234, vit_name='vit_b', ckpt='/content/samed_codes/sam_vit_b_01ec64.pth', lora_ckpt=None, rank=4, warmup=True, warmup_period=250, AdamW=True, module='sam_lora_image_encoder', dice_param=0.8, is_pretrain=True, exp='BraTS_512')
The length of train set is: 1395
279 iterations per epoch. 2790 max iterations 
iteration 1 : loss : 2.922372, loss_ce: 10.614069, loss_dice: 0.999448
iteration 2 : loss : 1.636240, loss_ce: 4.205913, loss_dice: 0.993822
iteration 3 : loss : 1.710809, loss_ce: 4.617106, loss_dice: 0.984234
iteration 4 : loss : 1.339381, loss_ce: 2.845310, loss_dice: 0.962899
iteration 5 : loss : 1.176538, loss_ce: 2.017739, loss_dice: 0.966238
iteration 6 : loss : 1.043798, loss_ce: 1.477863, loss_dice: 0.93528

  0%|                                          | 0/10 [00:00<?, ?it/s] 10%|███▎                             | 1/10 [04:52<43:53, 292.56s/it] 20%|██████▌                          | 2/10 [09:53<39:38, 297.37s/it] 30%|█████████▉                       | 3/10 [15:03<35:22, 303.22s/it] 40%|█████████████▏                   | 4/10 [20:14<30:38, 306.38s/it] 50%|████████████████▌                | 5/10 [25:25<25:40, 308.14s/it] 50%|████████████████▌                | 5/10 [30:37<30:37, 367.48s/it]
