In [2]:
# start coding here

In [3]:
from anndata import read_h5ad, AnnData
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error, r2_score

import altair as alt
from altair_saver import save as alt_save

In [4]:
metmap_tissue = snakemake.params['metmap_tissue']
model = snakemake.wildcards['model']

In [5]:
coexp_adata = read_h5ad(snakemake.input[0])

In [6]:
metmap_tissue

In [7]:
coexp_adata.obs

In [8]:
def model_with_linear_regression(cell_ontology_id, ct_coexp_adata, coef_df, coef_arr):
    # Metastatic potential, the response variable
    y = ct_coexp_adata.obs['met_potential_mean'].values
    X = ct_coexp_adata.X.toarray()
    
    top_coef_meta = None
    top_coef_arr = None
    best_mse = None
    all_y_test = np.array([])
    all_y_pred = np.array([])
    
    folds = []
    
    kf = KFold(n_splits=5, shuffle=True, random_state=2445)
    for fold_i, (train_index, test_index) in enumerate(kf.split(X, y)):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        regr = LinearRegression()

        # Train the model using the training sets
        regr.fit(X_train, y_train)

        # Make predictions using the testing set
        y_pred = regr.predict(X_test)
        
        all_y_test = np.concatenate((all_y_test, y_test))
        all_y_pred = np.concatenate((all_y_pred, y_pred))
        
        mse_val = mean_squared_error(y_test, y_pred)
        r2_val = r2_score(y_test, y_pred)
        
        folds.append({
            "cell_ontology_id": cell_ontology_id,
            "fold": fold_i,
            "MSE": mse_val,
            "R2": r2_val
        })
        
        if best_mse is None or mse_val < best_mse:
            top_coef_meta = {
                "cell_ontology_id": cell_ontology_id,
                "metmap_tissue": metmap_tissue,
                "MSE": mse_val,
                "R2": r2_val,
                "model": "LinearRegression"
            }
            top_coef_arr = regr.coef_
            best_mse = mse_val
    
    coef_df = coef_df.append(top_coef_meta, ignore_index=True)
    coef_arr.append(top_coef_arr)
    
    pred_df = pd.DataFrame(data=[], columns=["y_test", "y_pred"])
    pred_df["y_test"] = all_y_test
    pred_df["y_pred"] = all_y_pred
    
    fold_df = pd.DataFrame(data=folds)
    
    return coef_df, coef_arr, pred_df, fold_df
    

In [9]:
def model_with_random_forest_regressor(cell_ontology_id, ct_coexp_adata, coef_df, coef_arr):
    y = ct_coexp_adata.obs['met_potential_mean'].values
    X = ct_coexp_adata.X.toarray()
    
    top_coef_meta = None
    top_coef_arr = None
    best_mse = None
    all_y_test = np.array([])
    all_y_pred = np.array([])
    
    folds = []
    
    kf = KFold(n_splits=5, shuffle=True, random_state=2445)
    for fold_i, (train_index, test_index) in enumerate(kf.split(X, y)):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
    
    
        regr = RandomForestRegressor()

        # Train the model using the training sets
        regr.fit(X_train, y_train)

        # Make predictions using the testing set
        y_pred = regr.predict(X_test)

        
        
        all_y_test = np.concatenate((all_y_test, y_test))
        all_y_pred = np.concatenate((all_y_pred, y_pred))
        
        mse_val = mean_squared_error(y_test, y_pred)
        r2_val = regr.score(X_test, y_test)
        
        folds.append({
            "cell_ontology_id": cell_ontology_id,
            "fold": fold_i,
            "MSE": mse_val,
            "R2": r2_val
        })
        
        if best_mse is None or mse_val < best_mse:
            top_coef_meta = {
                "cell_ontology_id": cell_ontology_id,
                "metmap_tissue": metmap_tissue,
                "MSE": mse_val,
                "R2": r2_val,
                "model": "RandomForestRegressor"
            }
            top_coef_arr = regr.feature_importances_
            best_mse = mse_val
        
    
    coef_df = coef_df.append(top_coef_meta, ignore_index=True)
    coef_arr.append(top_coef_arr)
    
    pred_df = pd.DataFrame(data=[], columns=["y_test", "y_pred"])
    pred_df["y_test"] = all_y_test
    pred_df["y_pred"] = all_y_pred
    
    fold_df = pd.DataFrame(data=folds)
    
    return coef_df, coef_arr, pred_df, fold_df
    

In [10]:
coef_df = pd.DataFrame(columns=["cell_ontology_id", "metmap_tissue", "MSE", "R2"])
coef_arr = []
cell_type_by_fold_df = pd.DataFrame(columns=["cell_ontology_id", "fold", "MSE", "R2"])

In [11]:
best_cell_type_pred_df = None
best_cell_type_pred_mse = None
best_cell_type_pred_r2 = None
best_cell_type_pred = None

In [12]:
# Build a regression model for each cell type
cell_ontology_ids = coexp_adata.obs['cell_ontology_id'].unique().tolist()

for cell_type in cell_ontology_ids:
    ct_coexp_adata = coexp_adata[coexp_adata.obs['cell_ontology_id'] == cell_type, :]
    if model == "LinearRegression":
        coef_df, coef_arr, pred_df, fold_df = model_with_linear_regression(cell_type, ct_coexp_adata, coef_df, coef_arr)
    elif model == "RandomForestRegressor":
        coef_df, coef_arr, pred_df, fold_df = model_with_random_forest_regressor(cell_type, ct_coexp_adata, coef_df, coef_arr)
    
    pred_mse = mean_squared_error(pred_df["y_test"].values, pred_df["y_pred"].values)
    if best_cell_type_pred_mse is None or pred_mse < best_cell_type_pred_mse:
        best_cell_type_pred_df = pred_df
        best_cell_type_pred_mse = pred_mse
        best_cell_type_pred_r2 = r2_score(pred_df["y_test"].values, pred_df["y_pred"].values)
        best_cell_type_pred = cell_type
    
    cell_type_by_fold_df = cell_type_by_fold_df.append(fold_df, ignore_index=True)

In [13]:
pred_df = best_cell_type_pred_df

In [14]:
min_val = pred_df.min().min()
max_val = pred_df.max().max()

pred_plot = alt.Chart(pred_df).mark_point().encode(
    x=alt.X("y_test:Q", scale=alt.Scale(domain=[min_val, max_val])),
    y=alt.Y("y_pred:Q", scale=alt.Scale(domain=[min_val, max_val]))
).properties(
    title={
        "text": f"Predicted vs. actual metastasis potential",
        "subtitle": f"{model}, {metmap_tissue} {best_cell_type_pred}, MSE: {best_cell_type_pred_mse:.2f}, R^2: {best_cell_type_pred_r2:.2f}"
    }
)
pred_plot

In [16]:
fold_plot = alt.Chart(cell_type_by_fold_df).mark_boxplot().encode(
    x=alt.X('cell_ontology_id:N', axis=alt.Axis(title="Cell Ontology ID")),
    y=alt.Y('MSE:Q', axis=alt.Axis(title="Mean Squared Error"))
).properties(
    title={
        "text": f"5-fold CV MSE per Cell Type",
        "subtitle": f"{model}, {metmap_tissue}"
    }
)

fold_plot

In [115]:
coef_X = np.stack(coef_arr, axis=-1).T

In [108]:
coef_df.shape

In [69]:
coef_adata = AnnData(X=coef_X, obs=coef_df)
coef_adata.write(snakemake.output["model"])

In [None]:
alt_save(pred_plot, snakemake.output["prediction_plot"])
alt_save(fold_plot, snakemake.output["mse_plot"])