In [1]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
import umap
import numpy as np
import matplotlib.pyplot as plt
from utils import load_data, split_data
import os

print(f"tensorflow: {tf.__version__}")
print(f"keras: {tf.keras.__version__}")

tensorflow: 2.0.0-beta0
keras: 2.2.4-tf


In [2]:
%env DATA_DIR ../data/GSE92742_Broad_LINCS

data_dir = os.environ['DATA_DIR']

# data_fname = 'GSE92742_Broad_LINCS_Level4_ZSPCINF_mlr12k_n1319138x12328.gctx' # Level 4 data
data_fname = 'GSE92742_Broad_LINCS_Level3_INF_mlr12k_n1319138x12328.gctx' # Level 3 data
data_path = os.path.join(data_dir, data_fname)

sample_meta_fname = 'GSE92742_Broad_LINCS_inst_info.txt'
sample_meta_path = os.path.join(data_dir, sample_meta_fname)

env: DATA_DIR=../data/GSE92742_Broad_LINCS


In [None]:
# Read in raw data, selecting for cells by treatment and cell line

pert_types = [
    'trt_cp',       # treated with compound
    'ctl_vehicle',  # control for compound treatment (e.g DMSO) 
    'ctl_untrt'     # untreated samples
]

cell_ids = [
    'VCAP', # prostate tumor
    'MCF7', # breast tumor
    'PC3',  # prostate tumor
]

# Load Data
sample_meta, gene_labels, data = load_data(data_path, sample_meta_path, pert_types, cell_ids)

# Normalize expression between 0-1 per gene
# TODO: implement this normalization per batch during training
data_normed = data / data.max(0)

print(f"data size: {data.shape}")

In [None]:
# Split data into training, validation, and testing
train, val, test = split_data(data_normed, sample_meta, 0.2)

print(f"training size:   {train[0].shape[0]:,}")
print(f"validation size: {val[0].shape[0]:,}")
print(f"testing size:    {test[0].shape[0]:,}")

In [None]:
from tensorflow.keras import Sequential, layers
from metrics import PearsonsR

def create_AE(hidden_layers, activation='relu', optimizer='adam', out_size=978):
    model = Sequential()
    model.add(layers.Dense(hidden_layers[0], activation=activation, input_shape=(out_size,)))

    for nunits in hidden_layers[1:]:
        model.add(layers.Dense(nunits, activation=activation))
        
    model.add(layers.Dense(out_size, activation='relu'))
    
    model.compile(
        optimizer = optimizer, 
        loss = 'mean_squared_error',
        metrics = [
            tf.keras.metrics.CosineSimilarity(),
            PearsonsR() # custom correlation metric
        ] 
    )
    
    return model

def create_tf_dataset(X, y, shuffle=True, repeated=True, batch_size=32):
    dataset = tf.data.Dataset.from_tensor_slices((X,y))
    if repeated:
        dataset = dataset.repeat()
    if shuffle:
        dataset = dataset.shuffle(buffer_size = X.shape[0])
    dataset = dataset.batch(batch_size)
    # `prefetch` lets the dataset fetch batches, in the background while the model is training.
    dataset = dataset.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)
    return dataset

In [None]:
batch_size = 64
train_dataset = create_tf_dataset(train[0], train[0], batch_size=batch_size)
train_dataset

In [None]:
# Tensorboard stuff
# import datetime, os
# logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
# tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
tf.random.set_seed(42)
model = create_AE([128, 2, 128])
model.summary()

In [None]:
model.fit(
    train_dataset,
    epochs = 5,
    shuffle = True,
    steps_per_epoch = train[0].shape[0] // batch_size,
    validation_data = (val[0], val[0]),
#     callbacks = [tensorboard_callback]
)

In [None]:
test_loss = model.evaluate(test[0], test[0])
print(f"loss: {test_loss}")

### Plotting

In [None]:
from utils import plot_embedding2D, plot_embedding3D
encoder = tf.keras.Model(inputs=model.layers[0].input, outputs=model.layers[1].output)
h = encoder.predict(test[0])

In [None]:
plot_embedding2D(h, test[1].cell_id.values, alpha=0.2)

In [None]:
import pandas as pd
cell_info_fname = 'GSE92742_Broad_LINCS_cell_info.txt'
cell_meta = pd.read_csv(os.path.join(data_dir, cell_info_fname), sep='\t', na_values = '-666')

In [None]:
sample_counts = sample_meta['cell_id'].value_counts()
sample_counts[sample_counts > 5000].plot(kind='bar');

In [None]:
cell_meta[cell_meta.base_cell_id == 'PC3']