Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Mean-Teacher and FixMatch Self-Training Scheme(s) #116

Merged
merged 8 commits into from
Apr 2, 2023
83 changes: 66 additions & 17 deletions experiments/probabilistic_domain_adaptation/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
from torch_em.util.prediction import predict_with_padding
from torchvision import transforms
from tqdm import tqdm
from torch_em.util import load_model

CELL_TYPES = ["A172", "BT474", "BV2", "Huh7", "MCF7", "SHSY5Y", "SkBr3", "SKOV3"]


#
# The augmentations we use for the LiveCELL experiments:
# - weak augmenations: blurring and additive gaussian noise
# - strong augmentations: TODO
# - weak augmenations:
# blurring and additive gaussian noise
#
# - strong augmentations:
# blurring, additive gaussian noise and randon contrast adjustment
#


Expand All @@ -42,9 +46,33 @@ def weak_augmentations(p=0.25):
return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)


# TODO
def strong_augmentations():
pass
def strong_augmentations(p=0.5, mode=None):
assert mode is not None
norm = torch_em.transform.raw.standardize

if mode == "separate":
aug1 = transforms.Compose([
norm,
transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(1.0, 4.0))], p=p),
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
scale=(0.1, 0.35), clip_kwargs=False)], p=p),
transforms.RandomApply([torch_em.transform.raw.RandomContrast(
mean=0.0, alpha=(0.33, 3), clip_kwargs=False)], p=p),
])

elif mode == "joint":
aug1 = transforms.Compose([
norm,
transforms.RandomApply([torch_em.transform.raw.GaussianBlur(sigma=(0.6, 3.0))], p=p),
transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
scale=(0.05, 0.25), clip_kwargs=False)], p=p/2
),
transforms.RandomApply([torch_em.transform.raw.RandomContrast(
mean=0.0, alpha=(0.33, 3.0), clip_kwargs=False)], p=p
)
])

return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug1)


#
Expand All @@ -55,12 +83,28 @@ def get_unet():
return UNet2d(in_channels=1, out_channels=1, initial_features=64, final_activation="Sigmoid", depth=4)


def load_model(model, ckpt, state="model_state", device=None):
state = torch.load(os.path.join(ckpt, "best.pt"))[state]
model.load_state_dict(state)
if device is not None:
model.to(device)
return model
# Computing the Source Distribution for Distribution Alignment
def compute_class_distribution(root_folder, label_threshold=0.5):

bg_list, fg_list = [], []
total = 0

files = glob(os.path.join(root_folder, "*"))
assert len(files) > 0, f"Did not find predictions @ {root_folder}"

for pl_path in files:
img = imageio.imread(pl_path)
img = np.where(img >= label_threshold, 1, 0)
_, counts = np.unique(img, return_counts=True)
assert len(counts) == 2
bg_list.append(counts[0])
fg_list.append(counts[1])
total += img.size

bg_frequency = sum(bg_list) / float(total)
fg_frequency = sum(fg_list) / float(total)
assert np.isclose(bg_frequency + fg_frequency, 1.0)
return [bg_frequency, fg_frequency]


# use get_model and prediction_function to customize this, e.g. for using it with the PUNet
Expand Down Expand Up @@ -88,9 +132,12 @@ def evaluate_transfered_model(
if out_folder is not None:
os.makedirs(out_folder, exist_ok=True)

ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
if args.save_root is None:
ckpt = f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
else:
ckpt = args.save_root + f"checkpoints/{method}/thresh-{thresh}/{ct_src}/{ct_trg}"
model = get_model()
model = load_model(model, ckpt, device=device, state=model_state)
model = load_model(checkpoint=ckpt, model=model, state_key=model_state, device=device)

label_paths = glob(os.path.join(label_root, ct_trg, "*.tif"))
scores = []
Expand Down Expand Up @@ -190,14 +237,16 @@ def _get_image_paths(args, split, cell_type):
return image_paths


def get_unsupervised_loader(args, split, cell_type, teacher_augmentation, student_augmentation):
def get_unsupervised_loader(args, batch_size, split, cell_type, teacher_augmentation, student_augmentation):
patch_shape = (256, 256)

def _parse_aug(aug):
if aug == "weak":
return weak_augmentations()
elif aug == "strong":
return strong_augmentations()
elif aug == "strong-separate":
return strong_augmentations(mode="separate")
elif aug == "strong-joint":
return strong_augmentations(mode="joint")
assert callable(aug)
return aug

Expand All @@ -211,7 +260,7 @@ def _parse_aug(aug):
image_paths, patch_shape, raw_transform, transform,
augmentations=augmentations
)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=args.batch_size, num_workers=8, shuffle=True)
loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=8, shuffle=True)
return loader


Expand Down
128 changes: 128 additions & 0 deletions experiments/probabilistic_domain_adaptation/livecell/unet_adamatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os

import pandas as pd
import torch
import torch_em.self_training as self_training

import common


def check_loader(args, n_images=5):
from torch_em.util.debug import check_loader

cell_types = args.cell_types
print("The cell types", cell_types, "were selected.")
print("Checking the unsupervised loader for the first cell type", cell_types[0])

loader = common.get_unsupervised_loader(
args, "train", cell_types[0],
teacher_augmentation="weak", student_augmentation="strong",
)
check_loader(loader, n_images)


def _train_source_target(args, source_cell_type, target_cell_type):
model = common.get_unet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

# self training functionality
thresh = args.confidence_threshold
pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=thresh)
loss = self_training.DefaultSelfTrainingLoss()
loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()

# data loaders
supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size)
supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1)
unsupervised_train_loader = common.get_unsupervised_loader(
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-joint",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="strong-joint",
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

name = f"unet_adamatch/thresh-{thresh}"

if args.distribution_alignment:
assert args.output is not None
print(f"Getting scores for Source {source_cell_type} at Targets {target_cell_type}")
pred_folder = args.output + f"unet_source/{source_cell_type}/{target_cell_type}/"
src_dist = common.compute_class_distribution(pred_folder)
name = f"{name}-distro-align"
else:
src_dist = None

name = name + f"/{source_cell_type}/{target_cell_type}"

trainer = self_training.FixMatchTrainer(
name=name,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
pseudo_labeler=pseudo_labeler,
unsupervised_loss=loss,
unsupervised_loss_and_metric=loss_and_metric,
supervised_train_loader=supervised_train_loader,
unsupervised_train_loader=unsupervised_train_loader,
supervised_val_loader=supervised_val_loader,
unsupervised_val_loader=unsupervised_val_loader,
supervised_loss=loss,
supervised_loss_and_metric=loss_and_metric,
logger=self_training.SelfTrainingTensorboardLogger,
mixed_precision=True,
device=device,
log_image_interval=100,
save_root=args.save_root,
source_distribution=src_dist
)
trainer.fit(args.n_iterations)


def _train_source(args, cell_type):
for target_cell_type in common.CELL_TYPES:
if target_cell_type == cell_type:
continue
_train_source_target(args, cell_type, target_cell_type)


def run_training(args):
for cell_type in args.cell_types:
print("Start training for cell type:", cell_type)
_train_source(args, cell_type)


def run_evaluation(args):
results = []
for ct in args.cell_types:
res = common.evaluate_transfered_model(args, ct, "unet_adamatch")
results.append(res)
results = pd.concat(results)
print("Evaluation results:")
print(results)
result_folder = "./results"
os.makedirs(result_folder, exist_ok=True)
results.to_csv(os.path.join(result_folder, "unet_adamatch.csv"), index=False)


def main():
parser = common.get_parser(default_iterations=25000, default_batch_size=8)
parser.add_argument("--confidence_threshold", default=None, type=float)
parser.add_argument("--distribution_alignment", action='store_true', help="Activates Distribution Alignment")
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
elif args.phase in ("t", "train"):
run_training(args)
elif args.phase in ("e", "evaluate"):
run_evaluation(args)
else:
raise ValueError(f"Got phase={args.phase}, expect one of check, train, evaluate.")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def check_loader(args, n_images=5):


def _train_source_target(args, source_cell_type, target_cell_type):
model = common.get_model()
model = common.get_unet()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

Expand All @@ -36,11 +36,11 @@ def _train_source_target(args, source_cell_type, target_cell_type):
supervised_train_loader = common.get_supervised_loader(args, "train", source_cell_type, args.batch_size)
supervised_val_loader = common.get_supervised_loader(args, "val", source_cell_type, 1)
unsupervised_train_loader = common.get_unsupervised_loader(
args, "train", target_cell_type,
args, args.batch_size, "train", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)
unsupervised_val_loader = common.get_unsupervised_loader(
args, "val", target_cell_type,
args, 1, "val", target_cell_type,
teacher_augmentation="weak", student_augmentation="weak",
)

Expand Down Expand Up @@ -98,7 +98,7 @@ def run_evaluation(args):

def main():
parser = common.get_parser(default_iterations=75000, default_batch_size=4)
parser.add_argument("--confidence_threshold", default=0.9)
parser.add_argument("--confidence_threshold", default=None, type=float)
args = parser.parse_args()
if args.phase in ("c", "check"):
check_loader(args)
Expand Down
Loading