-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #123 from anwai98/main
Add 2D Probabilistic UNet Training Setup
- Loading branch information
Showing
6 changed files
with
310 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
experiments/probabilistic_domain_adaptation/livecell/punet_source.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
import pandas as pd | ||
|
||
import torch | ||
|
||
from torch_em.self_training import ProbabilisticUNetTrainer, \ | ||
ProbabilisticUNetLoss, ProbabilisticUNetLossAndMetric | ||
|
||
import common | ||
|
||
|
||
def _train_cell_type(args, cell_type): | ||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
train_loader = common.get_supervised_loader(args, "train", cell_type, args.batch_size) | ||
val_loader = common.get_supervised_loader(args, "val", cell_type, 1) | ||
name = f"punet_source/{cell_type}" | ||
|
||
model = common.get_punet() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True) | ||
model.to(device) | ||
|
||
supervised_loss = ProbabilisticUNetLoss() | ||
supervised_loss_and_metric = ProbabilisticUNetLossAndMetric() | ||
|
||
trainer = ProbabilisticUNetTrainer( | ||
name=name, | ||
save_root=args.save_root, | ||
model=model, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
logger=None, | ||
device=device, | ||
lr_scheduler=scheduler, | ||
optimizer=optimizer, | ||
mixed_precision=True, | ||
log_image_interval=100, | ||
loss=supervised_loss, | ||
loss_and_metric=supervised_loss_and_metric | ||
) | ||
|
||
trainer.fit(iterations=args.n_iterations) | ||
|
||
|
||
def run_training(args): | ||
for cell_type in args.cell_types: | ||
print("Start training for cell type:", cell_type) | ||
_train_cell_type(args, cell_type) | ||
|
||
|
||
def check_loader(args, n_images=5): | ||
from torch_em.util.debug import check_loader | ||
|
||
cell_types = args.cell_types | ||
print("The cell types", cell_types, "were selected.") | ||
print("Checking the loader for the first cell type", cell_types[0]) | ||
|
||
loader = common.get_supervised_loader(args) | ||
check_loader(loader, n_images) | ||
|
||
|
||
def run_evaluation(args): | ||
results = [] | ||
for ct in args.cell_types: | ||
res = common.evaluate_source_model(args, ct, "punet_source", get_model=common.get_punet, | ||
prediction_function=common.get_punet_predictions) | ||
results.append(res) | ||
results = pd.concat(results) | ||
print("Evaluation results:") | ||
print(results) | ||
result_folder = "./results" | ||
os.makedirs(result_folder, exist_ok=True) | ||
results.to_csv(os.path.join(result_folder, "punet_source.csv"), index=False) | ||
|
||
|
||
def main(): | ||
parser = common.get_parser(default_iterations=100000) | ||
args = parser.parse_args() | ||
if args.phase in ("c", "check"): | ||
check_loader(args) | ||
elif args.phase in ("t", "train"): | ||
run_training(args) | ||
elif args.phase in ("e", "evaluate"): | ||
run_evaluation(args) | ||
else: | ||
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from .logger import SelfTrainingTensorboardLogger | ||
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric | ||
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric, ProbabilisticUNetLoss, \ | ||
ProbabilisticUNetLossAndMetric | ||
from .mean_teacher import MeanTeacherTrainer | ||
from .fix_match import FixMatchTrainer | ||
from .pseudo_labeling import DefaultPseudoLabeler | ||
from .probabilistic_unet_trainer import ProbabilisticUNetTrainer, DummyLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import time | ||
import torch | ||
import torch_em | ||
|
||
|
||
class DummyLoss(torch.nn.Module): | ||
pass | ||
|
||
|
||
class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer): | ||
"""This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034). | ||
This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative | ||
segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from | ||
the posterior distribution, estimating the loss and further sampling from the prior for validation. | ||
Parameters: | ||
clipping_value [float] - (default: None) | ||
prior_samples [int] - (default: 16) | ||
loss [callable] - (default: None) | ||
loss_and_metric [callable] - (default: None) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
clipping_value=None, | ||
prior_samples=16, | ||
loss=None, | ||
loss_and_metric=None, | ||
**kwargs | ||
): | ||
super().__init__(loss=loss, metric=DummyLoss(), **kwargs) | ||
assert loss, loss_and_metric is not None | ||
|
||
self.loss_and_metric = loss_and_metric | ||
|
||
self.clipping_value = clipping_value | ||
|
||
self.prior_samples = prior_samples | ||
self.sigmoid = torch.nn.Sigmoid() | ||
|
||
self._kwargs = kwargs | ||
|
||
# | ||
# functionality for sampling from the network | ||
# | ||
|
||
def _sample(self): | ||
samples = [self.model.sample() for _ in range(self.prior_samples)] | ||
return samples | ||
|
||
# | ||
# training and validation functionality | ||
# | ||
|
||
def _train_epoch_impl(self, progress, forward_context, backprop): | ||
self.model.train() | ||
|
||
n_iter = 0 | ||
t_per_iter = time.time() | ||
|
||
for x, y in self.train_loader: | ||
x, y = x.to(self.device), y.to(self.device) | ||
|
||
self.optimizer.zero_grad() | ||
|
||
with forward_context(): | ||
# We pass the model, the input and the labels to the supervised loss function, so | ||
# that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet. | ||
loss = self.loss(self.model, x, y) | ||
|
||
backprop(loss) | ||
|
||
# To counter the exploding gradients in the posterior net | ||
if self.clipping_value is not None: | ||
torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value) | ||
|
||
if self.logger is not None: | ||
lr = [pm["lr"] for pm in self.optimizer.param_groups][0] | ||
samples = self._sample() if self._iteration % self.log_image_interval == 0 else None | ||
self.logger.log_train(self._iteration, loss, lr, x, y, samples) | ||
|
||
self._iteration += 1 | ||
n_iter += 1 | ||
if self._iteration >= self.max_iteration: | ||
break | ||
progress.update(1) | ||
|
||
t_per_iter = (time.time() - t_per_iter) / n_iter | ||
return t_per_iter | ||
|
||
def _validate_impl(self, forward_context): | ||
self.model.eval() | ||
|
||
metric_val = 0.0 | ||
loss_val = 0.0 | ||
|
||
with torch.no_grad(): | ||
for x, y in self.val_loader: | ||
x, y = x.to(self.device), y.to(self.device) | ||
|
||
with forward_context(): | ||
loss, metric = self.loss_and_metric(self.model, x, y) | ||
|
||
loss_val += loss.item() | ||
metric_val += metric | ||
|
||
metric_val /= len(self.val_loader) | ||
loss_val /= len(self.val_loader) | ||
|
||
if self.logger is not None: | ||
samples = self._sample() | ||
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples) | ||
|
||
return metric_val |