Skip to content

Commit

Permalink
Merge pull request #121 from anwai98/main
Browse files Browse the repository at this point in the history
Add Argument - Target Domain
  • Loading branch information
constantinpape committed Apr 12, 2023
2 parents da3b4c9 + a94d232 commit b028976
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 8 deletions.
4 changes: 4 additions & 0 deletions experiments/probabilistic_domain_adaptation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ python unet_mean_teacher.py -p [check / train / evaluate]
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --target_ct <TARGET-DOMAIN-CELL-TYPE>]
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
```

Expand All @@ -33,6 +34,7 @@ python unet_adamt.py -p [check / train / evaluate]
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --target_ct <TARGET-DOMAIN-CELL-TYPE>]
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
```

Expand All @@ -43,6 +45,7 @@ python unet_fixmatch.py -p [check / train / evaluate]
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --target_ct <TARGET-DOMAIN-CELL-TYPE>]
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
[(optional) --distribution_alignment <ACTIVATES-DISTRIBUTION-ALIGNMENT>]
```
Expand All @@ -54,6 +57,7 @@ python unet_adamatch.py -p [check / train / evaluate]
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --target_ct <TARGET-DOMAIN-CELL-TYPE>]
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
[(optional) --distribution_alignment <ACTIVATES-DISTRIBUTION-ALIGNMENT>]
```
Original file line number Diff line number Diff line change
Expand Up @@ -282,5 +282,6 @@ def get_parser(default_batch_size=8, default_iterations=int(1e5)):
parser.add_argument("-n", "--n_iterations", default=default_iterations, type=int)
parser.add_argument("-s", "--save_root")
parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES)
parser.add_argument("--target_ct", nargs="+", default=None)
parser.add_argument("-o", "--output")
return parser
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,21 @@ def _train_source_target(args, source_cell_type, target_cell_type):


def _train_source(args, cell_type):
for target_cell_type in common.CELL_TYPES:
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,21 @@ def _train_source_target(args, source_cell_type, target_cell_type):


def _train_source(args, cell_type):
for target_cell_type in common.CELL_TYPES:
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,21 @@ def _train_source_target(args, source_cell_type, target_cell_type):


def _train_source(args, cell_type):
for target_cell_type in common.CELL_TYPES:
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,21 @@ def _train_source_target(args, source_cell_type, target_cell_type):


def _train_source(args, cell_type):
for target_cell_type in common.CELL_TYPES:
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


Expand Down

0 comments on commit b028976

Please sign in to comment.