Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 2D Probabilistic UNet Training Setup #123

Merged
merged 7 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import torch
import torch_em
from torch_em.model import ProbabilisticUNet

from elf.evaluation import dice_score
from torch_em.data.datasets.livecell import (get_livecell_loader,
Expand Down Expand Up @@ -83,6 +84,11 @@ def get_unet():
return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4)


def get_punet():
return ProbabilisticUNet(input_channels=1, num_classes=1, num_filters=[64, 128, 256, 512],
latent_dim=6, no_convs_fcomb=3, beta=1.0, rl_swap=True)


# Computing the Source Distribution for Distribution Alignment
def compute_class_distribution(root_folder, label_threshold=0.5):

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from torch_em.self_training import ProbabilisticUNetTrainer, DummyLoss, \
ProbabilisticUNetLoss, ProbabilisticUNetLossAndMetric

import common

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _train_cell_type(args, cell_type, device=DEVICE):
train_loader = common.get_supervised_loader(args, "train", cell_type, args.batch_size)
val_loader = common.get_supervised_loader(args, "val", cell_type, 1)
name = f"punet_source/{cell_type}"

model = common.get_punet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True)
model.to(device)

supervised_loss = ProbabilisticUNetLoss()
supervised_loss_and_metric = ProbabilisticUNetLossAndMetric()

trainer = ProbabilisticUNetTrainer(
name=name,
save_root=args.save_root,
model=model,
train_loader=train_loader,
val_loader=val_loader,
metric=DummyLoss(),
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
logger=None,
device=device,
lr_scheduler=scheduler,
optimizer=optimizer,
mixed_precision=True,
log_image_interval=100,
loss=supervised_loss,
loss_and_metric=supervised_loss_and_metric
)

trainer.fit(iterations=args.n_iterations)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
_train_cell_type(args, cell_type)


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the loader for the first cell type", cell_types[0])

loader = common.get_supervised_loader(args)
check_loader(loader, n_images)


def run_evaluation(args):
# TODO
pass


def main():
parser = common.get_parser(default_iterations=100000)
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion torch_em/self_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .logger import SelfTrainingTensorboardLogger
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric, ProbabilisticUNetLoss, \
ProbabilisticUNetLossAndMetric
from .mean_teacher import MeanTeacherTrainer
from .fix_match import FixMatchTrainer
from .pseudo_labeling import DefaultPseudoLabeler
from .probabilistic_unet_trainer import ProbabilisticUNetTrainer, DummyLoss
79 changes: 78 additions & 1 deletion torch_em/self_training/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch.nn as nn
import torch
import torch_em
import torch.nn as nn
from torch_em.loss import DiceLoss


class DefaultSelfTrainingLoss(nn.Module):
Expand Down Expand Up @@ -55,3 +57,78 @@ def __call__(self, model, input_, labels, label_filter=None):
loss = self.loss(prediction * label_filter, labels * label_filter)
metric = self.metric(prediction, labels)
return loss, metric


def l2_regularisation(m):
l2_reg = None

for W in m.parameters():
if l2_reg is None:
l2_reg = W.norm(2)
else:
l2_reg = l2_reg + W.norm(2)
return l2_reg


class ProbabilisticUNetLoss(nn.Module):
"""
Loss function for Probabilistic UNet

Parameters :
# TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
loss [nn.Module] - the loss function to be used. (default: None)
"""
def __init__(self, loss=None):
super().__init__()
self.loss = loss

def __call__(self, model, input_, labels):
model.forward(input_, labels)

if self.loss is None:
elbo = model.elbo(labels)
reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
l2_regularisation(model.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss

return loss


class ProbabilisticUNetLossAndMetric(nn.Module):
"""Loss and metric function for Probabilistic UNet.

Parameters:
# TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
loss [nn.Module] - the loss function to be used. (default: None)

metric [nn.Module] - the metric function to be used. (default: torch_em.loss.DiceLoss)
activation [nn.Module, callable] - the activation function to be applied to the prediction
before evaluating the average predictions. (default: None)
"""
def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
super().__init__()
self.activation = activation
self.metric = metric
self.loss = loss
self.prior_samples = prior_samples

def __call__(self, model, input_, labels):
model.forward(input_, labels)

if self.loss is None:
elbo = model.elbo(labels)
reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
l2_regularisation(model.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss

samples_per_distribution = []
for _ in range(self.prior_samples):
samples = model.sample(testing=False)
if self.activation is not None:
samples = self.activation(samples)
samples_per_distribution.append(samples)

avg_samples = torch.stack(samples_per_distribution, dim=0).sum(dim=0) / len(samples_per_distribution)
metric = self.metric(avg_samples, labels)

return loss, metric
6 changes: 3 additions & 3 deletions torch_em/self_training/mean_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class Dummy(torch.nn.Module):


class MeanTeacherTrainer(torch_em.trainer.DefaultTrainer):
"""This trainer implements self-traning for semi-supervised learning and domain following the 'MeanTeacher' approach
of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from the
student model via EMA of weights to predict pseudo-labels on unlabeled data.
"""This trainer implements self-training for semi-supervised learning and domain following the 'MeanTeacher'
approach of Tarvainen & Vapola (https://arxiv.org/abs/1703.01780). This approach uses a teacher model derived from
the student model via EMA of weights to predict pseudo-labels on unlabeled data.
We support two training strategies: joint training on labeled and unlabeled data
(with a supervised and unsupervised loss function). And training only on the unsupervised data.

Expand Down
114 changes: 114 additions & 0 deletions torch_em/self_training/probabilistic_unet_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import time
import torch
import torch_em


class DummyLoss(torch.nn.Module):
pass


class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer):
"""This trainer implements training for the 'Probabilistic UNet' of Kohl et al.: (https://arxiv.org/abs/1806.05034).
This approach combines the learnings from UNet and VAEs (Prior and Posterior networks) to obtain generative
segmentations. The heuristic trains by taking into account the feature maps from UNet and the samples from
the posterior distribution, estimating the loss and further sampling from the prior for validation.

Parameters:
clipping_value [float] - (default: None)
prior_samples [int] - (default: 16)
loss [callable] - (default: None)
loss_and_metric [callable] - (default: None)
"""

def __init__(
self,
clipping_value=None,
prior_samples=16,
loss=None,
loss_and_metric=None,
**kwargs
):
super().__init__(loss=loss, **kwargs)
constantinpape marked this conversation as resolved.
Show resolved Hide resolved
assert loss, loss_and_metric is not None

self.loss_and_metric = loss_and_metric

self.clipping_value = clipping_value

self.prior_samples = prior_samples
self.sigmoid = torch.nn.Sigmoid()

self._kwargs = kwargs

#
# functionality for sampling from the network
#

def _sample(self):
samples = [self.model.sample() for _ in range(self.prior_samples)]
return samples

#
# training and validation functionality
#

def _train_epoch_impl(self, progress, forward_context, backprop):
self.model.train()

n_iter = 0
t_per_iter = time.time()

for x, y in self.train_loader:
x, y = x.to(self.device), y.to(self.device)

self.optimizer.zero_grad()

with forward_context():
# We pass the model, the input and the labels to the supervised loss function, so
# that's how the loss is calculated stays flexible, e.g. here to enable ELBO for PUNet.
loss = self.loss(self.model, x, y)

backprop(loss)

# To counter the exploding gradients in the posterior net
if self.clipping_value is not None:
torch.nn.utils.clip_grad_norm_(self.model.posterior.encoder.layers.parameters(), self.clipping_value)

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
samples = self._sample() if self._iteration % self.log_image_interval == 0 else None
self.logger.log_train(self._iteration, loss, lr, x, y, samples)

self._iteration += 1
n_iter += 1
if self._iteration >= self.max_iteration:
break
progress.update(1)

t_per_iter = (time.time() - t_per_iter) / n_iter
return t_per_iter

def _validate_impl(self, forward_context):
self.model.eval()

metric_val = 0.0
loss_val = 0.0

with torch.no_grad():
for x, y in self.val_loader:
x, y = x.to(self.device), y.to(self.device)

with forward_context():
loss, metric = self.loss_and_metric(self.model, x, y)

loss_val += loss.item()
metric_val += metric

metric_val /= len(self.val_loader)
loss_val /= len(self.val_loader)

if self.logger is not None:
samples = self._sample()
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples)

return metric_val