Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Monai's UNETR #155

Merged
merged 7 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 104 additions & 56 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,55 @@

import torch
import torch_em
from torch_em.model import UNETR
from torch_em.transform.raw import standardize
from torch_em.transform.label import labels_to_binary
from torch_em.data.datasets import get_livecell_loader
from torch_em.util.prediction import predict_with_halo


def get_unetr_model(
model_name: str,
source_choice: str,
patch_shape: Tuple[int, int],
sam_initialization: bool,
output_channels: int
):
"""Returns the expected UNETR model
"""
if source_choice == "torch-em":
from torch_em import model as torch_em_models
model = torch_em_models.UNETR(
encoder=model_name, out_channels=output_channels,
encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if sam_initialization else None
)

elif source_choice == "monai":
from monai.networks import nets as monai_models
model = monai_models.unetr.UNETR(
in_channels=1,
out_channels=output_channels,
img_size=patch_shape,
spatial_dims=2
)
model.out_channels = 2 # type: ignore

else:
raise ValueError(f"The available UNETR models are either from \"torch-em\" or \"monai\", choose from them instead of - {source_choice}")

return model


def do_unetr_training(
input_path: str,
model: UNETR,
model_name: str,
model,
cell_types: List[str],
patch_shape: Tuple[int, int],
device: torch.device,
save_root: str,
iterations: int,
sam_initialization: bool
sam_initialization: bool,
source_choice: str
):
os.makedirs(input_path, exist_ok=True)
print("Run training for cell types:", cell_types)
train_loader = get_livecell_loader(
path=input_path,
split="train",
Expand All @@ -53,14 +83,13 @@ def do_unetr_training(
num_workers=8
)

_name = "livecell-unetr" if cell_types is None else f"livecell-{cell_types}-unetr"

_save_root = os.path.join(
save_root, f"sam-{model_name}" if sam_initialization else "scratch"
save_root,
f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
) if save_root is not None else save_root

trainer = torch_em.default_segmentation_trainer(
name=_name,
name=f"livecell-{cell_types}",
model=model,
train_loader=train_loader,
val_loader=val_loader,
Expand All @@ -78,42 +107,39 @@ def do_unetr_training(
def do_unetr_inference(
input_path: str,
device: torch.device,
model: UNETR,
model,
cell_types: List[str],
save_dir: str,
root_save_dir: str,
sam_initialization: bool,
model_name: str,
save_root: str
save_root: str,
source_choice: str
):
_save_dir = os.path.join(
save_dir,
f"unetr-torch-em-sam-{model_name}" if sam_initialization else f"unetr-torch-em-scratch-{model_name}"
)

for ctype in cell_types:
test_img_dir = os.path.join(input_path, "images", "livecell_test_images", "*")

model_ckpt = os.path.join(save_root,
f"sam-{model_name}" if sam_initialization else "scratch",
"checkpoints", f"livecell-{ctype}-unetr", "best.pt")
f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch",
"checkpoints", f"livecell-{ctype}", "best.pt")
assert os.path.exists(model_ckpt)

model.load_state_dict(torch.load(model_ckpt, map_location=torch.device('cpu'))["model_state"])
model.to(device)
model.eval()

with torch.no_grad():
for img_path in glob(test_img_dir):
for img_path in tqdm(glob(test_img_dir), desc=f"Run inference for all livecell with model {model_ckpt}"):
fname = os.path.split(img_path)[-1]

input_img = imageio.imread(img_path)
input_img = standardize(input_img)
outputs = predict_with_halo(input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64])
outputs = predict_with_halo(
input_img, model, gpu_ids=[device], block_shape=[384, 384], halo=[64, 64], disable_tqdm=True
)

fg, bd = outputs[0, :, :], outputs[1, :, :]

fg_save_dir = os.path.join(_save_dir, f"src-{ctype}", "foreground")
bd_save_dir = os.path.join(_save_dir, f"src-{ctype}", "boundary")
fg_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "foreground")
bd_save_dir = os.path.join(root_save_dir, f"src-{ctype}", "boundary")

os.makedirs(fg_save_dir, exist_ok=True)
os.makedirs(bd_save_dir, exist_ok=True)
Expand All @@ -125,21 +151,20 @@ def do_unetr_inference(
def do_unetr_evaluation(
input_path: str,
cell_types: List[str],
save_dir: str,
model_name: str,
sam_initialization: bool
root_save_dir: str,
sam_initialization: bool,
source_choice: str
):
root_save_dir = os.path.join(
save_dir,
f"unetr-torch-em-sam-{model_name}" if sam_initialization else f"unetr-torch-em-scratch-{model_name}"
)
fg_list, bd_list = [], []

for c1 in cell_types:
_save_dir = os.path.join(root_save_dir, f"src-{c1}")
if not os.path.exists(_save_dir):
print("Skipping", _save_dir)
continue

fg_set, bd_set = {"CELL TYPE": c1}, {"CELL TYPE": c1}
for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models"):
for c2 in tqdm(cell_types, desc=f"Evaluation on {c1} source models from {_save_dir}"):
fg_dir = os.path.join(_save_dir, "foreground")
bd_dir = os.path.join(_save_dir, "boundary")

Expand All @@ -152,11 +177,19 @@ def do_unetr_evaluation(
fg = imageio.imread(os.path.join(fg_dir, fname))
bd = imageio.imread(os.path.join(bd_dir, fname))

true_fg = labels_to_binary(gt)
true_bd = find_boundaries(gt)

cwise_fg.append(dice_score(fg, true_fg, threshold_gt=0))
cwise_bd.append(dice_score(bd, true_bd, threshold_gt=0))
# Compare the foreground prediction to the ground-truth.
# Here, it's important not to threshold the segmentation. Otherwise EVERYTHING will be set to
# foreground in the dice function, since we have a comparision > 0 in there, and everything in the
# binary prediction evaluates to true.
# For the GT we can set the threshold to 0, because this will map to the correct binary mask.
cwise_fg.append(dice_score(fg, gt, threshold_gt=0, threshold_seg=None))

# Compare the background prediction to the ground-truth.
# Here, we don't need any thresholds: for the prediction the same holds as before.
# For the ground-truth we have already a binary label, so we don't need to threshold it again.
cwise_bd.append(dice_score(bd, true_bd, threshold_gt=None, threshold_seg=None))

fg_set[c2] = np.mean(cwise_fg)
bd_set[c2] = np.mean(cwise_bd)
Expand All @@ -170,56 +203,68 @@ def do_unetr_evaluation(
csv_save_dir = "./results/"
os.makedirs(csv_save_dir, exist_ok=True)

tmp_csv_name = f"sam-{model_name}" if sam_initialization else "scratch"
f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-torch-em-unetr-{tmp_csv_name}-results.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-torch-em-unetr-{tmp_csv_name}-results.csv"))
tmp_csv_name = f"{source_choice}-sam" if sam_initialization else f"{source_choice}-scratch"
f_df_fg.to_csv(os.path.join(csv_save_dir, f"foreground-unetr-{tmp_csv_name}-results.csv"))
f_df_bd.to_csv(os.path.join(csv_save_dir, f"boundary-unetr-{tmp_csv_name}-results.csv"))


def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n_channels = 2
model = UNETR(
encoder=args.model_name, out_channels=n_channels,
encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth" if args.do_sam_ini else None)
model.to(device)
patch_shape = (512, 512)
output_channels = 2

all_cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]

model = get_unetr_model(
model_name=args.model_name,
source_choice=args.source_choice,
patch_shape=patch_shape,
sam_initialization=args.do_sam_ini,
output_channels=output_channels
)
model.to(device)

if args.train:
print("2d UNETR training on LiveCell dataset")
do_unetr_training(
input_path=args.input,
model=model,
model_name=args.model_name,
cell_types=args.cell_type,
patch_shape=(512, 512),
patch_shape=patch_shape,
device=device,
save_root=args.save_root,
iterations=args.iterations,
sam_initialization=args.do_sam_ini
sam_initialization=args.do_sam_ini,
source_choice=args.source_choice
)

root_save_dir = os.path.join(
args.save_dir,
f"unetr-{args.source_choice}-sam" if args.do_sam_ini else f"unetr-{args.source_choice}-scratch"
)
print("Predictions are saved in", root_save_dir)

if args.predict:
print("2d UNETR inference on LiveCell dataset")
do_unetr_inference(
input_path=args.input,
device=device,
model=model,
cell_types=all_cell_types,
save_dir=args.save_dir,
root_save_dir=root_save_dir,
sam_initialization=args.do_sam_ini,
model_name=args.model_name,
save_root=args.save_root
save_root=args.save_root,
source_choice=args.source_choice
)

if args.evaluate:
print("2d UNETR evaluation on LiveCell dataset")
do_unetr_evaluation(
input_path=args.input,
cell_types=all_cell_types,
save_dir=args.save_dir,
model_name=args.model_name,
sam_initialization=args.do_sam_ini
root_save_dir=root_save_dir,
sam_initialization=args.do_sam_ini,
source_choice=args.source_choice
)


Expand All @@ -234,6 +279,9 @@ def main(args):
parser.add_argument("--evaluate", action='store_true',
help="Enables UNETR evaluation on LiveCell dataset")

parser.add_argument("--source_choice", type=str, default="torch-em",
help="The source where the model comes from, i.e. either torch-em / monai")

parser.add_argument("-m", "--model_name", type=str, default="vit_b",
help="Name of the ViT to use as the encoder in UNETR")

Expand All @@ -246,13 +294,13 @@ def main(args):
parser.add_argument("-i", "--input", type=str, default="/scratch/usr/nimanwai/data/livecell",
help="Path where the dataset already exists/will be downloaded by the dataloader")

parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/torch-em/",
parser.add_argument("-s", "--save_root", type=str, default="/scratch/usr/nimanwai/models/unetr/",
help="Path where checkpoints and logs will be saved")

parser.add_argument("--save_dir", type=str, default="/scratch/usr/nimanwai/predictions/unetr",
help="Path to save predictions from UNETR model")

parser.add_argument("--iterations", type=int, default=10000)
parser.add_argument("--iterations", type=int, default=100000)

args = parser.parse_args()
main(args)
26 changes: 16 additions & 10 deletions experiments/vision-transformer/unetr/submit_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from datetime import datetime


def write_batch_script(out_path, ini_sam=False):
def write_batch_script(out_path, ini_sam=True, source_choice="torch-em"):
"""
inputs:
source_choice:str - [torch_em / monai] source of the unetr model coming from
ini_sam: bool - initialize torch-em's unetr implementation with sam encoder weights
"""
cell_types = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]

for i, ctype in enumerate(cell_types):
batch_script = """#!/bin/bash
#SBATCH -t 2-00:00:00
Expand All @@ -20,23 +26,23 @@ def write_batch_script(out_path, ini_sam=False):
#SBATCH -A gzz0001
"""
if ini_sam:
batch_script += f"#SBATCH --job-name=unetr-sam-torch-em-{ctype}"
batch_script += f"#SBATCH --job-name=unetr-sam-{source_choice}-{ctype}"
else:
batch_script += f"#SBATCH --job-name=unetr-torch-em-{ctype}"
batch_script += f"#SBATCH --job-name=unetr-{source_choice}-{ctype}"

env_name = "monai2" if source_choice == "monai" else "sam"

batch_script += """
batch_script += f"""

source ~/.bashrc
mamba activate sam
mamba activate {env_name}
python livecell_unetr.py --train """

add_ctype = f"-c {ctype} "
add_input_path = "-i /scratch/usr/nimanwai/data/livecell/ "
add_save_root = "-s /scratch/usr/nimanwai/models/unetr/torch-em/ "
add_sam_ini = "--do_sam_ini "

batch_script += add_ctype + add_input_path + add_save_root
add_source_choice = f"--source_choice {source_choice} "
batch_script += add_ctype + add_source_choice

add_sam_ini = "--do_sam_ini "
if ini_sam:
batch_script += add_sam_ini

Expand Down
Loading