In [1]:
import pandas as pd

from biological_fuzzy_logic_networks.DREAM_analysis.utils import (create_bfz, 
                                                                  prepare_cell_line_data, 
                                                                  cl_data_to_input)
import torch
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sc1_winner = pd.read_csv("/Users/adr/Box/CAR_Tcells/Data/Participants/SC1/icx_bxai_sc1.csv")
sc1_truth = pd.read_csv("/Users/adr/Box/CAR_Tcells/Data/DREAMdata/Challenge_data/sc1_test_time_aligned.csv", index_col=0)

In [3]:
from math import sqrt
def mean_RMSE(ground_truth, predictions, markers=["p.ERK", "p.Akt.Ser473.", "p.S6", "p.HER2", "p.PLCg2"]):
    
    rmse_values = []
    for m in markers:
        for cl in ground_truth["cell_line"].unique():
            rmse = sqrt(sum((ground_truth.loc[ground_truth["cell_line"]==cl, m] - 
                             predictions.loc[predictions["cell_line"]==cl, m])**2) / 
                        len(ground_truth[ground_truth["cell_line"]==cl]))
            rmse_values.append(rmse)
    return sum(rmse_values)/len(rmse_values)

In [4]:
sc1_winner

Unnamed: 0,glob_cellID,cell_line,treatment,time,cellID,fileID,p.ERK,p.Akt.Ser473.,p.S6,p.HER2,p.PLCg2
0,1,AU565,EGF,0.0,1,59,3.901440,4.186802,5.402692,4.357932,2.280056
1,2,AU565,EGF,0.0,1,122,4.050485,4.431242,7.042006,4.999486,2.807537
2,3,AU565,EGF,0.0,2,59,4.171771,4.566755,6.040624,4.716986,2.456247
3,4,AU565,EGF,0.0,2,122,3.412021,3.618416,5.450850,4.169167,1.795671
4,5,AU565,EGF,0.0,3,59,2.383576,2.730812,3.964737,3.504084,1.459668
...,...,...,...,...,...,...,...,...,...,...,...
2383053,2383054,MDAMB436,iPKC,40.0,3563,2108,4.103123,4.384158,7.069998,4.702792,2.854447
2383054,2383055,MDAMB436,iPKC,40.0,3564,2108,3.982192,4.499595,7.004058,4.785437,3.004326
2383055,2383056,MDAMB436,iPKC,40.0,3565,2108,3.698230,4.565942,7.264115,4.523724,2.598044
2383056,2383057,MDAMB436,iPKC,40.0,3566,2108,2.570859,3.822638,4.720550,2.993271,1.501598


In [5]:
sc1_truth

Unnamed: 0,glob_cellID,cell_line,treatment,time,cellID,fileID,p.ERK,p.Akt.Ser473.,p.S6,p.HER2,p.PLCg2
0,1,AU565,EGF,0.0,1,59,2.161190,3.205339,6.83042,4.626783,1.995542
1,2,AU565,EGF,0.0,1,122,3.665097,3.893094,7.65064,4.713031,2.738222
2,3,AU565,EGF,0.0,2,59,4.326529,4.902871,7.67852,5.207019,1.631741
3,4,AU565,EGF,0.0,2,122,2.391616,3.246099,5.90715,3.414914,2.404445
4,5,AU565,EGF,0.0,3,59,1.557082,1.271263,2.46928,4.383723,0.986227
...,...,...,...,...,...,...,...,...,...,...,...
2383053,2383054,MDAMB436,iPKC,40.0,3563,2108,4.099454,4.580278,4.50455,3.499509,2.347711
2383054,2383055,MDAMB436,iPKC,40.0,3564,2108,3.680623,4.352142,6.43669,3.533757,2.389125
2383055,2383056,MDAMB436,iPKC,40.0,3565,2108,3.319599,5.045283,4.85734,3.331996,1.906681
2383056,2383057,MDAMB436,iPKC,40.0,3566,2108,2.152803,4.370816,2.48743,2.313594,1.386166


In [6]:
mean_RMSE(sc1_truth.sort_values("glob_cellID"), sc1_winner.sort_values("glob_cellID"))

0.8551523556310505

In [7]:
mean_RMSE(sc1_truth[sc1_truth["time"]==9.], 
          sc1_winner[sc1_winner["time"]==9.], markers=["p.Akt.Ser473."])

0.8436160585600524

In [8]:
ckpt = torch.load("/Users/adr/Box/CAR_Tcells/Model/Test/MEK_FAK_ERK/9/model.pt")
with open("/Users/adr/Box/CAR_Tcells/Model/Test/MEK_FAK_ERK/9/9_config.json") as f:
        config = json.load(f)
model = create_bfz("/Users/adr/Box/CAR_Tcells/Data/DREAMdata/MEK_FAK_ERK.sif", config["network_class"])
model.load_from_checkpoint(ckpt["model_state_dict"])

In [None]:
test_cell_lines = ["AU565", "MDAMB436", "EFM19", "HCC2218", "LY2", "MACLS2"]
test_file_paths = [f"/Users/adr/Box/CAR_Tcells/Data/DREAMdata/{CL}.csv" for CL in test_cell_lines]

cl_data = prepare_cell_line_data(
        data_file=test_file_paths,
        time_point=9)

(   train_data,
    valid_data,
    train_inhibitors,
    valid_inhibitors,
    train_input,
    valid_input,
    train,
    valid,
    scaler,
)= cl_data_to_input(
    data=cl_data,
    model=model,
    train_treatments=None,
    valid_treatments=None,
    train_cell_lines=None,
    valid_cell_lines=None,
    inhibition_value=1,
    minmaxscale=True,
    add_root_values=False,
    input_value=None,
    root_nodes=None
)

all_test_data = {node: torch.cat((train_data[node], valid_data[node])) for node in train_data.keys()}
all_test_input = {node: torch.cat((train_input[node], valid_input[node])) for node in train_input.keys()}
all_test_inhibitors = {node: torch.cat((train_inhibitors[node], valid_inhibitors[node])) for node in train_inhibitors.keys()}
all_test = (pd.concat([train, valid])[["MEK12", "FAK", "ERK12", "treatment"]]).reset_index(drop=True)

['AU565' 'MDAMB436' 'EFM19' 'HCC2218' 'LY2' 'MACLS2']


In [None]:
a = cl_data_to_input(
    data=cl_data,
    model=model,
    train_treatments=None,
    valid_treatments=None,
    train_cell_lines=None,
    valid_cell_lines=None,
    inhibition_value=1,
    minmaxscale=True,
    add_root_values=False,
    input_value=None,
    root_nodes=None
)

In [None]:
with torch.no_grad():
    model.set_network_ground_truth(all_test_data)
    model.sequential_update(model.root_nodes, all_test_inhibitors)