Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FixMatch-based Probabilistic Domain Adaptation Setups #126

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions experiments/probabilistic_domain_adaptation/livecell/punet_adamatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os

import pandas as pd
import torch
import torch_em.self_training as self_training

import common


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the unsupervised loader for the first cell type", cell_types[0])

loader = common.get_unsupervised_loader(
args, "train", cell_types[0],
teacher_augmentation="weak", student_augmentation="strong",
)
check_loader(loader, n_images)


def _train_source_target(args, source_cell_type, target_cell_type):
model = common.get_punet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)

# self training functionality
# - when thresh is None, we don't mask the reconstruction loss (RL) with label filters
# - when thresh is passed (float), we mask the RL with weighted consensus label filters
# - when thresh is passed (float) with args.consensus_masking, we mask the RL with masked consensus label filters
thresh = args.confidence_threshold
if thresh is None:
assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking"

pseudo_labeler = self_training.ProbabilisticPseudoLabeler(activation=torch.nn.Sigmoid(),
confidence_threshold=thresh, prior_samples=16,
consensus_masking=args.consensus_masking)
loss = self_training.ProbabilisticUNetLoss()
loss_and_metric = self_training.ProbabilisticUNetLossAndMetric()

# data loaders
supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size)
supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1)
unsupervised_train_loader = common.get_unsupervised_loader(
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-joint",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-joint",
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if args.consensus_masking:
name = f"punet_adamatch/thresh-{thresh}-masking"
else:
name = f"punet_adamatch/thresh-{thresh}"

if args.distribution_alignment:
assert args.output is not None
print(f"Getting scores for Source {source_cell_type} at Targets {target_cell_type}")
pred_folder = args.output + f"punet_source/{source_cell_type}/{target_cell_type}/"
src_dist = common.compute_class_distribution(pred_folder)
name = f"{name}-distro-align"
else:
src_dist = None

name = name + f"/{source_cell_type}/{target_cell_type}"

trainer = self_training.FixMatchTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
supervised_train_loader=supervised_train_loader,
unsupervised_train_loader=unsupervised_train_loader,
supervised_val_loader=supervised_val_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=None,
mixed_precision=True,
device=device,
log_image_interval=100,
save_root=args.save_root,
source_distribution=src_dist
)
trainer.fit(args.n_iterations)


def _train_source(args, cell_type):
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_transfered_model(args, ct, "punet_adamatch")
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "punet_adamatch.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=10000, default_batch_size=4)
parser.add_argument("--confidence_threshold", default=None, type=float)
parser.add_argument("--consensus_masking", action='store_true')
parser.add_argument("--distribution_alignment", action='store_true', help="Activates Distribution Alignment")
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def _train_source_target(args, source_cell_type, target_cell_type):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if args.consensus_masking:
name = f"punet_adamt/thresh-{thresh}-masking"
name = f"punet_adamt/thresh-{thresh}-masking/{source_cell_type}/{target_cell_type}"
else:
name = f"punet_adamt/thresh-{thresh}"
name = f"punet_adamt/thresh-{thresh}/{source_cell_type}/{target_cell_type}"

trainer = self_training.MeanTeacherTrainer(
name=name,
Expand Down
151 changes: 151 additions & 0 deletions experiments/probabilistic_domain_adaptation/livecell/punet_fixmatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os

import pandas as pd
import torch
import torch_em.self_training as self_training
from torch_em.util import load_model

import common


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the unsupervised loader for the first cell type", cell_types[0])

loader = common.get_unsupervised_loader(
args, "train", cell_types[0],
teacher_augmentation="weak", student_augmentation="strong",
)
check_loader(loader, n_images)


def _train_source_target(args, source_cell_type, target_cell_type):

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = common.get_punet()
if args.save_root is None:
src_checkpoint = f"./checkpoints/punet_source/{source_cell_type}"
else:
src_checkpoint = args.save_root + f"checkpoints/punet_source/{source_cell_type}"
model = load_model(checkpoint=src_checkpoint, model=model, device=device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)

# self training functionality
# - when thresh is None, we don't mask the reconstruction loss (RL) with label filters
# - when thresh is passed (float), we mask the RL with weighted consensus label filters
# - when thresh is passed (float) with args.consensus_masking, we mask the RL with masked consensus label filters
thresh = args.confidence_threshold
if thresh is None:
assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking"

pseudo_labeler = self_training.ProbabilisticPseudoLabeler(activation=torch.nn.Sigmoid(),
confidence_threshold=thresh, prior_samples=16,
consensus_masking=args.consensus_masking)
loss = self_training.ProbabilisticUNetLoss()
loss_and_metric = self_training.ProbabilisticUNetLossAndMetric()

# data loaders
unsupervised_train_loader = common.get_unsupervised_loader(
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-separate",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-separate",
)

if args.consensus_masking:
name = f"punet_fixmatch/thresh-{thresh}-masking"
else:
name = f"punet_fixmatch/thresh-{thresh}"

if args.distribution_alignment:
assert args.output is not None
print(f"Getting scores for Source {source_cell_type} at Targets {target_cell_type}")
pred_folder = args.output + f"punet_source/{source_cell_type}/{target_cell_type}/"
src_dist = common.compute_class_distribution(pred_folder)
name = f"{name}-distro-align"
else:
src_dist = None

name = name + f"/{source_cell_type}/{target_cell_type}"

trainer = self_training.FixMatchTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
unsupervised_train_loader=unsupervised_train_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=None,
mixed_precision=True,
device=device,
log_image_interval=100,
save_root=args.save_root,
source_distribution=src_dist,
compile_model=False
)
trainer.fit(args.n_iterations)


def _train_source(args, cell_type):
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_transfered_model(args, ct, "punet_fixmatch")
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "punet_fixmatch.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=10000, default_batch_size=4)
parser.add_argument("--confidence_threshold", default=None, type=float)
parser.add_argument("--consensus_masking", action='store_true')
parser.add_argument("--distribution_alignment", action='store_true', help="Activates Distribution Alignment")
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
Loading