Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

642 train 3d model with lucchi data #650

Draft
wants to merge 26 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9d75668
implemented 3dsam train routine with lucchi data. still shape mismatch
lufre1 Jun 26, 2024
42f9f36
implemented training routine for 3d sam
lufre1 Jun 27, 2024
9be15d5
tidied up code
lufre1 Jun 27, 2024
b0fc01a
changed dataset esp. label shape not depending on num_classes
lufre1 Jun 27, 2024
8ca1326
added check_loader
lufre1 Jun 28, 2024
ca864ed
Add mentions for annotating 3D RGB volumes (#629)
anwai98 Jun 9, 2024
a66c09f
tidied up code
lufre1 Jun 28, 2024
a5e937a
Add SemanticSam3dLogger (#643)
anwai98 Jun 28, 2024
1592988
added new training and predict scripts
lufre1 Jul 4, 2024
b61ee04
Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapp…
constantinpape Jun 28, 2024
c64944d
Minor fix to trainable sam model functionality (#646)
anwai98 Jun 28, 2024
70cf9b7
Fix dimension order in 3d sam wrappers
constantinpape Jun 29, 2024
09af0a7
Api cleanup (#648)
constantinpape Jul 2, 2024
3d8d879
Fix bug in precompute for 3d data (#649)
constantinpape Jul 3, 2024
9bf0d45
Merge branch 'dev' into 642-train-3d-model-with-lucchi-data
lufre1 Jul 4, 2024
b4f7865
merges...
lufre1 Jul 4, 2024
63b4654
added support for vitl and vith
lufre1 Jul 5, 2024
eaacf7a
changed training for n iterations to n epochs
lufre1 Jul 9, 2024
e3b2dbb
debug train sam without encoder on mitottomo
lufre1 Jul 9, 2024
a19f73d
added parameter for raw transform and min_size for label_transform to…
lufre1 Jul 10, 2024
a90ca2e
added checkpoint to train_with_lucchi
lufre1 Jul 10, 2024
ad76f2e
Add min-size to training and fix other issues
constantinpape Jul 11, 2024
a550893
removed unused code
lufre1 Jul 12, 2024
908e1c1
merged train routine updated
lufre1 Jul 12, 2024
b6a7ce9
updates on train 3d without decoer
lufre1 Jul 12, 2024
3422041
bash script for sbatch
lufre1 Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions development/predict_3d_model_with_lucchi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import os
import argparse
from tqdm import tqdm
import numpy as np
import imageio.v3 as imageio
from elf.io import open_file
from skimage.measure import label as connected_components

import torch
from glob import glob

from torch_em.util.segmentation import size_filter
from torch_em.util import load_model
from torch_em.transform.raw import normalize
from torch_em.util.prediction import predict_with_halo

from micro_sam import util
from micro_sam.evaluation.inference import _run_inference_with_iterative_prompting_for_image

from segment_anything import SamPredictor

from micro_sam.models.sam_3d_wrapper import get_sam_3d_model
from typing import List, Union, Dict, Optional, Tuple


class RawTrafoFor3dInputs:
def _normalize_inputs(self, raw):
raw = normalize(raw)
raw = raw * 255
return raw

def _set_channels_for_inputs(self, raw):
raw = np.stack([raw] * 3, axis=0)
return raw

def __call__(self, raw):
raw = self._normalize_inputs(raw)
raw = self._set_channels_for_inputs(raw)
return raw


def _run_semantic_segmentation_for_image_3d(
model: torch.nn.Module,
image: np.ndarray,
prediction_path: Union[os.PathLike, str],
patch_shape: Tuple[int, int, int],
halo: Tuple[int, int, int],
):
device = next(model.parameters()).device
block_shape = tuple(bs - 2 * ha for bs, ha in zip(patch_shape, halo))

def preprocess(x):
x = 255 * normalize(x)
x = np.stack([x] * 3)
return x

def prediction_function(net, inp):
# Note: we have two singleton axis in front here, I am not quite sure why.
# Both need to be removed to be compatible with the SAM network.
batched_input = [{
"image": inp[0, 0], "original_size": inp.shape[-2:]
}]
masks = net(batched_input, multimask_output=True)[0]["masks"]
masks = torch.argmax(masks, dim=1)
return masks

# num_classes = model.sam_model.mask_decoder.num_multimask_outputs
image_size = patch_shape[-1]
output = np.zeros(image.shape, dtype="float32")
predict_with_halo(
image, model, gpu_ids=[device],
block_shape=block_shape, halo=halo,
preprocess=preprocess, output=output,
prediction_function=prediction_function
)

# save the segmentations
imageio.imwrite(prediction_path, output, compression="zlib")


def run_semantic_segmentation_3d(
model: torch.nn.Module,
image_paths: List[Union[str, os.PathLike]],
prediction_dir: Union[str, os.PathLike],
semantic_class_map: Dict[str, int],
patch_shape: Tuple[int, int, int] = (32, 512, 512),
halo: Tuple[int, int, int] = (6, 64, 64),
image_key: Optional[str] = None,
is_multiclass: bool = False,
):
"""
"""
for image_path in tqdm(image_paths, desc="Run inference for semantic segmentation with all images"):
image_name = os.path.basename(image_path)

assert os.path.exists(image_path), image_path

# Perform segmentation only on the semantic class
# for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()):
# if is_multiclass:
# semantic_class_name = "all"
# if i > 0: # We only perform segmentation for multiclass once.
# continue

semantic_class_name = "all" #since we only perform segmentation for multiclass
# We skip the images that already have been segmented
image_name = os.path.splitext(image_name)[0] + ".tif"
prediction_path = os.path.join(prediction_dir, "all", image_name)
if os.path.exists(prediction_path):
continue

if image_key is None:
image = imageio.imread(image_path)
else:
with open_file(image_path, "r") as f:
image = f[image_key][:]

# create the prediction folder
os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True)

_run_semantic_segmentation_for_image_3d(
model=model, image=image, prediction_path=prediction_path,
patch_shape=patch_shape, halo=halo,
)


def transform_labels(y):
return (y > 0).astype("float32")


def predict(args):

device = "cuda" if torch.cuda.is_available() else "cpu"
if args.checkpoint_path is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why you would ever run prediction without a checkpoint. I would not make this optional.

if os.path.exists(args.checkpoint_path):
# model = load_model(checkpoint=args.checkpoint_path, device=device) # does not work

cp_path = os.path.join(args.checkpoint_path, "", "best.pt")
print(cp_path)
model = get_sam_3d_model(device, n_classes=args.n_classes, image_size=args.patch_shape[1],
lora_rank=4,
model_type=args.model_type,
# checkpoint_path=args.checkpoint_path
)

checkpoint = torch.load(cp_path, map_location=device)
# # Load the state dictionary from the checkpoint
for k, v in checkpoint.items():
print("keys", k)
model.load_state_dict(checkpoint['model_state']) #.state_dict()
model.eval()

data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True)
pred_path = args.save_root
semantic_class_map = {"all": 0}

run_semantic_segmentation_3d(
model=model, image_paths=data_paths, prediction_dir=pred_path, semantic_class_map=semantic_class_map,
patch_shape=args.patch_shape, image_key="raw", is_multiclass=True
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h."
)
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)")
parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations")
parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_workers", type=int, default=4, help="num_workers")
parser.add_argument(
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d",
help="The filepath to where the logs and the checkpoints will be saved."
)
parser.add_argument(
"--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-vitb-masamhyp-lucchi",
help="The filepath to where the logs and the checkpoints will be saved."
)

args = parser.parse_args()

predict(args)


if __name__ == "__main__":
main()
201 changes: 201 additions & 0 deletions development/train_3d_model_with_lucchi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os
import argparse
import numpy as np
from math import ceil, floor
import torch

from torch_em.data.datasets import get_lucchi_loader, get_lucchi_dataset
from torch_em.segmentation import SegmentationDataset
import torch_em
from torch_em.util.debug import check_loader
from torch_em.transform.raw import normalize

from micro_sam.models.sam_3d_wrapper import get_sam_3d_model

from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer

import micro_sam.training as sam_training


class RawTrafoFor3dInputs:
def _normalize_inputs(self, raw):
raw = normalize(raw)
raw = raw * 255
return raw

def _set_channels_for_inputs(self, raw):
raw = np.stack([raw] * 3, axis=0)
return raw

def __call__(self, raw):
raw = self._normalize_inputs(raw)
raw = self._set_channels_for_inputs(raw)
return raw


# for sega
class RawResizeTrafoFor3dInputs(RawTrafoFor3dInputs):
def __init__(self, desired_shape, padding="constant"):
super().__init__()
self.desired_shape = desired_shape
self.padding = padding

def __call__(self, raw):
raw = self._normalize_inputs(raw)

# let's pad the inputs
tmp_ddim = (
self.desired_shape[0] - raw.shape[0],
self.desired_shape[1] - raw.shape[1],
self.desired_shape[2] - raw.shape[2]
)
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2, tmp_ddim[2] / 2)
raw = np.pad(
raw,
pad_width=(
(ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1])), (ceil(ddim[2]), floor(ddim[2]))
),
mode=self.padding
)

raw = self._set_channels_for_inputs(raw)

return raw


class LucchiSegmentationDataset(SegmentationDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can now be removed.

def __init__(self, patch_shape, label_transform=None, **kwargs):
super().__init__(patch_shape=patch_shape, label_transform=label_transform, **kwargs) # Call parent class constructor

def __getitem__(self, index):
raw, label = super().__getitem__(index)
# raw shape: (z, color channels, y, x) channels is fixed to 3
image_shape = (self.patch_shape[0], 1) + self.patch_shape[1:]
raw = raw.unsqueeze(2)
raw = raw.view(image_shape)
raw = raw.squeeze(0)
raw = raw.repeat(1, 3, 1, 1)
# wanted label shape: (1, z, y, x)
label = (label != 0).to(torch.float)
return raw, label


def transform_labels(y):
#return (y > 0).astype("float32")
# use torch_em to get foreground and boundary channels
transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
one_hot_channels = transform(y)
# Combine foreground and background using element-wise maximum
foreground = np.where(one_hot_channels[0] > 0, 1, 0)

# Combine foreground and boundaries with priority to boundaries (ensures boundaries are 2)
combined = np.where(one_hot_channels[1] > 0, 2, foreground)

# Set background to 0
combined[combined == 0] = 0

return combined


def get_loaders(input_path, patch_shape):
train_loader = get_lucchi_loader(
input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True,
raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels,
n_samples=100
)
val_loader = get_lucchi_loader(
input_path, split="test", patch_shape=patch_shape, batch_size=1,
raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels
)
return train_loader, val_loader


def train_on_lucchi(args):
from micro_sam.training.util import ConvertToSemanticSamInputs
input_path = args.input_path
patch_shape = args.patch_shape
batch_size = args.batch_size
num_workers = args.num_workers
n_classes = args.n_classes
model_type = args.model_type
n_epochs = args.n_epochs
save_root = args.save_root
cp_path = args.checkpoint_path


device = "cuda" if torch.cuda.is_available() else "cpu"
if args.without_lora:
sam_3d = get_sam_3d_model(
device, n_classes=n_classes, image_size=patch_shape[1],
model_type=model_type, lora_rank=None) # freeze encoder
else:
sam_3d = get_sam_3d_model(
device, n_classes=n_classes, image_size=patch_shape[1],
model_type=model_type, lora_rank=4)
if cp_path is not None:
if os.path.exists(cp_path):
checkpoint = torch.load(cp_path, map_location=device)
# # Load the state dictionary from the checkpoint
sam_3d.load_state_dict(checkpoint['model_state']) #.state_dict()
train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape)
#optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1)
optimizer = torch.optim.Adam(sam_3d.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=15, verbose=True)
#masam no scheduler


trainer = SemanticSamTrainer(
name=args.exp_name,
model=sam_3d,
convert_inputs=ConvertToSemanticSamInputs(),
num_classes=n_classes,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
lr_scheduler=scheduler,
device=device,
compile_model=False,
save_root=save_root,
#logger=None
)
# check_loader(train_loader, n_samples=10)
trainer.fit(epochs=n_epochs)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h."
)
parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.")
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)")

parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs")
parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict")
parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size") # masam 3
parser.add_argument("--num_workers", type=int, default=4, help="num_workers")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="base learning rate") # MASAM 0.0008
parser.add_argument(
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d",
help="The filepath to where the logs and the checkpoints will be saved."
)
parser.add_argument(
"--checkpoint_path", default=None,
help="The filepath to where the checkpoints are loaded from."
)
parser.add_argument(
"--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi",
help="The filepath to where the logs and the checkpoints will be saved."
)

args = parser.parse_args()
train_on_lucchi(args)


if __name__ == "__main__":
main()
Loading
Loading