In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import tensorflow as tf
import typing as t

from pathlib import Path
from tensorflow import keras

from cdrpy.models import deepcdr
from cdrpy.data.datasets import Dataset, get_predictions
from cdrpy.data.preprocess import normalize_responses
from cdrpy.splits import load_split
from cdrpy.metrics import tf_metrics

In [None]:
input_dir = Path("../../data/inputs/GDSCv2DepMap")

exp_path = input_dir / "DeepCDR/FeatureCellToExpression717CGCGenesTPMLogp1.csv"
mut_path = input_dir / "DeepCDR/FeatureCellToSomaticMutationsPositionEncoded716CGCGenesAll.csv"
mol_path = input_dir / "DrugToConvMolFeatures.pickle"

label_path = input_dir / "LabelsLogIC50.csv"
split_path = input_dir / "splits/mixed"

cell_exp_enc, cell_mut_enc, _ = deepcdr.load_cell_features(exp_path, mut_path)
drug_feat_enc, drug_adj_enc = deepcdr.load_drug_features(mol_path)

In [None]:
dataset = Dataset.from_csv(
    label_path,
    name="GDSCv2DepMap",
    cell_encoders=[cell_exp_enc, cell_mut_enc],
    drug_encoders=[drug_feat_enc, drug_adj_enc],
)

split = load_split(split_path, 1)

train_ds = dataset.select(split.train_ids, name="train")
val_ds = dataset.select(split.val_ids, name="val")
test_ds = dataset.select(split.test_ids, name="test")

train_ds, val_ds, test_ds = normalize_responses(
    train_ds, val_ds, test_ds, norm_method="global"
)

In [None]:
exp_dim = cell_exp_enc.shape[-1]
mut_dim = cell_mut_enc.shape[1]
drug_dim = drug_feat_enc.shape[-1]

model = deepcdr.create_model(exp_dim, mut_dim, drug_dim)

In [None]:
cell_norm = keras.layers.Normalization()
X = np.array(cell_exp_enc.encode(val_ds.cell_ids))
cell_norm.adapt(X)

In [None]:
X.mean(axis=0)

In [None]:
cell_norm(X).numpy().std(axis=0).round(2)

In [None]:
mini_ds = train_ds.sample(1000).encode_tf().shuffle(100).batch(32)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss="mean_squared_error",
    metrics=["mse", tf_metrics.pearson],
)

In [None]:
model.fit(
    mini_ds,
    epochs=1,
    # validation_data=val_tfds,
    # callbacks=callbacks,
)