In [None]:
import haiku as hk
import jax
import jax.numpy as jnp
import pandas as pd

from multiomics_open_research.bulk_rna_bert.downstream.pretrained import get_pretrained_downstream_model
from multiomics_open_research.bulk_rna_bert.preprocess import preprocess_rna_seq_for_bulkrnabert

In [None]:
parameters, forward_fn, tokenizer, config, mlm_config = get_pretrained_downstream_model(
    model_name="tcga_5_cohorts",
    checkpoint_directory="../checkpoints/",
)
forward_fn = hk.transform(forward_fn)

In [None]:
# Get bulk RNASeq data and tokenize it
rna_seq_df = pd.read_csv("../data/tcga_sample.csv")
rna_seq_array = preprocess_rna_seq_for_bulkrnabert(rna_seq_df, mlm_config)
tokens_ids = tokenizer.batch_tokenize(rna_seq_array)
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

In [None]:
# Inference
random_key = jax.random.PRNGKey(0)
outs = forward_fn.apply(parameters, random_key, tokens[:1])

In [None]:
with open("../data/5_cohorts_labels_mapping.json", "r") as f:
    label_mapping = json.load(f)

predicted_cancer_type = label_mapping[outs["logits"].argmax()]
print(f"Cancer type prediction {predicted_cancer_type}")