Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepprofiler/learning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
def learn_model(config, dset, epoch=1, seed=None, verbose=1):
model_module = importlib.import_module("plugins.models.{}".format(config["train"]["model"]["name"]))
crop_module = importlib.import_module("plugins.crop_generators.{}".format(config["train"]["model"]["crop_generator"]))
#config["num_classes"] = len(dset.training_images["Target"].unique())
if config["train"]["model"]["crop_generator"] == 'online_labels_cropgen':
config["train"]["model"]["params"]["label_smoothing"] = 0

if "metrics" in config["train"]["model"].keys():
if type(config["train"]["model"]["metrics"]) not in [list, dict]:
raise ValueError("Metrics should be a list or dictionary.")
Expand Down
67 changes: 25 additions & 42 deletions plugins/crop_generators/online_labels_cropgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,76 +19,61 @@

class GeneratorClass(deepprofiler.imaging.cropping.CropGenerator):

def __init__(self, config, dset, mode="Training"):
def __init__(self, config, dset, mode="training"):
super(GeneratorClass, self).__init__(config, dset)
#self.datagen = tf.keras.preprocessing.image.ImageDataGenerator()
self.directory = config["paths"]["single_cell_set"]
self.num_channels = len(config["dataset"]["images"]["channels"])
self.box_size = self.config["dataset"]["locations"]["box_size"]
self.batch_size = self.config["train"]["model"]["params"]["batch_size"]
self.mode = mode

# Load metadata
self.all_cells = pd.read_csv(os.path.join(self.directory, "expanded_sc_metadata_tengenes.csv"))

## UNCOMMENT FOR ALPHA SET
#self.all_cells.loc[(self.all_cells.Training_Status == "Unused") & self.all_cells.Metadata_Plate.isin([41756,41757]), "Training_Status_Alpha"] = "Validation"

## UNCOMMENT FOR SINGLE CELL BALANCED SET
self.all_cells.loc[self.all_cells.Training_Status == "Training", "Training_Status"] = "XXX"
self.all_cells.loc[self.all_cells.Training_Status == "SingleCellTraining", "Training_Status"] = "Training"
self.all_cells.loc[self.all_cells.Training_Status == "Validation", "Training_Status"] = "YYY"
self.all_cells.loc[self.all_cells.Training_Status == "SingleCellValidation", "Training_Status"] = "Validation"

self.all_cells = pd.read_csv(config["paths"]["sc_index"])
self.target = config["train"]["partition"]["targets"][0]

# Index targets for one-hot encoded labels
#self.split_data = self.all_cells[self.all_cells.Training_Status_TenGenes == self.mode].reset_index(drop=True)
self.split_data = self.all_cells[self.all_cells.Training_Status == self.mode].reset_index(drop=True)
self.split_data = self.all_cells[self.all_cells[self.config["train"]["partition"]["split_field"]].isin(
self.config["train"]["partition"][self.mode])].reset_index(drop=True)
self.classes = list(self.split_data[self.target].unique())
self.num_classes = len(self.classes)
self.classes.sort()
self.classes = {self.classes[i]: i for i in range(self.num_classes)}

# Identify targets and samples
self.balanced_sample()
self.expected_steps = (self.samples.shape[0] // self.batch_size) + int(self.samples.shape[0] % self.batch_size > 0)
self.expected_steps = (self.samples.shape[0] // self.batch_size) + \
int(self.samples.shape[0] % self.batch_size > 0)

# Report number of classes globally
self.config["num_classes"] = self.num_classes
print(" >> Number of classes:", self.num_classes)

# Online labels
if self.mode == "Training":
if self.mode == "training":
self.out_dir = config["paths"]["results"] + "soft_labels/"
os.makedirs(self.out_dir, exist_ok=True)
self.init_online_labels()


def start(self, session):
pass

def balanced_sample(self):
# Obtain distribution of single cells per class
counts = self.split_data.groupby("Class_Name").count().reset_index()[["Class_Name", "Key"]]
counts = self.split_data.groupby(self.target).count().reset_index()[[self.target, "Key"]]
sample_size = int(counts.Key.median())
counts = {r.Class_Name: r.Key for k,r in counts.iterrows()}
counts = {r[self.target]: r.Key for k, r in counts.iterrows()}

# Sample the same number of cells per class
class_samples = []
for cls in self.split_data.Class_Name.unique():
class_samples.append(self.split_data[self.split_data.Class_Name == cls].sample(n=sample_size, replace=counts[cls] < sample_size))
for cls in self.split_data[self.target].unique():
class_samples.append(self.split_data[self.split_data[self.target] == cls].sample(
n=sample_size, replace=counts[cls] < sample_size))
self.samples = pd.concat(class_samples)

# Randomize order
if self.mode == "Training":
self.samples = self.samples.sample(frac=1.0).reset_index()
else:
self.samples = self.samples.sample(frac=0.1).reset_index()
print(self.samples[self.target].value_counts())

self.samples = self.samples.sample(frac=1.0).reset_index()

def generate(self, sess, global_step=0):
def generator(self, sess, global_step=0):
pointer = 0
while True:
x = np.zeros([self.batch_size, self.box_size, self.box_size, self.num_channels])
Expand All @@ -99,13 +84,12 @@ def generate(self, sess, global_step=0):
pointer = 0
filename = os.path.join(self.directory, self.samples.loc[pointer, "Image_Name"])
im = skimage.io.imread(filename).astype(np.float32)
x[i,:,:,:] = deepprofiler.imaging.cropping.fold_channels(im)
x[i, :, :, :] = deepprofiler.imaging.cropping.fold_channels(im)
y.append([self.soft_labels[self.samples.loc[pointer, "index"], :]])
pointer += 1
yield(x, np.concatenate(y, axis=0))

yield x, np.concatenate(y, axis=0)

def generator(self, source="samples"):
def generate(self, source="samples"):
pointer = 0
if source == "splits":
dataframe = self.split_data
Expand All @@ -129,11 +113,10 @@ def generator(self, source="samples"):
pointer += 1
if len(y) < x.shape[0]:
x = x[0:len(y), ...]
yield(x, tf.keras.utils.to_categorical(y, num_classes = self.num_classes))

yield x, tf.keras.utils.to_categorical(y, num_classes=self.num_classes)

def init_online_labels(self):
LABEL_SMOOTHING = 0.2
LABEL_SMOOTHING = self.config["train"]["model"]["params"]["online_label_smoothing"]
self.soft_labels = np.zeros((self.split_data.shape[0], self.num_classes)) + LABEL_SMOOTHING/self.num_classes
print("Soft labels:", self.soft_labels.shape)
for k, r in self.split_data.iterrows():
Expand All @@ -143,17 +126,18 @@ def init_online_labels(self):
sl = pd.DataFrame(data=self.soft_labels)
sl.to_csv(self.out_dir + "0000.csv", index=False)


def update_online_labels(self, model, epoch):
# Prepare parameters and predictions
LAMBDA = 0.01
LAMBDA = self.config["train"]["model"]["params"]["online_lambda"]
predictions = []

# Get predictions with the model
model.get_layer("augmentation_layer").is_training = False
for batch in self.generator(source = "splits"):
if model.layers[1].name == 'augmentation_layer':
model.get_layer("augmentation_layer").is_training = False
for batch in self.generate(source="splits"):
predictions.append(model.predict(batch[0]))
model.get_layer("augmentation_layer").is_training = True
if model.layers[1].name == 'augmentation_layer':
model.get_layer("augmentation_layer").is_training = True

# Update soft labels
predictions = np.concatenate(predictions, axis=0)
Expand All @@ -164,7 +148,6 @@ def update_online_labels(self, model, epoch):
sl = pd.DataFrame(data=self.soft_labels)
sl.to_csv(self.out_dir + "{:04d}.csv".format(epoch+1), index=False)


def stop(self, session):
pass

Expand Down
4 changes: 3 additions & 1 deletion tests/files/config/test.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
"label_smoothing": 0.0,
"feature_dim": 100,
"latent_dim": 100,
"epsilon_std": 1.0
"epsilon_std": 1.0,
"online_label_smoothing": 0.1,
"online_lambda": 0.01
},
"lr_schedule" : {
"epoch": [1,3,5],
Expand Down