In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import h5py
from collections import Counter
import numpy as np
from numpy import loadtxt
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score

In [2]:
# f = h5py.File('mouse1sample1.hdf5','r')
# p_scores = pd.read_csv('merfish_M1S1_filtered_periph_scores.csv')
# p_scores = p_scores.loc[p_scores['annotation'] != 'unannotated']

In [3]:
# time intensive step: get list of annotated cell_ids and annotation types 
# annotated_cells = []
# annotations = []
# for cell_id in f['cells']:
#     cell = f['cells'][cell_id]
#     ann = dict(cell.attrs)['annotation']
#     if ann != 'unannotated' and cell_id in list(p_scores['cell_id']):
#         annotated_cells.append(cell_id)
#         annotations.append(ann)
        
# annotations_count = Counter(annotations)
# print("Number of Annotated Samples:", len(annotated_cells))
# print("Num of Cell Types:", len(annotations_count))
# print("Max Number of Samples per Cell Type:", max(annotations_count.values()))
# print("Min Number of Samples per Cell Type:", min(annotations_count.values()))
# print("Average Number of Samples per Cell Type:", sum(annotations_count.values())/93)

In [4]:
# periphery data: shape (len(annotated_cells), num_genes)
# num_genes = (p_scores['gene']).unique().shape[0]
# cell_indices = {k: v for v, k in enumerate(list(p_scores['cell_id'].unique()))}
# gene_indices = {k: v for v, k in enumerate(list(p_scores['gene'].unique()))}

# periphery_data = np.zeros((len(annotated_cells), num_genes))
# for cell_id in annotated_cells:
#     subset = p_scores.loc[p_scores['cell_id'] == cell_id]
#     for index, s in subset.iterrows():
#         cell_index = cell_indices[cell_id]
#         gene_index = gene_indices[s['gene']]
#         periphery_data[cell_index][gene_index] = s['periphery_score']

# np.save('periphery_data', periphery_data)

periphery_data = np.load('periphery_data.npy')

In [5]:
#image data: shape (len(annotated_cells), 228, 432))
# image_data = np.zeros(shape=(len(annotated_cells), 288, 432)).astype('uint8')
# fig = plt.figure()
# plt.gray()
# idx = 0
# for cell_id in annotated_cells:
#     cell = f['cells'][cell_id]
#     keys = list(cell['boundaries'].keys())
#     midpoint = keys[int(len(keys)/2)]
#     boundary = cell['boundaries'][midpoint]
#     xs = boundary[:,0]
#     ys = boundary[:,1]
#     plt.plot(xs,ys)
#     plt.axis('equal')
#     plt.xticks([])
#     plt.yticks([])
#     fig.canvas.draw()
#     fig_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
#     fig_array = fig_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
#     image_data[idx] = fig_array[:,:,0].astype('uint8')
#     idx += 1
#     plt.clf()
#     print(f"\r{idx}", end="")
  
# print(image_data.shape)
# np.save('image_data', image_data)

image_data = np.load('image_data.npy')
image_data = np.repeat(image_data[..., np.newaxis], 3, -1)

MemoryError: Unable to allocate 6.02 GiB for an array with shape (17312, 288, 432, 3) and data type uint8

In [None]:
# one hot encoded label vectors
# annotation_indices = {k: v for v, k in enumerate(list(annotations_count.keys()))}
# labels = np.zeros((len(annotated_cells), len(annotation_indices)))
# for index in range(0, len(annotated_cells)):
#     cell = f['cells'][annotated_cells[index]]
#     ann = dict(cell.attrs)['annotation']
#     arr_index = annotation_indices[ann]
#     labels[index][arr_index] = 1
# print(labels.shape)
# np.save('labels', labels)

labels = np.load('labels.npy')

In [None]:
indices = [x for x in range(labels.shape[0])]
train_indices, test = train_test_split(indices, test_size=0.3)
test_indices, val_indices = train_test_split(test, test_size=0.5)

image_train, periphery_train, labels_train = image_data[train_indices], periphery_data[train_indices], labels[train_indices]
image_test, periphery_test, labels_test = image_data[test_indices], periphery_data[test_indices], labels[test_indices]
image_val, periphery_val, labels_val = image_data[val_indices], periphery_data[val_indices], labels[val_indices]

print(image_train.shape, periphery_train.shape, labels_train.shape)
print(image_test.shape, periphery_test.shape, labels_test.shape)
print(image_val.shape, periphery_val.shape, labels_val.shape)

In [None]:
# initializing individual models + late fusion model
periphery_model = tf.keras.models.Sequential([tf.keras.layers.Input(192),
                                    tf.keras.layers.Dense(100, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(50, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(25, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(80, activation=tf.nn.softmax)])

image_model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(32,(3,3),padding='same', activation='relu',input_shape=(288, 432, 3)),
                                    tf.keras.layers.MaxPool2D((2,2)), 
                                    tf.keras.layers.Flatten(), 
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(80, activation=tf.nn.softmax)])

model_concat = tf.keras.layers.concatenate([periphery_model.output, image_model.output], axis=-1)
model_concat = tf.keras.layers.Dense(80, activation='softmax')(model_concat)
model = tf.keras.models.Model(inputs=[periphery_model.input, image_model.input], outputs=model_concat)
model.summary()

In [None]:
# compiling late fusion model with exponential decay lr
lr= tf.keras.optimizers.schedules.ExponentialDecay(
    0.0001,
    decay_steps=100000,
    decay_rate=0.95,
    staircase=True)

# early stopping to prevent overfitting 
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)

model.compile(loss='categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(learning_rate=lr), metrics=['accuracy'])

In [None]:
# late fusion model training
model.fit(x=[periphery_train, image_train], y=labels_train, batch_size=128, epochs=1, validation_data=([periphery_val, image_val], labels_val))

In [None]:
# evaluation 
model.evaluate([periphery_test, image_test], y=labels_test)

In [None]:
# balanced accuracy evaluation
