In [1]:
import os
import sys
sys.path.append('/workspace/Documents')  ### remove this if not needed!
import numpy as np
import pandas as pd 
from tqdm import tqdm 
import random
from pathlib import Path
import nibabel as nb
import time

import argparse
from einops import rearrange
from natsort import natsorted
from madgrad import MADGRAD

import torch
import torch.backends.cudnn as cudnn
 
from original_SAM.utils.model_util import *
from original_SAM.segment_anything.model import build_model 
from original_SAM.utils.save_utils import *
from original_SAM.utils.config_util import Config
from original_SAM.utils.misc import NativeScalerWithGradNormCount as NativeScaler

# from original_SAM.train_engine import train_loop

import original_SAM.dataset.build_datasets as build_datasets
import original_SAM.functions_collection as ff
import original_SAM.get_args_parser as get_args_parser

main_path = '/mnt/camca_NAS/SAM_for_CMR/'  # replace with your own path

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15]


### define parameters for this experiment
The full setting can be find in ```get_args_parser.py```

In [2]:
# set experiment-specific parameters
trial_name = 'original_SAM_trial' 

output_dir = os.path.join(main_path, 'example_data_original_sam/models', trial_name)
ff.make_folder([os.path.join(main_path, 'example_data_original_sam/models'), output_dir])

pretrained_model = None # define your pre-trained model if any
start_epoch = 1
total_training_epochs = 100 # define total number of epochs

In [7]:
# define the original SAM model
original_sam = os.path.join( main_path, 'models/pretrained_sam/sam_vit_b_01ec64.pth') 

# define the number of segmented classes ## important
num_classes = 3

args = get_args_parser.get_args_parser(num_classes = num_classes,
                                       vit_type = "vit_b",
                                       pretrained_model = pretrained_model, 
                                       original_sam = original_sam, 
                                       start_epoch = start_epoch, 
                                       total_training_epochs = total_training_epochs)
args = args.parse_args([])

# some other settings
cfg = Config(args.config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

### define the training dataset

In [8]:
# define SAX training data
patient_list_file_sax = os.path.join(main_path,'example_data_original_sam/Patient_list/patient_list.xlsx')
patient_index_list = np.arange(0,1,1)
dataset_train_sax = build_datasets.build_dataset(
        args,
        patient_list_file = patient_list_file_sax, 
        index_list = patient_index_list, 
        shuffle = True, 
        augment = True)

### load pre-trained SAM model (freeze SAM modules)

In [10]:
# set model
model = build_model(args, device)

# # set freezed and trainable keys
train_keys = []
freezed_keys = []
        
# load pretrained sam model vit_b
if args.model_type.startswith("sam"):
    if args.resume.endswith(".pth"):
        print('args.vit_type = ', args.vit_type)
        with open(args.resume, "rb") as f:
            state_dict = torch.load(f)
        try:
            model.load_state_dict(state_dict)
        except:
            if args.vit_type == "vit_h":
                new_state_dict = load_from(model, state_dict, args.img_size,  16, [7, 15, 23, 31])
               
            model.load_state_dict(new_state_dict)
        
        # # freeze original SAM layers
        # freeze_list = [ "norm1", "attn" , "mlp", "norm2"]  
                
        for n, value in model.named_parameters():
            value.requires_grad = True
            # if any(substring in n for substring in freeze_list):
            #     freezed_keys.append(n)
            #     value.requires_grad = False
            # else:
            #     train_keys.append(n)
            #     value.requires_grad = True

## Select optimization method
optimizer = MADGRAD(model.parameters(), lr=args.lr) # momentum=,weight_decay=,eps=)
        
# Continue training model
if args.pretrained_model is not None:
    if os.path.exists(args.pretrained_model):
        print('loading pretrained model : ', args.pretrained_model)
        args.resume = args.pretrained_model
        finetune_checkpoint = torch.load(args.pretrained_model)
        model.load_state_dict(finetune_checkpoint["model"])
        optimizer.load_state_dict(finetune_checkpoint["optimizer"])
        torch.cuda.empty_cache()
else:
    print('new training\n')

args.vit_type =  vit_b
new training



  state_dict = torch.load(f)
