Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LM generalist training (with new decoder setup) #379

Merged
merged 11 commits into from
Feb 8, 2024
2 changes: 1 addition & 1 deletion finetuning/evaluation/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
figures/*
*.png
173 changes: 173 additions & 0 deletions finetuning/evaluation/experiments/run_updated_unetr_evaluations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import re
import subprocess
from glob import glob
from pathlib import Path

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


CMD = "python submit_all_evaluation.py "
CHECKPOINT_ROOT = "/scratch/usr/nimanwai/experiments/micro-sam/unetr-decoder-updates/"
EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/new_models/test/unetr-decoder-updates"


def run_eval_process(cmd):
proc = subprocess.Popen(cmd)
try:
outs, errs = proc.communicate(timeout=60)
except subprocess.TimeoutExpired:
proc.terminate()
outs, errs = proc.communicate()


def run_specific_experiment(dataset_name, model_type, setup):
all_checkpoint_dirs = sorted(glob(os.path.join(CHECKPOINT_ROOT, f"{setup}-*")))
for checkpoint_dir in all_checkpoint_dirs:
checkpoint_path = os.path.join(checkpoint_dir, "checkpoints", model_type, "lm_generalist_sam", "best.pt")

experiment_name = checkpoint_dir.split("/")[-1]
experiment_folder = os.path.join(EXPERIMENT_ROOT, experiment_name, dataset_name, model_type)

cmd = CMD + f"-d {dataset_name} " + f"-m {model_type} " + "-e generalist "
cmd += f"--checkpoint_path {checkpoint_path} "
cmd += f"--experiment_path {experiment_folder}"
print(f"Running the command: {cmd} \n")
_cmd = re.split(r"\s", cmd)
run_eval_process(_cmd)


def _get_plots(dataset_name, model_type):
experiment_dirs = sorted(glob(os.path.join(EXPERIMENT_ROOT, "*")))

# adding a fixed color palette to each experiments, for consistency in plotting the legends
palette = {"amg": "C0", "ais": "C1", "box": "C2", "i_b": "C3", "point": "C4", "i_p": "C5"}

fig, ax = plt.subplots(1, len(experiment_dirs), figsize=(20, 10), sharex="col", sharey="row")

for idx, _experiment_dir in enumerate(experiment_dirs):
all_result_paths = sorted(glob(os.path.join(_experiment_dir, dataset_name, model_type, "results", "*")))
res_list_per_experiment = []
for i, result_path in enumerate(all_result_paths):
# avoid using the grid-search parameters' files
_tmp_check = os.path.split(result_path)[-1]
if _tmp_check.startswith("grid_search_"):
continue

res = pd.read_csv(result_path)
setting_name = Path(result_path).stem
if setting_name == "amg" or setting_name.startswith("instance"): # saving results from amg or ais
res_df = pd.DataFrame(
{
"name": model_type,
"type": Path(result_path).stem if len(setting_name) == 3 else "ais",
"results": res.iloc[0]["msa"]
}, index=[i]
)
else: # saving results from iterative prompting
prompt_name = Path(result_path).stem.split("_")[-1]
res_df = pd.concat(
[
pd.DataFrame(
{
"name": model_type,
"type": prompt_name,
"results": res.iloc[0]["msa"]
}, index=[i]
),
pd.DataFrame(
{
"name": model_type,
"type": f"i_{prompt_name[0]}",
"results": res.iloc[-1]["msa"]
}, index=[i]
)
]
)
res_list_per_experiment.append(res_df)

res_df_per_experiment = pd.concat(res_list_per_experiment, ignore_index=True)

container = sns.barplot(
x="name", y="results", hue="type", data=res_df_per_experiment, ax=ax[idx], palette=palette
)
ax[idx].set(xlabel="Experiments", ylabel="Segmentation Quality")
ax[idx].legend(title="Settings", bbox_to_anchor=(1, 1))

# adding the numbers over the barplots
for j in container.containers:
container.bar_label(j, fmt='%.2f')

# titles for each subplot
ax[idx].title.set_text(_experiment_dir.split("/")[-1])

# here, we remove the legends for each subplot, and get one common legend for all
all_lines, all_labels = [], []
for ax in fig.axes:
lines, labels = ax.get_legend_handles_labels()
for line, label in zip(lines, labels):
if label not in all_labels:
all_lines.append(line)
all_labels.append(label)
ax.get_legend().remove()

fig.legend(all_lines, all_labels)
plt.show()
plt.tight_layout()
plt.subplots_adjust(top=0.90, right=0.95)
fig.suptitle(dataset_name, fontsize=20)

save_path = f"figures/{dataset_name}/{model_type}.png"

try:
plt.savefig(save_path)
except FileNotFoundError:
os.makedirs(os.path.split(save_path)[0], exist_ok=True)
plt.savefig(save_path)

plt.close()
print(f"Plot saved at {save_path}")


def run_one_setup(all_dataset_list, all_model_list, setup):
for dataset_name in all_dataset_list:
for model_type in all_model_list:
run_specific_experiment(dataset_name=dataset_name, model_type=model_type, setup=setup)
breakpoint()


def for_all_lm(setup):
assert setup in ["conv-transpose", "bilinear"]

# let's run for in-domain
run_one_setup(
all_dataset_list=["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"],
all_model_list=["vit_t", "vit_b", "vit_l", "vit_h"],
setup=setup
)


def _run_evaluations():
os.chdir("../")
# for_all_lm("conv-transpose")
for_all_lm("bilinear")


def _get_all_plots():
all_datasets = ["tissuenet", "deepbacs", "plantseg/root", "livecell", "neurips-cell-seg"]
all_models = ["vit_t", "vit_b", "vit_l", "vit_h"]

for dataset_name in all_datasets:
for model_type in all_models:
_get_plots(dataset_name, model_type)


def main():
# _run_evaluations()
_get_all_plots()


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions finetuning/evaluation/preprocess_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,8 @@ def neurips_raw_trafo(raw):
def for_deepbacs(save_dir):
"Move the datasets from the internal split (provided by default in deepbacs) to our `slices` logic"
for split in ["val", "test"]:
image_paths = os.path.join(ROOT, "deepbacs", "mixed", split, "source", "*")
label_paths = os.path.join(ROOT, "deepbacs", "mixed", split, "target", "*")
image_paths = sorted(glob(os.path.join(ROOT, "deepbacs", "mixed", split, "source", "*")))
label_paths = sorted(glob(os.path.join(ROOT, "deepbacs", "mixed", split, "target", "*")))

os.makedirs(os.path.join(save_dir, split, "raw"), exist_ok=True)
os.makedirs(os.path.join(save_dir, split, "labels"), exist_ok=True)
Expand Down Expand Up @@ -801,7 +801,7 @@ def main():
# let's ensure all the data is downloaded
download_all_datasets(ROOT)

# now let's save the slices as tif
# now let's save the slices as tif
preprocess_lm_datasets()
preprocess_em_datasets()

Expand Down
4 changes: 2 additions & 2 deletions finetuning/evaluation/run_all_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ def run_one_setup(all_dataset_list, all_model_list, all_experiment_set_list, roi
def for_all_lm():
# let's run for in-domain
run_one_setup(
all_dataset_list=["tissuenet", "deepbacs", "plantseg_root", "livecell"],
all_dataset_list=["tissuenet", "deepbacs", "plantseg/root", "livecell"],
all_model_list=["vit_b", "vit_h"],
all_experiment_set_list=["vanilla", "generalist", "specialist"],
roi="lm"
)

# next, let's run for out-of-domain
run_one_setup(
all_dataset_list=["covid_if", "plantseg_ovules", "hpa", "lizard", "mouse-embryo", "ctc", "neurips-cell-seg"],
all_dataset_list=["covid_if", "plantseg/ovules", "hpa", "lizard", "mouse-embryo", "ctc", "neurips-cell-seg"],
all_model_list=["vit_b", "vit_h"],
all_experiment_set_list=["vanilla", "generalist"],
roi="lm"
Expand Down
10 changes: 9 additions & 1 deletion finetuning/evaluation/submit_all_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def write_batch_script(
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A gzz0001
#SBATCH --constraint=80gb
#SBATCH --job-name={inference_setup}

source ~/.bashrc
Expand Down Expand Up @@ -130,9 +131,16 @@ def submit_slurm(args):
all_setups = ["precompute_embeddings", "evaluate_amg", "iterative_prompting"]
else:
all_setups = ["precompute_embeddings", "evaluate_amg", "evaluate_instance_segmentation", "iterative_prompting"]

# env name
if model_type == "vit_t":
env_name = "mobilesam"
else:
env_name = "sam"

for current_setup in all_setups:
write_batch_script(
env_name="sam",
env_name=env_name,
out_path=get_batch_script_names(tmp_folder),
inference_setup=current_setup,
checkpoint=checkpoint,
Expand Down
5 changes: 3 additions & 2 deletions finetuning/evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ def get_model(model_type, ckpt):


def get_paths(dataset_name, split):
assert dataset_name in DATASETS
assert dataset_name in DATASETS, dataset_name

if dataset_name == "livecell":
return _get_livecell_paths(input_folder=os.path.join(ROOT, "livecell"), split=split)
image_paths, gt_paths = _get_livecell_paths(input_folder=os.path.join(ROOT, "livecell"), split=split)
return sorted(image_paths), sorted(gt_paths)

image_dir, gt_dir = get_dataset_paths(dataset_name, split)
image_paths = sorted(glob(os.path.join(image_dir)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def get_concat_lm_datasets(input_path, patch_shape, split_choice):
n_samples=1000 if split_choice == "train" else 100
),
datasets.get_livecell_dataset(
path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape,
label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, raw_transform=identity,
n_samples=1000 if split_choice == "train" else 100, download=True
path=os.path.join(input_path, "livecell"), split=split_choice, patch_shape=patch_shape, download=True,
label_transform=label_transform, sampler=sampler, label_dtype=label_dtype, raw_transform=identity
),
datasets.get_deepbacs_dataset(
path=os.path.join(input_path, "deepbacs"), split=split_choice, patch_shape=patch_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def finetune_lm_generalist(args):
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True
resize_input=True,
use_conv_transpose=True
)
unetr.to(device)

Expand Down