diff --git a/deepprofiler/learning/training.py b/deepprofiler/learning/training.py index db186664..7edfd679 100644 --- a/deepprofiler/learning/training.py +++ b/deepprofiler/learning/training.py @@ -9,7 +9,9 @@ def learn_model(config, dset, epoch=1, seed=None, verbose=1): model_module = importlib.import_module("plugins.models.{}".format(config["train"]["model"]["name"])) crop_module = importlib.import_module("plugins.crop_generators.{}".format(config["train"]["model"]["crop_generator"])) - #config["num_classes"] = len(dset.training_images["Target"].unique()) + if config["train"]["model"]["crop_generator"] == 'online_labels_cropgen': + config["train"]["model"]["params"]["label_smoothing"] = 0 + if "metrics" in config["train"]["model"].keys(): if type(config["train"]["model"]["metrics"]) not in [list, dict]: raise ValueError("Metrics should be a list or dictionary.") diff --git a/plugins/crop_generators/online_labels_cropgen.py b/plugins/crop_generators/online_labels_cropgen.py index 2b8cf38d..c55392e9 100644 --- a/plugins/crop_generators/online_labels_cropgen.py +++ b/plugins/crop_generators/online_labels_cropgen.py @@ -19,9 +19,8 @@ class GeneratorClass(deepprofiler.imaging.cropping.CropGenerator): - def __init__(self, config, dset, mode="Training"): + def __init__(self, config, dset, mode="training"): super(GeneratorClass, self).__init__(config, dset) - #self.datagen = tf.keras.preprocessing.image.ImageDataGenerator() self.directory = config["paths"]["single_cell_set"] self.num_channels = len(config["dataset"]["images"]["channels"]) self.box_size = self.config["dataset"]["locations"]["box_size"] @@ -29,22 +28,12 @@ def __init__(self, config, dset, mode="Training"): self.mode = mode # Load metadata - self.all_cells = pd.read_csv(os.path.join(self.directory, "expanded_sc_metadata_tengenes.csv")) - - ## UNCOMMENT FOR ALPHA SET - #self.all_cells.loc[(self.all_cells.Training_Status == "Unused") & self.all_cells.Metadata_Plate.isin([41756,41757]), "Training_Status_Alpha"] = "Validation" - - ## UNCOMMENT FOR SINGLE CELL BALANCED SET - self.all_cells.loc[self.all_cells.Training_Status == "Training", "Training_Status"] = "XXX" - self.all_cells.loc[self.all_cells.Training_Status == "SingleCellTraining", "Training_Status"] = "Training" - self.all_cells.loc[self.all_cells.Training_Status == "Validation", "Training_Status"] = "YYY" - self.all_cells.loc[self.all_cells.Training_Status == "SingleCellValidation", "Training_Status"] = "Validation" - + self.all_cells = pd.read_csv(config["paths"]["sc_index"]) self.target = config["train"]["partition"]["targets"][0] # Index targets for one-hot encoded labels - #self.split_data = self.all_cells[self.all_cells.Training_Status_TenGenes == self.mode].reset_index(drop=True) - self.split_data = self.all_cells[self.all_cells.Training_Status == self.mode].reset_index(drop=True) + self.split_data = self.all_cells[self.all_cells[self.config["train"]["partition"]["split_field"]].isin( + self.config["train"]["partition"][self.mode])].reset_index(drop=True) self.classes = list(self.split_data[self.target].unique()) self.num_classes = len(self.classes) self.classes.sort() @@ -52,43 +41,39 @@ def __init__(self, config, dset, mode="Training"): # Identify targets and samples self.balanced_sample() - self.expected_steps = (self.samples.shape[0] // self.batch_size) + int(self.samples.shape[0] % self.batch_size > 0) + self.expected_steps = (self.samples.shape[0] // self.batch_size) + \ + int(self.samples.shape[0] % self.batch_size > 0) # Report number of classes globally self.config["num_classes"] = self.num_classes print(" >> Number of classes:", self.num_classes) # Online labels - if self.mode == "Training": + if self.mode == "training": self.out_dir = config["paths"]["results"] + "soft_labels/" os.makedirs(self.out_dir, exist_ok=True) self.init_online_labels() - def start(self, session): pass def balanced_sample(self): # Obtain distribution of single cells per class - counts = self.split_data.groupby("Class_Name").count().reset_index()[["Class_Name", "Key"]] + counts = self.split_data.groupby(self.target).count().reset_index()[[self.target, "Key"]] sample_size = int(counts.Key.median()) - counts = {r.Class_Name: r.Key for k,r in counts.iterrows()} + counts = {r[self.target]: r.Key for k, r in counts.iterrows()} # Sample the same number of cells per class class_samples = [] - for cls in self.split_data.Class_Name.unique(): - class_samples.append(self.split_data[self.split_data.Class_Name == cls].sample(n=sample_size, replace=counts[cls] < sample_size)) + for cls in self.split_data[self.target].unique(): + class_samples.append(self.split_data[self.split_data[self.target] == cls].sample( + n=sample_size, replace=counts[cls] < sample_size)) self.samples = pd.concat(class_samples) # Randomize order - if self.mode == "Training": - self.samples = self.samples.sample(frac=1.0).reset_index() - else: - self.samples = self.samples.sample(frac=0.1).reset_index() - print(self.samples[self.target].value_counts()) - + self.samples = self.samples.sample(frac=1.0).reset_index() - def generate(self, sess, global_step=0): + def generator(self, sess, global_step=0): pointer = 0 while True: x = np.zeros([self.batch_size, self.box_size, self.box_size, self.num_channels]) @@ -99,13 +84,12 @@ def generate(self, sess, global_step=0): pointer = 0 filename = os.path.join(self.directory, self.samples.loc[pointer, "Image_Name"]) im = skimage.io.imread(filename).astype(np.float32) - x[i,:,:,:] = deepprofiler.imaging.cropping.fold_channels(im) + x[i, :, :, :] = deepprofiler.imaging.cropping.fold_channels(im) y.append([self.soft_labels[self.samples.loc[pointer, "index"], :]]) pointer += 1 - yield(x, np.concatenate(y, axis=0)) - + yield x, np.concatenate(y, axis=0) - def generator(self, source="samples"): + def generate(self, source="samples"): pointer = 0 if source == "splits": dataframe = self.split_data @@ -129,11 +113,10 @@ def generator(self, source="samples"): pointer += 1 if len(y) < x.shape[0]: x = x[0:len(y), ...] - yield(x, tf.keras.utils.to_categorical(y, num_classes = self.num_classes)) - + yield x, tf.keras.utils.to_categorical(y, num_classes=self.num_classes) def init_online_labels(self): - LABEL_SMOOTHING = 0.2 + LABEL_SMOOTHING = self.config["train"]["model"]["params"]["online_label_smoothing"] self.soft_labels = np.zeros((self.split_data.shape[0], self.num_classes)) + LABEL_SMOOTHING/self.num_classes print("Soft labels:", self.soft_labels.shape) for k, r in self.split_data.iterrows(): @@ -143,17 +126,18 @@ def init_online_labels(self): sl = pd.DataFrame(data=self.soft_labels) sl.to_csv(self.out_dir + "0000.csv", index=False) - def update_online_labels(self, model, epoch): # Prepare parameters and predictions - LAMBDA = 0.01 + LAMBDA = self.config["train"]["model"]["params"]["online_lambda"] predictions = [] # Get predictions with the model - model.get_layer("augmentation_layer").is_training = False - for batch in self.generator(source = "splits"): + if model.layers[1].name == 'augmentation_layer': + model.get_layer("augmentation_layer").is_training = False + for batch in self.generate(source="splits"): predictions.append(model.predict(batch[0])) - model.get_layer("augmentation_layer").is_training = True + if model.layers[1].name == 'augmentation_layer': + model.get_layer("augmentation_layer").is_training = True # Update soft labels predictions = np.concatenate(predictions, axis=0) @@ -164,7 +148,6 @@ def update_online_labels(self, model, epoch): sl = pd.DataFrame(data=self.soft_labels) sl.to_csv(self.out_dir + "{:04d}.csv".format(epoch+1), index=False) - def stop(self, session): pass diff --git a/tests/files/config/test.json b/tests/files/config/test.json index 3c542ab8..2ac897d6 100644 --- a/tests/files/config/test.json +++ b/tests/files/config/test.json @@ -56,7 +56,9 @@ "label_smoothing": 0.0, "feature_dim": 100, "latent_dim": 100, - "epsilon_std": 1.0 + "epsilon_std": 1.0, + "online_label_smoothing": 0.1, + "online_lambda": 0.01 }, "lr_schedule" : { "epoch": [1,3,5],