From 6c8358fdf5ce3e8f01e344c6315c3ae9b13cf63a Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 12 Apr 2023 14:17:22 +0200 Subject: [PATCH 1/2] Add Argument (optional) for Target Cell-Types --- .../probabilistic_domain_adaptation/livecell/common.py | 1 + .../livecell/unet_adamatch.py | 10 ++++++++-- .../livecell/unet_adamt.py | 10 ++++++++-- .../livecell/unet_fixmatch.py | 10 ++++++++-- .../livecell/unet_mean_teacher.py | 10 ++++++++-- 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/common.py b/experiments/probabilistic_domain_adaptation/livecell/common.py index d60d08ca..f1f6de99 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/common.py +++ b/experiments/probabilistic_domain_adaptation/livecell/common.py @@ -282,5 +282,6 @@ def get_parser(default_batch_size=8, default_iterations=int(1e5)): parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int) parser.add_argument("-s", "--save_root") parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES) + parser.add_argument("--target_ct", nargs="+", default=None) parser.add_argument("-o", "--output") return parser diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py index ef723c1e..dd75fadf 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py @@ -84,7 +84,13 @@ def _train_source_target(args, source_cell_type, target_cell_type): def _train_source(args, cell_type): - for target_cell_type in common.CELL_TYPES: + if args.target_ct is None: + target_cell_list = common.CELL_TYPES + else: + target_cell_list = args.target_ct + + for target_cell_type in target_cell_list: + print("Training on target cell type:", target_cell_type) if target_cell_type == cell_type: continue _train_source_target(args, cell_type, target_cell_type) @@ -92,7 +98,7 @@ def _train_source(args, cell_type): def run_training(args): for cell_type in args.cell_types: - print("Start training for cell type:", cell_type) + print("Start training for source cell type:", cell_type) _train_source(args, cell_type) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py index e3d7c2fd..e1ed3743 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_adamt.py @@ -71,7 +71,13 @@ def _train_source_target(args, source_cell_type, target_cell_type): def _train_source(args, cell_type): - for target_cell_type in common.CELL_TYPES: + if args.target_ct is None: + target_cell_list = common.CELL_TYPES + else: + target_cell_list = args.target_ct + + for target_cell_type in target_cell_list: + print("Training on target cell type:", target_cell_type) if target_cell_type == cell_type: continue _train_source_target(args, cell_type, target_cell_type) @@ -79,7 +85,7 @@ def _train_source(args, cell_type): def run_training(args): for cell_type in args.cell_types: - print("Start training for cell type:", cell_type) + print("Start training for source cell type:", cell_type) _train_source(args, cell_type) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py index 6ebad36b..f683d062 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_fixmatch.py @@ -90,7 +90,13 @@ def _train_source_target(args, source_cell_type, target_cell_type): def _train_source(args, cell_type): - for target_cell_type in common.CELL_TYPES: + if args.target_ct is None: + target_cell_list = common.CELL_TYPES + else: + target_cell_list = args.target_ct + + for target_cell_type in target_cell_list: + print("Training on target cell type:", target_cell_type) if target_cell_type == cell_type: continue _train_source_target(args, cell_type, target_cell_type) @@ -98,7 +104,7 @@ def _train_source(args, cell_type): def run_training(args): for cell_type in args.cell_types: - print("Start training for cell type:", cell_type) + print("Start training for source cell type:", cell_type) _train_source(args, cell_type) diff --git a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py index 5a5f08a4..dc96fa7e 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/unet_mean_teacher.py @@ -75,7 +75,13 @@ def _train_source_target(args, source_cell_type, target_cell_type): def _train_source(args, cell_type): - for target_cell_type in common.CELL_TYPES: + if args.target_ct is None: + target_cell_list = common.CELL_TYPES + else: + target_cell_list = args.target_ct + + for target_cell_type in target_cell_list: + print("Training on target cell type:", target_cell_type) if target_cell_type == cell_type: continue _train_source_target(args, cell_type, target_cell_type) @@ -83,7 +89,7 @@ def _train_source(args, cell_type): def run_training(args): for cell_type in args.cell_types: - print("Start training for cell type:", cell_type) + print("Start training for source cell type:", cell_type) _train_source(args, cell_type) From a94d232d918d6ce1fc71f5c6d655d5c7696dd870 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 12 Apr 2023 14:23:56 +0200 Subject: [PATCH 2/2] Update README.md - Add Target Domain Argument --- experiments/probabilistic_domain_adaptation/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/experiments/probabilistic_domain_adaptation/README.md b/experiments/probabilistic_domain_adaptation/README.md index d7d76f87..651c6e52 100644 --- a/experiments/probabilistic_domain_adaptation/README.md +++ b/experiments/probabilistic_domain_adaptation/README.md @@ -23,6 +23,7 @@ python unet_mean_teacher.py -p [check / train / evaluate] -i -s -o + [(optional) --target_ct ] [(optional) --confidence_threshold ] ``` @@ -33,6 +34,7 @@ python unet_adamt.py -p [check / train / evaluate] -i -s -o + [(optional) --target_ct ] [(optional) --confidence_threshold ] ``` @@ -43,6 +45,7 @@ python unet_fixmatch.py -p [check / train / evaluate] -i -s -o + [(optional) --target_ct ] [(optional) --confidence_threshold ] [(optional) --distribution_alignment ] ``` @@ -54,6 +57,7 @@ python unet_adamatch.py -p [check / train / evaluate] -i -s -o + [(optional) --target_ct ] [(optional) --confidence_threshold ] [(optional) --distribution_alignment ] ```