Skip to content

Commit

Permalink
Merge pull request #123 from anwai98/main
Browse files Browse the repository at this point in the history
Add 2D Probabilistic UNet Training Setup
  • Loading branch information
constantinpape authored Apr 29, 2023
2 parents d53d3dc + 5df0dc6 commit 1f8138a
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 8 deletions.
24 changes: 21 additions & 3 deletions experiments/probabilistic_domain_adaptation/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import torch
import torch_em
from torch_em.model import ProbabilisticUNet

from elf.evaluation import dice_score
from torch_em.data.datasets.livecell import (get_livecell_loader,
Expand All @@ -23,7 +24,6 @@

CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]


#
# The augmentations we use for the LiveCELL experiments:
# - weak augmenations:
Expand Down Expand Up @@ -83,6 +83,11 @@ def get_unet():
return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4)


def get_punet():
return ProbabilisticUNet(input_channels=1, num_classes=1, num_filters=[64, 128, 256, 512],
latent_dim=6, no_convs_fcomb=3, beta=1.0, rl_swap=True)


# Computing the Source Distribution for Distribution Alignment
def compute_class_distribution(root_folder, label_threshold=0.5):

Expand Down Expand Up @@ -179,20 +184,33 @@ def evaluate_transfered_model(
return pd.DataFrame(results)


def get_punet_predictions(model, inputs):
activation = torch.nn.Sigmoid()
prior_samples = 16

with torch.no_grad():
model.forward(inputs)
samples_per_input = [activation(model.sample(testing=True))for _ in range(prior_samples)]
avg_pred = torch.stack(samples_per_input, dim=0).sum(dim=0) / prior_samples

return avg_pred


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
def evaluate_source_model(args, ct_src, method, get_model=get_unet, prediction_function=None):
device = torch.device("cuda")

if args.save_root is None:
ckpt = f"checkpoints/{method}/{ct_src}"
else:
ckpt = args.save_root + f"checkpoints/{method}/{ct_src}"
model = get_model()
model = torch_em.util.get_trainer(ckpt).model
model = load_model(checkpoint=ckpt, model=model, device=device)

image_folder = os.path.join(args.input, "images", "livecell_test_images")
label_root = os.path.join(args.input, "annotations", "livecell_test_images")

results = {"src": [ct_src]}
device = torch.device("cuda")

with torch.no_grad():
for ct_trg in CELL_TYPES:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import pandas as pd

import torch

from torch_em.self_training import ProbabilisticUNetTrainer, \
ProbabilisticUNetLoss, ProbabilisticUNetLossAndMetric

import common


def _train_cell_type(args, cell_type):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

train_loader = common.get_supervised_loader(args, "train", cell_type, args.batch_size)
val_loader = common.get_supervised_loader(args, "val", cell_type, 1)
name = f"punet_source/{cell_type}"

model = common.get_punet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True)
model.to(device)

supervised_loss = ProbabilisticUNetLoss()
supervised_loss_and_metric = ProbabilisticUNetLossAndMetric()

trainer = ProbabilisticUNetTrainer(
name=name,
save_root=args.save_root,
model=model,
train_loader=train_loader,
val_loader=val_loader,
logger=None,
device=device,
lr_scheduler=scheduler,
optimizer=optimizer,
mixed_precision=True,
log_image_interval=100,
loss=supervised_loss,
loss_and_metric=supervised_loss_and_metric
)

trainer.fit(iterations=args.n_iterations)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
_train_cell_type(args, cell_type)


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the loader for the first cell type", cell_types[0])

loader = common.get_supervised_loader(args)
check_loader(loader, n_images)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_source_model(args, ct, "punet_source", get_model=common.get_punet,
prediction_function=common.get_punet_predictions)
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "punet_source.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=100000)
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion torch_em/self_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .logger import SelfTrainingTensorboardLogger
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric, ProbabilisticUNetLoss, \
ProbabilisticUNetLossAndMetric
from .mean_teacher import MeanTeacherTrainer
from .fix_match import FixMatchTrainer
from .pseudo_labeling import DefaultPseudoLabeler
from .probabilistic_unet_trainer import ProbabilisticUNetTrainer, DummyLoss
79 changes: 78 additions & 1 deletion torch_em/self_training/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch.nn as nn
import torch
import torch_em
import torch.nn as nn
from torch_em.loss import DiceLoss


class DefaultSelfTrainingLoss(nn.Module):
Expand Down Expand Up @@ -55,3 +57,78 @@ def __call__(self, model, input_, labels, label_filter=None):
loss = self.loss(prediction * label_filter, labels * label_filter)
metric = self.metric(prediction, labels)
return loss, metric


def l2_regularisation(m):
l2_reg = None

for W in m.parameters():
if l2_reg is None:
l2_reg = W.norm(2)
else:
l2_reg = l2_reg + W.norm(2)
return l2_reg


class ProbabilisticUNetLoss(nn.Module):
"""
Loss function for Probabilistic UNet
Parameters :
# TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
loss [nn.Module] - the loss function to be used. (default: None)
"""
def __init__(self, loss=None):
super().__init__()
self.loss = loss

def __call__(self, model, input_, labels):
model.forward(input_, labels)

if self.loss is None:
elbo = model.elbo(labels)
reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
l2_regularisation(model.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss

return loss


class ProbabilisticUNetLossAndMetric(nn.Module):
"""Loss and metric function for Probabilistic UNet.
Parameters:
# TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
loss [nn.Module] - the loss function to be used. (default: None)
metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
activation [nn.Module, callable] - the activation function to be applied to the prediction
before evaluating the average predictions. (default: None)
"""
def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
super().__init__()
self.activation = activation
self.metric = metric
self.loss = loss
self.prior_samples = prior_samples

def __call__(self, model, input_, labels):
model.forward(input_, labels)

if self.loss is None:
elbo = model.elbo(labels)
reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
l2_regularisation(model.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss

samples_per_distribution = []
for _ in range(self.prior_samples):
samples = model.sample(testing=False)
if self.activation is not None:
samples = self.activation(samples)
samples_per_distribution.append(samples)

avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution)
metric = self.metric(avg_samples, labels)

return loss, metric
6 changes: 3 additions & 3 deletions torch_em/self_training/mean_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class Dummy(torch.nn.Module):


class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
"""This trainer implements self-traning for semi-supervised learning and domain following the 'MeanTeacher' approach
of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the
student model via EMA of weights to predict pseudo-labels on unlabeled data.
"""This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher'
approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from
the student model via EMA of weights to predict pseudo-labels on unlabeled data.
We support two training strategies: joint training on labeled and unlabeled data
(with a supervised and unsupervised loss function). And training only on the unsupervised data.
Expand Down
114 changes: 114 additions & 0 deletions torch_em/self_training/probabilistic_unet_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import time
import torch
import torch_em


class DummyLoss(torch.nn.Module):
pass


class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer):
"""This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034).
This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative
segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from
the posterior distribution, estimating the loss and further sampling from the prior for validation.
Parameters:
clipping_value [float] - (default: None)
prior_samples [int] - (default: 16)
loss [callable] - (default: None)
loss_and_metric [callable] - (default: None)
"""

def __init__(
self,
clipping_value=None,
prior_samples=16,
loss=None,
loss_and_metric=None,
**kwargs
):
super().__init__(loss=loss, metric=DummyLoss(), **kwargs)
assert loss, loss_and_metric is not None

self.loss_and_metric = loss_and_metric

self.clipping_value = clipping_value

self.prior_samples = prior_samples
self.sigmoid = torch.nn.Sigmoid()

self._kwargs = kwargs

#
# functionality for sampling from the network
#

def _sample(self):
samples = [self.model.sample() for _ in range(self.prior_samples)]
return samples

#
# training and validation functionality
#

def _train_epoch_impl(self, progress, forward_context, backprop):
self.model.train()

n_iter = 0
t_per_iter = time.time()

for x, y in self.train_loader:
x, y = x.to(self.device), y.to(self.device)

self.optimizer.zero_grad()

with forward_context():
# We pass the model, the input and the labels to the supervised loss function, so
# that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet.
loss = self.loss(self.model, x, y)

backprop(loss)

# To counter the exploding gradients in the posterior net
if self.clipping_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value)

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
samples = self._sample() if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train(self._iteration, loss, lr, x, y, samples)

self._iteration += 1
n_iter += 1
if self._iteration >= self.max_iteration:
break
progress.update(1)

t_per_iter = (time.time() - t_per_iter) / n_iter
return t_per_iter

def _validate_impl(self, forward_context):
self.model.eval()

metric_val = 0.0
loss_val = 0.0

with torch.no_grad():
for x, y in self.val_loader:
x, y = x.to(self.device), y.to(self.device)

with forward_context():
loss, metric = self.loss_and_metric(self.model, x, y)

loss_val += loss.item()
metric_val += metric

metric_val /= len(self.val_loader)
loss_val /= len(self.val_loader)

if self.logger is not None:
samples = self._sample()
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples)

return metric_val

0 comments on commit 1f8138a

Please sign in to comment.