Skip to content

Commit

Permalink
Merge pull request #116 from anwai98/main
Browse files Browse the repository at this point in the history
Update Mean-Teacher and FixMatch Self-Training Scheme(s)
  • Loading branch information
constantinpape committed Apr 2, 2023
2 parents 63fc9bb + 6be42fa commit 944e1cf
Show file tree
Hide file tree
Showing 9 changed files with 741 additions and 32 deletions.
59 changes: 59 additions & 0 deletions experiments/probabilistic_domain_adaptation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Probabilistic Domain Adaption

Implemention of [Probabilistic Domain Adaptation for Biomedical Image Segmentation](https://arxiv.org/abs/2303.11790) in `torch_em`.
Please cite the paper if you are using these approaches in your research.

## Self-Training Approaches

The subfolders contain the training scripts for both separate and joint training setups:

- `unet_source.py` (UNet Source Training):
```
python unet_source.py -p [check / train / evaluate]
-c <CELL-TYPE>
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
```

- `unet_mean_teacher.py` (UNet Mean-Teacher Separate Training):
```
python unet_mean_teacher.py -p [check / train / evaluate]
-c <CELL-TYPE>
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
```

- `unet_adamt.py` (UNet Mean-Teacher Joint Training):
```
python unet_adamt.py -p [check / train / evaluate]
-c <CELL-TYPE>
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
```

- `unet_fixmatch.py` (UNet FixMatch Separate Training):
```
python unet_fixmatch.py -p [check / train / evaluate]
-c <CELL-TYPE>
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
[(optional) --distribution_alignment <ACTIVATES-DISTRIBUTION-ALIGNMENT>]
```

- `unet_adamatch.py` (UNet FixMatch Joint Training):
```
python unet_adamatch.py -p [check / train / evaluate]
-c <CELL-TYPE>
-i <PATH-TO-DATA>
-s <PATH-TO-SAVE-MODEL-WEIGHTS>
-o <PATH-FOR-SAVING-PREDICTIONS>
[(optional) --confidence_threshold <THRESHOLD-FOR-COMPUTING-FILTER-MASK>]
[(optional) --distribution_alignment <ACTIVATES-DISTRIBUTION-ALIGNMENT>]
```
83 changes: 66 additions & 17 deletions experiments/probabilistic_domain_adaptation/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
from torch_em.util.prediction import predict_with_padding
from torchvision import transforms
from tqdm import tqdm
from torch_em.util import load_model

CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]


#
# The augmentations we use for the LiveCELL experiments:
# - weak augmenations: blurring and additive gaussian noise
# - strong augmentations: TODO
# - weak augmenations:
# blurring and additive gaussian noise
#
# - strong augmentations:
# blurring, additive gaussian noise and randon contrast adjustment
#


Expand All @@ -42,9 +46,33 @@ def weak_augmentations(p=0.25):
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)


# TODO
def strong_augmentations():
pass
def strong_augmentations(p=0.5, mode=None):
assert mode is not None
norm = torch_em.transform.raw.standardize

if mode == "separate":
aug1 = transforms.Compose([
norm,
transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(1.0, 4.0))], p=p),
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
scale=(0.1, 0.35), clip_kwargs=False)], p=p),
transforms.RandomApply([torch_em.transform.raw.RandomContrast(
mean=0.0, alpha=(0.33, 3), clip_kwargs=False)], p=p),
])

elif mode == "joint":
aug1 = transforms.Compose([
norm,
transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(0.6, 3.0))], p=p),
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
scale=(0.05, 0.25), clip_kwargs=False)], p=p/2
),
transforms.RandomApply([torch_em.transform.raw.RandomContrast(
mean=0.0, alpha=(0.33, 3.0), clip_kwargs=False)], p=p
)
])

return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug1)


#
Expand All @@ -55,12 +83,28 @@ def get_unet():
return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4)


def load_model(model, ckpt, state="model_state", device=None):
state = torch.load(os.path.join(ckpt, "best.pt"))[state]
model.load_state_dict(state)
if device is not None:
model.to(device)
return model
# Computing the Source Distribution for Distribution Alignment
def compute_class_distribution(root_folder, label_threshold=0.5):

bg_list, fg_list = [], []
total = 0

files = glob(os.path.join(root_folder, "*"))
assert len(files) > 0, f"Did not find predictions @ {root_folder}"

for pl_path in files:
img = imageio.imread(pl_path)
img = np.where(img >= label_threshold, 1, 0)
_, counts = np.unique(img, return_counts=True)
assert len(counts) == 2
bg_list.append(counts[0])
fg_list.append(counts[1])
total += img.size

bg_frequency = sum(bg_list) / float(total)
fg_frequency = sum(fg_list) / float(total)
assert np.isclose(bg_frequency + fg_frequency, 1.0)
return [bg_frequency, fg_frequency]


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
Expand Down Expand Up @@ -88,9 +132,12 @@ def evaluate_transfered_model(
if out_folder is not None:
os.makedirs(out_folder, exist_ok=True)

ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
if args.save_root is None:
ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
else:
ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
model = get_model()
model = load_model(model, ckpt, device=device, state=model_state)
model = load_model(checkpoint=ckpt, model=model, state_key=model_state, device=device)

label_paths = glob(os.path.join(label_root, ct_trg, "*.tif"))
scores = []
Expand Down Expand Up @@ -190,14 +237,16 @@ def _get_image_paths(args, split, cell_type):
return image_paths


def get_unsupervised_loader(args, split, cell_type, teacher_augmentation, student_augmentation):
def get_unsupervised_loader(args, batch_size, split, cell_type, teacher_augmentation, student_augmentation):
patch_shape = (256, 256)

def _parse_aug(aug):
if aug == "weak":
return weak_augmentations()
elif aug == "strong":
return strong_augmentations()
elif aug == "strong-separate":
return strong_augmentations(mode="separate")
elif aug == "strong-joint":
return strong_augmentations(mode="joint")
assert callable(aug)
return aug

Expand All @@ -211,7 +260,7 @@ def _parse_aug(aug):
image_paths, patch_shape, raw_transform, transform,
augmentations=augmentations
)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=args.batch_size, num_workers=8, shuffle=True)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=8, shuffle=True)
return loader


Expand Down
128 changes: 128 additions & 0 deletions experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os

import pandas as pd
import torch
import torch_em.self_training as self_training

import common


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the unsupervised loader for the first cell type", cell_types[0])

loader = common.get_unsupervised_loader(
args, "train", cell_types[0],
teacher_augmentation="weak", student_augmentation="strong",
)
check_loader(loader, n_images)


def _train_source_target(args, source_cell_type, target_cell_type):
model = common.get_unet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

# self training functionality
thresh = args.confidence_threshold
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh)
loss = self_training.DefaultSelfTrainingLoss()
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

# data loaders
supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size)
supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1)
unsupervised_train_loader = common.get_unsupervised_loader(
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-joint",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-joint",
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

name = f"unet_adamatch/thresh-{thresh}"

if args.distribution_alignment:
assert args.output is not None
print(f"Getting scores for Source {source_cell_type} at Targets {target_cell_type}")
pred_folder = args.output + f"unet_source/{source_cell_type}/{target_cell_type}/"
src_dist = common.compute_class_distribution(pred_folder)
name = f"{name}-distro-align"
else:
src_dist = None

name = name + f"/{source_cell_type}/{target_cell_type}"

trainer = self_training.FixMatchTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
supervised_train_loader=supervised_train_loader,
unsupervised_train_loader=unsupervised_train_loader,
supervised_val_loader=supervised_val_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=self_training.SelfTrainingTensorboardLogger,
mixed_precision=True,
device=device,
log_image_interval=100,
save_root=args.save_root,
source_distribution=src_dist
)
trainer.fit(args.n_iterations)


def _train_source(args, cell_type):
for target_cell_type in common.CELL_TYPES:
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
_train_source(args, cell_type)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_transfered_model(args, ct, "unet_adamatch")
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "unet_adamatch.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=25000, default_batch_size=8)
parser.add_argument("--confidence_threshold", default=None, type=float)
parser.add_argument("--distribution_alignment", action='store_true', help="Activates Distribution Alignment")
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_loader(args, n_images=5):


def _train_source_target(args, source_cell_type, target_cell_type):
model = common.get_model()
model = common.get_unet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

Expand All @@ -36,11 +36,11 @@ def _train_source_target(args, source_cell_type, target_cell_type):
supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size)
supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1)
unsupervised_train_loader = common.get_unsupervised_loader(
args, "train", target_cell_type,
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, "val", target_cell_type,
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)

Expand Down Expand Up @@ -98,7 +98,7 @@ def run_evaluation(args):

def main():
parser = common.get_parser(default_iterations=75000, default_batch_size=4)
parser.add_argument("--confidence_threshold", default=0.9)
parser.add_argument("--confidence_threshold", default=None, type=float)
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
Expand Down
Loading

0 comments on commit 944e1cf

Please sign in to comment.