From 0bf2e173248796d5365c2caaf87f6549c35b910d Mon Sep 17 00:00:00 2001 From: arkkienkeli Date: Fri, 8 Oct 2021 13:11:58 +0200 Subject: [PATCH 1/4] Changes to online_labels_cropgen.py --- .../crop_generators/online_labels_cropgen.py | 59 ++++++++----------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/plugins/crop_generators/online_labels_cropgen.py b/plugins/crop_generators/online_labels_cropgen.py index 2b8cf38d..a7bf9c0e 100644 --- a/plugins/crop_generators/online_labels_cropgen.py +++ b/plugins/crop_generators/online_labels_cropgen.py @@ -19,9 +19,8 @@ class GeneratorClass(deepprofiler.imaging.cropping.CropGenerator): - def __init__(self, config, dset, mode="Training"): + def __init__(self, config, dset, mode="training"): super(GeneratorClass, self).__init__(config, dset) - #self.datagen = tf.keras.preprocessing.image.ImageDataGenerator() self.directory = config["paths"]["single_cell_set"] self.num_channels = len(config["dataset"]["images"]["channels"]) self.box_size = self.config["dataset"]["locations"]["box_size"] @@ -29,22 +28,12 @@ def __init__(self, config, dset, mode="Training"): self.mode = mode # Load metadata - self.all_cells = pd.read_csv(os.path.join(self.directory, "expanded_sc_metadata_tengenes.csv")) - - ## UNCOMMENT FOR ALPHA SET - #self.all_cells.loc[(self.all_cells.Training_Status == "Unused") & self.all_cells.Metadata_Plate.isin([41756,41757]), "Training_Status_Alpha"] = "Validation" - - ## UNCOMMENT FOR SINGLE CELL BALANCED SET - self.all_cells.loc[self.all_cells.Training_Status == "Training", "Training_Status"] = "XXX" - self.all_cells.loc[self.all_cells.Training_Status == "SingleCellTraining", "Training_Status"] = "Training" - self.all_cells.loc[self.all_cells.Training_Status == "Validation", "Training_Status"] = "YYY" - self.all_cells.loc[self.all_cells.Training_Status == "SingleCellValidation", "Training_Status"] = "Validation" - + self.all_cells = pd.read_csv(config["paths"]["sc_index"]) self.target = config["train"]["partition"]["targets"][0] # Index targets for one-hot encoded labels - #self.split_data = self.all_cells[self.all_cells.Training_Status_TenGenes == self.mode].reset_index(drop=True) - self.split_data = self.all_cells[self.all_cells.Training_Status == self.mode].reset_index(drop=True) + self.split_data = self.all_cells[self.all_cells[self.config["train"]["partition"]["split_field"]].isin( + self.config["train"]["partition"][self.mode])].reset_index(drop=True) self.classes = list(self.split_data[self.target].unique()) self.num_classes = len(self.classes) self.classes.sort() @@ -52,43 +41,43 @@ def __init__(self, config, dset, mode="Training"): # Identify targets and samples self.balanced_sample() - self.expected_steps = (self.samples.shape[0] // self.batch_size) + int(self.samples.shape[0] % self.batch_size > 0) + self.expected_steps = (self.samples.shape[0] // self.batch_size) + \ + int(self.samples.shape[0] % self.batch_size > 0) # Report number of classes globally self.config["num_classes"] = self.num_classes print(" >> Number of classes:", self.num_classes) # Online labels - if self.mode == "Training": + if self.mode == "training": self.out_dir = config["paths"]["results"] + "soft_labels/" os.makedirs(self.out_dir, exist_ok=True) self.init_online_labels() - def start(self, session): pass def balanced_sample(self): # Obtain distribution of single cells per class - counts = self.split_data.groupby("Class_Name").count().reset_index()[["Class_Name", "Key"]] + counts = self.split_data.groupby(self.target).count().reset_index()[[self.target, "Key"]] sample_size = int(counts.Key.median()) - counts = {r.Class_Name: r.Key for k,r in counts.iterrows()} + counts = {r[self.target]: r.Key for k, r in counts.iterrows()} # Sample the same number of cells per class class_samples = [] - for cls in self.split_data.Class_Name.unique(): - class_samples.append(self.split_data[self.split_data.Class_Name == cls].sample(n=sample_size, replace=counts[cls] < sample_size)) + for cls in self.split_data[self.target].unique(): + class_samples.append(self.split_data[self.split_data[self.target] == cls].sample( + n=sample_size, replace=counts[cls] < sample_size)) self.samples = pd.concat(class_samples) # Randomize order - if self.mode == "Training": + if self.mode == "training": self.samples = self.samples.sample(frac=1.0).reset_index() else: self.samples = self.samples.sample(frac=0.1).reset_index() print(self.samples[self.target].value_counts()) - - def generate(self, sess, global_step=0): + def generator(self, sess, global_step=0): pointer = 0 while True: x = np.zeros([self.batch_size, self.box_size, self.box_size, self.num_channels]) @@ -99,13 +88,12 @@ def generate(self, sess, global_step=0): pointer = 0 filename = os.path.join(self.directory, self.samples.loc[pointer, "Image_Name"]) im = skimage.io.imread(filename).astype(np.float32) - x[i,:,:,:] = deepprofiler.imaging.cropping.fold_channels(im) + x[i, :, :, :] = deepprofiler.imaging.cropping.fold_channels(im) y.append([self.soft_labels[self.samples.loc[pointer, "index"], :]]) pointer += 1 - yield(x, np.concatenate(y, axis=0)) - + yield x, np.concatenate(y, axis=0) - def generator(self, source="samples"): + def generate(self, source="samples"): pointer = 0 if source == "splits": dataframe = self.split_data @@ -129,8 +117,7 @@ def generator(self, source="samples"): pointer += 1 if len(y) < x.shape[0]: x = x[0:len(y), ...] - yield(x, tf.keras.utils.to_categorical(y, num_classes = self.num_classes)) - + yield x, tf.keras.utils.to_categorical(y, num_classes=self.num_classes) def init_online_labels(self): LABEL_SMOOTHING = 0.2 @@ -143,17 +130,18 @@ def init_online_labels(self): sl = pd.DataFrame(data=self.soft_labels) sl.to_csv(self.out_dir + "0000.csv", index=False) - def update_online_labels(self, model, epoch): # Prepare parameters and predictions LAMBDA = 0.01 predictions = [] # Get predictions with the model - model.get_layer("augmentation_layer").is_training = False - for batch in self.generator(source = "splits"): + if model.layers[1].name == 'augmentation_layer': + model.get_layer("augmentation_layer").is_training = False + for batch in self.generate(source="splits"): predictions.append(model.predict(batch[0])) - model.get_layer("augmentation_layer").is_training = True + if model.layers[1].name == 'augmentation_layer': + model.get_layer("augmentation_layer").is_training = True # Update soft labels predictions = np.concatenate(predictions, axis=0) @@ -164,7 +152,6 @@ def update_online_labels(self, model, epoch): sl = pd.DataFrame(data=self.soft_labels) sl.to_csv(self.out_dir + "{:04d}.csv".format(epoch+1), index=False) - def stop(self, session): pass From 61ffc6fa60407415fe54104c2fddac4f93c5de7c Mon Sep 17 00:00:00 2001 From: arkkienkeli Date: Fri, 8 Oct 2021 13:34:22 +0200 Subject: [PATCH 2/4] Parameterize lambda and label smoothing for online_labels_cropgen.py --- plugins/crop_generators/online_labels_cropgen.py | 4 ++-- tests/files/config/test.json | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/plugins/crop_generators/online_labels_cropgen.py b/plugins/crop_generators/online_labels_cropgen.py index a7bf9c0e..93b991de 100644 --- a/plugins/crop_generators/online_labels_cropgen.py +++ b/plugins/crop_generators/online_labels_cropgen.py @@ -120,7 +120,7 @@ def generate(self, source="samples"): yield x, tf.keras.utils.to_categorical(y, num_classes=self.num_classes) def init_online_labels(self): - LABEL_SMOOTHING = 0.2 + LABEL_SMOOTHING = self.config["train"]["model"]["params"]["online_label_smoothing"] self.soft_labels = np.zeros((self.split_data.shape[0], self.num_classes)) + LABEL_SMOOTHING/self.num_classes print("Soft labels:", self.soft_labels.shape) for k, r in self.split_data.iterrows(): @@ -132,7 +132,7 @@ def init_online_labels(self): def update_online_labels(self, model, epoch): # Prepare parameters and predictions - LAMBDA = 0.01 + LAMBDA = self.config["train"]["model"]["params"]["online_lambda"] predictions = [] # Get predictions with the model diff --git a/tests/files/config/test.json b/tests/files/config/test.json index 3c542ab8..2ac897d6 100644 --- a/tests/files/config/test.json +++ b/tests/files/config/test.json @@ -56,7 +56,9 @@ "label_smoothing": 0.0, "feature_dim": 100, "latent_dim": 100, - "epsilon_std": 1.0 + "epsilon_std": 1.0, + "online_label_smoothing": 0.1, + "online_lambda": 0.01 }, "lr_schedule" : { "epoch": [1,3,5], From 4bf77b93e4252a072ff239851ddbe6b6d8cfebbf Mon Sep 17 00:00:00 2001 From: arkkienkeli Date: Tue, 12 Oct 2021 11:18:16 +0200 Subject: [PATCH 3/4] Remove distinction between training and validation for randomizing data --- plugins/crop_generators/online_labels_cropgen.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/plugins/crop_generators/online_labels_cropgen.py b/plugins/crop_generators/online_labels_cropgen.py index 93b991de..c55392e9 100644 --- a/plugins/crop_generators/online_labels_cropgen.py +++ b/plugins/crop_generators/online_labels_cropgen.py @@ -71,11 +71,7 @@ def balanced_sample(self): self.samples = pd.concat(class_samples) # Randomize order - if self.mode == "training": - self.samples = self.samples.sample(frac=1.0).reset_index() - else: - self.samples = self.samples.sample(frac=0.1).reset_index() - print(self.samples[self.target].value_counts()) + self.samples = self.samples.sample(frac=1.0).reset_index() def generator(self, sess, global_step=0): pointer = 0 From a487615b6b576df0111427a18bbc4f3c17562af0 Mon Sep 17 00:00:00 2001 From: arkkienkeli Date: Tue, 12 Oct 2021 11:34:26 +0200 Subject: [PATCH 4/4] Label smoothing is always 0 for cross-entropy if online labels crop generator is used. --- deepprofiler/learning/training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepprofiler/learning/training.py b/deepprofiler/learning/training.py index db186664..7edfd679 100644 --- a/deepprofiler/learning/training.py +++ b/deepprofiler/learning/training.py @@ -9,7 +9,9 @@ def learn_model(config, dset, epoch=1, seed=None, verbose=1): model_module = importlib.import_module("plugins.models.{}".format(config["train"]["model"]["name"])) crop_module = importlib.import_module("plugins.crop_generators.{}".format(config["train"]["model"]["crop_generator"])) - #config["num_classes"] = len(dset.training_images["Target"].unique()) + if config["train"]["model"]["crop_generator"] == 'online_labels_cropgen': + config["train"]["model"]["params"]["label_smoothing"] = 0 + if "metrics" in config["train"]["model"].keys(): if type(config["train"]["model"]["metrics"]) not in [list, dict]: raise ValueError("Metrics should be a list or dictionary.")