### Longitudinal predictions on the AML dataset (CloMu)

This notebook needs to be run under the [CloMu repository](https://github.com/elkebir-group/CloMu.git).

#### 1. Imports

In [13]:
import torch
import numpy as np
import pandas as pd
from scipy.special import softmax
from collections import defaultdict, deque

from CloMu import *

#### 2. Load data

In [4]:
model = torch.load('Models/realData/savedModel_AML.pt')

  model = torch.load('Models/realData/savedModel_AML.pt')


In [5]:
model

MutationModel(
  (lin1): Linear(in_features=42, out_features=5, bias=True)
  (lin2): Linear(in_features=5, out_features=22, bias=True)
)

In [6]:
AML_fitness = np.load('results/realData/fitness_AML.npy', allow_pickle=True)

In [7]:
AML_fitness

array([0.01939408, 0.03941291, 0.18423268, 0.01930839, 0.00585   ,
       0.01932359, 0.05464946, 0.0193208 , 0.02791075, 0.01935017,
       0.0193455 , 0.01934996, 0.33878654, 0.02613744, 0.0193467 ,
       0.01937685, 0.01936982, 0.01938877, 0.01932122, 0.01941568,
       0.05207818, 0.01933071], dtype=float32)

In [8]:
AML_mutations = np.load('data/realData/categoryNames.npy', allow_pickle=True)

In [9]:
AML_mutations

array(['ASXL', 'ASXL1', 'DNMT3A', 'EZH2', 'FLT3', 'FLT3-ITD', 'GATA2',
       'IDH1', 'IDH2', 'JAK2', 'KIT', 'KRAS', 'NPM1', 'NRAS', 'PTPN11',
       'RUNX1', 'SF3B1', 'SFB1', 'SRSF2', 'TP53', 'U2AF1', 'WT1',
       'ZZZZZZZZZZ', 'ZZZZZZZZZZZZZZZZ'], dtype='<U16')

In [10]:
AML_fitness_dict = dict(zip(AML_mutations[:-2], AML_fitness))

In [11]:
AML_fitness_dict

{np.str_('ASXL'): np.float32(0.019394081),
 np.str_('ASXL1'): np.float32(0.039412912),
 np.str_('DNMT3A'): np.float32(0.18423268),
 np.str_('EZH2'): np.float32(0.019308392),
 np.str_('FLT3'): np.float32(0.005849996),
 np.str_('FLT3-ITD'): np.float32(0.019323587),
 np.str_('GATA2'): np.float32(0.054649465),
 np.str_('IDH1'): np.float32(0.019320799),
 np.str_('IDH2'): np.float32(0.02791075),
 np.str_('JAK2'): np.float32(0.019350166),
 np.str_('KIT'): np.float32(0.019345503),
 np.str_('KRAS'): np.float32(0.019349964),
 np.str_('NPM1'): np.float32(0.33878654),
 np.str_('NRAS'): np.float32(0.02613744),
 np.str_('PTPN11'): np.float32(0.019346695),
 np.str_('RUNX1'): np.float32(0.019376848),
 np.str_('SF3B1'): np.float32(0.01936982),
 np.str_('SFB1'): np.float32(0.019388774),
 np.str_('SRSF2'): np.float32(0.019321222),
 np.str_('TP53'): np.float32(0.019415682),
 np.str_('U2AF1'): np.float32(0.05207818),
 np.str_('WT1'): np.float32(0.019330708)}

In [None]:
mutation_list = AML_mutations[:-2]
mutation_to_index = {m: i for i, m in enumerate(mutation_list)}

# AML-04-001
edges = np.array([
    ['FLT3-ITD', 'PTPN11'],
    ['FLT3', 'WT1'],
    ['NRAS', 'IDH1'],
    ['Root', 'SF3B1'],
    ['SF3B1', 'SRSF2'],
    ['SRSF2', 'FLT3-ITD'],
    ['SRSF2', 'FLT3'],
    ['SRSF2', 'NRAS'],
], dtype='<U13')
tree_04 = defaultdict(list)
for parent, child in edges:
    tree_04[parent].append(child)

# AML-09-001
edges = np.array([
    ['NPM1', 'FLT3-ITD'],
    ['NPM1', 'FLT3'],
    ['NPM1', 'KRAS'],
    ['Root', 'NPM1']
], dtype='<U13')
tree_09 = defaultdict(list)
for parent, child in edges:
    tree_09[parent].append(child)

# AML-83-001
edges = np.array([
    ['DNMT3A', 'IDH2'],
    ['Root', 'DNMT3A']
], dtype='<U13')
tree_83 = defaultdict(list)
for parent, child in edges:
    tree_83[parent].append(child)


#### 3. Predictions

In [15]:
def get_rank(tree, mutation_to_index):
    genotype_info = {}
    
    queue = deque()
    queue.append(('Root', []))  # (current_node, path_so_far)

    while queue:
        node, path = queue.popleft()
        
        # Clean mutation names (skip Root)
        if node != 'Root':
            path = path + [node]
        
        # Build genotype vector
        vector = np.zeros(22, dtype=int)
        for mut in set(path):  # remove duplicates
            if mut in mutation_to_index:
                idx = mutation_to_index[mut]
                vector[idx] = 1
        
        # Save both vector and path
        genotype_info[node] = {
            'vector': vector,
            'path': path
        }
        
        # Traverse children
        for child in tree.get(node, []):
            queue.append((child, path))

    # 1. First, build a set of existing edges (cleaned mutation names)
    existing_edges = set()
    for parent, child in edges:
        existing_edges.add((parent, child))

    # 2. Initialize storage
    rows = []

    for node, info in genotype_info.items():
        current_path = info['path']
        current_mutations_set = set(current_path)  # For quick lookup
        
        # Model output
        input_tensor = torch.tensor(info['vector'], dtype=torch.float32).unsqueeze(0)
        output = model(input_tensor)[0].squeeze(0).detach().numpy()
        
        for idx, mutation in enumerate(mutation_list):
            # Skip if mutation already present
            if mutation in current_mutations_set:
                continue
            
            # Check if this edge already exists
            parent = current_path[-1] if current_path else 'Root'
            if (parent, mutation) in existing_edges:
                continue
            
            # Build new extended path
            extended_path = current_path + [mutation]
            pathway_str = '->'.join(extended_path)
            
            log_energy = float(output[idx])
            
            rows.append({
                'pathway': "Root->" + pathway_str,
                'log_energy': log_energy
            })

    # 3. Convert to DataFrame
    summary_df = pd.DataFrame(rows)
    
    # 4. Compute probabilities and ranks
    summary_df['probability'] = softmax(summary_df['log_energy'])
    summary_df['rank'] = summary_df['probability'].rank(pct=True)

    return summary_df


In [16]:
rank_04 = get_rank(tree_04, mutation_to_index)
rank_09 = get_rank(tree_09, mutation_to_index)
rank_83 = get_rank(tree_83, mutation_to_index)

In [17]:
rank_04.query(f"pathway == 'Root->SF3B1->SRSF2->NRAS->WT1'")["rank"].astype(float).values[0]

np.float64(0.6936416184971098)

In [18]:
rank_04.query(f"pathway == 'Root->SF3B1->SRSF2->IDH1'")["rank"].astype(float).values[0]

np.float64(0.3872832369942196)

In [19]:
rank_09.query(f"pathway == 'Root->NPM1->FLT3->WT1'")["rank"].astype(float).values[0]

np.float64(0.3333333333333333)

In [20]:
rank_83.query(f"pathway == 'Root->DNMT3A->IDH2->NRAS'")["rank"].astype(float).values[0]

np.float64(0.9672131147540983)