Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Probabilistic Self-Training (MT) Approaches #125

Merged
merged 3 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions experiments/probabilistic_domain_adaptation/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ def compute_class_distribution(root_folder, label_threshold=0.5):
return [bg_frequency, fg_frequency]


def get_punet_predictions(model, inputs):
activation = torch.nn.Sigmoid()
prior_samples = 16

with torch.no_grad():
model.forward(inputs)
samples_per_input = [activation(model.sample(testing=True))for _ in range(prior_samples)]
avg_pred = torch.stack(samples_per_input, dim=0).sum(dim=0) / prior_samples

return avg_pred


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
# set model_state to "teacher_state" when using this with a mean-teacher method
def evaluate_transfered_model(
Expand All @@ -121,9 +133,12 @@ def evaluate_transfered_model(
label_root = os.path.join(args.input, "annotations", "livecell_test_images")

results = {"src": [ct_src]}
device = torch.device("cuda")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

thresh = args.confidence_threshold
if thresh is None:
assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking"

with torch.no_grad():
for ct_trg in CELL_TYPES:

Expand All @@ -134,21 +149,32 @@ def evaluate_transfered_model(
if args.output is None:
out_folder = None
else:
out_folder = args.output + f"thresh-{thresh}"

if args.consensus_masking:
out_folder = out_folder + "-masking"

if args.distribution_alignment:
out_folder = os.path.join(args.output, f"thresh-{thresh}-distro-align/", ct_src, ct_trg)
else:
out_folder = os.path.join(args.output, f"thresh-{thresh}", ct_src, ct_trg)
out_folder = out_folder + "-distro-align"

out_folder = os.path.join(out_folder, ct_src, ct_trg)

if out_folder is not None:
os.makedirs(out_folder, exist_ok=True)

if args.save_root is None:
ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
else:
ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}"

if args.consensus_masking:
ckpt = ckpt + "-masking"

if args.distribution_alignment:
ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}-distro-align/{ct_src}/{ct_trg}"
else:
ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
ckpt = ckpt + "-distro-align"

ckpt = os.path.join(ckpt, ct_src, ct_trg)

model = get_model()
model = load_model(checkpoint=ckpt, model=model, state_key=model_state, device=device)

Expand Down Expand Up @@ -184,21 +210,9 @@ def evaluate_transfered_model(
return pd.DataFrame(results)


def get_punet_predictions(model, inputs):
activation = torch.nn.Sigmoid()
prior_samples = 16

with torch.no_grad():
model.forward(inputs)
samples_per_input = [activation(model.sample(testing=True))for _ in range(prior_samples)]
avg_pred = torch.stack(samples_per_input, dim=0).sum(dim=0) / prior_samples

return avg_pred


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
def evaluate_source_model(args, ct_src, method, get_model=get_unet, prediction_function=None):
device = torch.device("cuda")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if args.save_root is None:
ckpt = f"checkpoints/{method}/{ct_src}"
Expand Down Expand Up @@ -310,4 +324,5 @@ def get_parser(default_batch_size=8, default_iterations=int(1e5)):
parser.add_argument("-c", "--cell_types", nargs="+", default=CELL_TYPES)
parser.add_argument("--target_ct", nargs="+", default=None)
parser.add_argument("-o", "--output")
parser.add_argument("--distribution_alignment", action='store_true')
return parser
134 changes: 134 additions & 0 deletions experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os

import pandas as pd
import torch
import torch_em.self_training as self_training

import common


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the unsupervised loader for the first cell type", cell_types[0])

loader = common.get_unsupervised_loader(
args, "train", cell_types[0],
teacher_augmentation="weak", student_augmentation="weak",
)
check_loader(loader, n_images)


def _train_source_target(args, source_cell_type, target_cell_type):
model = common.get_punet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)

# self training functionality
# - when thresh is None, we don't mask the reconstruction loss (RL) with label filters
# - when thresh is passed (float), we mask the RL with weighted consensus label filters
# - when thresh is passed (float) with args.consensus_masking, we mask the RL with masked consensus label filters
thresh = args.confidence_threshold
if thresh is None:
assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking"

pseudo_labeler = self_training.ProbabilisticPseudoLabeler(activation=torch.nn.Sigmoid(),
confidence_threshold=thresh, prior_samples=16,
consensus_masking=args.consensus_masking)
loss = self_training.ProbabilisticUNetLoss()
loss_and_metric = self_training.ProbabilisticUNetLossAndMetric()

# data loaders
supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size)
supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1)
unsupervised_train_loader = common.get_unsupervised_loader(
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if args.consensus_masking:
name = f"punet_adamt/thresh-{thresh}-masking"
else:
name = f"punet_adamt/thresh-{thresh}"

trainer = self_training.MeanTeacherTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
supervised_train_loader=supervised_train_loader,
unsupervised_train_loader=unsupervised_train_loader,
supervised_val_loader=supervised_val_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=None,
mixed_precision=True,
device=device,
log_image_interval=100,
save_root=args.save_root,
compile_model=False
)
trainer.fit(args.n_iterations)


def _train_source(args, cell_type):
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_transfered_model(args, ct, "punet_adamt", model_state="teacher_state")
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "punet_adamt.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=100000, default_batch_size=4)
parser.add_argument("--confidence_threshold", default=None, type=float)
parser.add_argument("--consensus_masking", action='store_true')
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os

import pandas as pd
import torch
import torch_em.self_training as self_training
from torch_em.util import load_model

import common


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the unsupervised loader for the first cell type", cell_types[0])

loader = common.get_unsupervised_loader(
args, "train", cell_types[0],
teacher_augmentation="weak", student_augmentation="weak",
)
check_loader(loader, n_images)


def _train_source_target(args, source_cell_type, target_cell_type):

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = common.get_punet()
if args.save_root is None:
src_checkpoint = f"./checkpoints/punet_source/{source_cell_type}"
else:
src_checkpoint = args.save_root + f"checkpoints/punet_source/{source_cell_type}"
model = load_model(checkpoint=src_checkpoint, model=model, device=device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)

# self training functionality
# - when thresh is None, we don't mask the reconstruction loss (RL) with label filters
# - when thresh is passed (float), we mask the RL with weighted consensus label filters
# - when thresh is passed (float) with args.consensus_masking, we mask the RL with masked consensus label filters
thresh = args.confidence_threshold
if thresh is None:
assert args.consensus_masking is False, "Provide a confidence threshold to use consensus masking"

pseudo_labeler = self_training.ProbabilisticPseudoLabeler(activation=torch.nn.Sigmoid(),
confidence_threshold=thresh, prior_samples=16,
consensus_masking=args.consensus_masking)
loss = self_training.ProbabilisticUNetLoss()
loss_and_metric = self_training.ProbabilisticUNetLossAndMetric()

# data loaders
unsupervised_train_loader = common.get_unsupervised_loader(
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)

if args.consensus_masking:
name = f"punet_mean_teacher/thresh-{thresh}-masking/{source_cell_type}/{target_cell_type}"
else:
name = f"punet_mean_teacher/thresh-{thresh}/{source_cell_type}/{target_cell_type}"

trainer = self_training.MeanTeacherTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
unsupervised_train_loader=unsupervised_train_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=None,
mixed_precision=True,
device=device,
log_image_interval=100,
save_root=args.save_root,
compile_model=False
)
trainer.fit(args.n_iterations)


def _train_source(args, cell_type):
if args.target_ct is None:
target_cell_list = common.CELL_TYPES
else:
target_cell_list = args.target_ct

for target_cell_type in target_cell_list:
print("Training on target cell type:", target_cell_type)
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for source cell type:", cell_type)
_train_source(args, cell_type)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_transfered_model(args, ct, "punet_mean_teacher",
get_model=common.get_punet,
model_state="teacher_state",
prediction_function=common.get_punet_predictions)
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "punet_mean_teacher.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=10000, default_batch_size=4)
parser.add_argument("--confidence_threshold", default=None, type=float)
parser.add_argument("--consensus_masking", action='store_true')
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,4 @@ def main():


if __name__ == "__main__":
# break
main()
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def _train_source_target(args, source_cell_type, target_cell_type):
teacher_augmentation="weak", student_augmentation="strong-separate",
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

name = f"unet_fixmatch/thresh-{thresh}"

if args.distribution_alignment:
Expand Down
2 changes: 1 addition & 1 deletion torch_em/self_training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
ProbabilisticUNetLossAndMetric
from .mean_teacher import MeanTeacherTrainer
from .fix_match import FixMatchTrainer
from .pseudo_labeling import DefaultPseudoLabeler
from .pseudo_labeling import DefaultPseudoLabeler, ProbabilisticPseudoLabeler
from .probabilistic_unet_trainer import ProbabilisticUNetTrainer, DummyLoss
Loading
Loading