# Tutorial for phylodynamics model selection
Based on the method developed in Perez M.F. and Gascuel O.PhyloCNN: Improving tree representation and neural network architecture for deep learning from trees in phylodynamics and diversification studies. https://www.biorxiv.org/content/10.1101/2024.12.13.628187v1

## 1. Introduction & Requirements
This tutorial shows how to train a CNN model that classify phylogentic trees of viruses according to three competing epidemiological (phylodynamics) models - Birth-Death (BD), Birth-Death Exposed Infectious (BDEI) and Birth-Death with Superspreaders (BDSS). 

<img src="img/Figure_BDModels.png" width="500" height="340"> 

The simulated trees were encoded by describing the neighborhood (e.g., length of outgoing branches) and main measurements (e.g., date, number of descendants) of all nodes and leaves of the phylogeny.

<img src="img/Figure 1 PhyloCNN Encoding.png" width="750" height="500"> 

## 2. Libraries and Data Loading
We first load the required python libraries and then we load phylogenetic trees simulated under each of the 3 models (BD, BDEI, BDSS) and their respective parameter values (sampled from prior distributions). We reshape each encoded tree to `(samples, 1000, 18)` because there are up to 1000 nodes per tree, each with 18 features.


In [None]:
import pandas as pd
import tensorflow as tf
import keras
import numpy as np

from keras.models import Sequential, Model
from keras.layers import Activation, Dense
from keras.layers import Conv2D, GlobalAveragePooling2D, BatchNormalization
from keras.layers import Dense, Dropout, Activation, Flatten

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# 1) Load parameters
param_train_BD = pd.read_csv('./parameters_BD.txt', sep='\t')
param_test_BD  = pd.read_csv('./testset/parameters_BD.txt', sep='\t')
param_train_BDEI = pd.read_csv('./parameters_BDEI.txt', sep='\t')
param_test_BDEI  = pd.read_csv('./testset/parameters_BDEI.txt', sep='\t')
param_train_BDSS = pd.read_csv('./parameters_BDSS.txt', sep='\t')
param_test_BDSS  = pd.read_csv('./testset/parameters_BDSS.txt', sep='\t')

# 2) Load tree encodings for BD, BDEI, BDSS models
encoding_BD    = pd.read_csv('./Encoded_trees_BD.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)
encoding_test_BD = pd.read_csv('./testset/Encoded_trees_BD.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)
encoding_BDEI  = pd.read_csv('./Encoded_trees_BDEI.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)
encoding_test_BDEI = pd.read_csv('./testset/Encoded_trees_BDEI.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)
encoding_BDSS  = pd.read_csv('./Encoded_trees_BDSS.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)
encoding_test_BDSS = pd.read_csv('./testset/Encoded_trees_BDSS.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)

## 3. Data Preprocessing
We will process the input to be properly formatted before feeding it to the neural network. This will involve the following steps:

### Removing Unused Columns
The last line in each array is the “rescaling factor”, that needs to be removed from the encoding (it's used only to predict parameters, which will not be covered in this tutorial):

### Label Assignment
We create a label array **Y** for the training and test set, with:
- `0` for BD
- `1` for BDEI
- `2` for BDSS

### Adding Sampling Probability
We add an extra feature column representing `sampling_proba`.

In [None]:
# remove unused columns: rescaling factor
encoding_BD=np.delete(encoding_BD, -1, axis=1)
encoding_test_BD=np.delete(encoding_test_BD, -1, axis=1)
encoding_BDEI=np.delete(encoding_BDEI, -1, axis=1)
encoding_test_BDEI=np.delete(encoding_test_BDEI, -1, axis=1)
encoding_BDSS=np.delete(encoding_BDSS, -1, axis=1)
encoding_test_BDSS=np.delete(encoding_test_BDSS, -1, axis=1)

#Add labels for each simulation (a different label for each model)
Y = [0 for i in range(len(encoding_BD))]
Y.extend([1 for i in range(len(encoding_BDEI))])
Y.extend([2 for i in range(len(encoding_BDSS))])
Y = np.array(Y)

Y_test = [0 for i in range(len(encoding_test_BD))]
Y_test.extend([1 for i in range(len(encoding_test_BDEI))])
Y_test.extend([2 for i in range(len(encoding_test_BDSS))])
Y_test = np.array(Y_test)

#Now insert an additional column with sampling proba for all nodes

samp_proba_list = np.array(param_train_BD['sampling_proba'])
encoding_BD=np.concatenate((encoding_BD,np.repeat(samp_proba_list,999).reshape(-1,999,1)),axis=2)

samp_proba_list_test = np.array(param_test_BD['sampling_proba'])
encoding_test_BD=np.concatenate((encoding_test_BD,np.repeat(samp_proba_list_test,999).reshape(-1,999,1)),axis=2)

samp_proba_list = np.array(param_train_BDEI['sampling_proba'])
encoding_BDEI=np.concatenate((encoding_BDEI,np.repeat(samp_proba_list,999).reshape(-1,999,1)),axis=2)

samp_proba_list_test = np.array(param_test_BDEI['sampling_proba'])
encoding_test_BDEI=np.concatenate((encoding_test_BDEI,np.repeat(samp_proba_list_test,999).reshape(-1,999,1)),axis=2)

samp_proba_list = np.array(param_train_BDSS['sampling_proba'])
encoding_BDSS=np.concatenate((encoding_BDSS,np.repeat(samp_proba_list,999).reshape(-1,999,1)),axis=2)

samp_proba_list_test = np.array(param_test_BDSS['sampling_proba'])
encoding_test_BDSS=np.concatenate((encoding_test_BDSS,np.repeat(samp_proba_list_test,999).reshape(-1,999,1)),axis=2)

### Padding & Ordering Leaves/Nodes
**Goal**: Ensure each encoded tree has exactly 500 leaves and 500 internal nodes. We:

1. **Identify leaves** (column 3 == 1) and **sort them** by their ages (column 1).  
2. **Identify internal nodes** (column 3 > 1) and also **sort** them by age.  
3. **Pad** each set (leaves, nodes) to size 500 (with zeros if fewer than 500).  
4. **Stack** leaves and nodes into 2-channel data: `(500, feature_dim, 2)`.

In [None]:
# This function takes in the tree encodings for both training and testing datasets
# and processes them to have a uniform shape. It also pads the leaves and nodes 
# of the trees to ensure each tree has a fixed number of 500 leaves and nodes.

def encode_pad_0s_rootage(enc, enc_test):
    # Create an empty list to hold padded training encodings
    enc_pad = []
    
    # Iterate over each tree in the training dataset
    for i in range(enc.shape[0]):
        # Separate the leaves (where column 3 has value 1, which indicates leaves)
        leaves = enc[i][enc[i,:,3] == 1]
        # Sort leaves by their age (assumed to be in column 1)
        leaves = leaves[np.argsort(leaves[:, 1])]
        # Pad the leaves array with 0s until it has a maximum size of 500 leaves
        leaves = np.pad(leaves, [(0, (500 - leaves.shape[0])), (0, 0)], mode='constant')

        # Separate the nodes (where column 3 is greater than 1, indicating internal nodes)
        nodes = enc[i][enc[i,:,3] > 1]
        # Sort nodes by their age (assumed to be in column 1)
        nodes = nodes[np.argsort(nodes[:, 1])]
        # Copy the last node's value to balance the number of leaves and nodes
        nodes = np.append(nodes, nodes[-1].reshape(1, -1), axis=0)
        # Pad the nodes array with 0s to ensure a size of 500 nodes
        nodes = np.pad(nodes, [(0, (500 - nodes.shape[0])), (0, 0)], mode='constant')
        
        # Stack the leaves and nodes arrays together along axis 2 (creating 2 channels)
        enc_pad.append(np.stack((leaves, nodes), axis=2))
    
    # Now process the test dataset (same procedure as above)
    enc_pad_test = []
    for i in range(enc_test.shape[0]):
        # Extract and sort leaves
        leaves = enc_test[i][enc_test[i,:,3] == 1]
        leaves = leaves[np.argsort(leaves[:, 1])]
        # Pad leaves to ensure size of 500
        leaves = np.pad(leaves, [(0, (500 - leaves.shape[0])), (0, 0)], mode='constant')

        # Extract and sort nodes
        nodes = enc_test[i][enc_test[i,:,3] > 1]
        nodes = nodes[np.argsort(nodes[:, 1])]
        # Copy the last node's value to balance the number of leaves and nodes
        nodes = np.append(nodes, nodes[-1].reshape(1, -1), axis=0)
        # Pad nodes to ensure size of 500
        nodes = np.pad(nodes, [(0, (500 - nodes.shape[0])), (0, 0)], mode='constant')
        
        # Stack the leaves and nodes arrays together along axis 2 (creating 2 channels)
        enc_pad_test.append(np.stack((leaves, nodes), axis=2))
    
    # Convert lists to numpy arrays and return the padded training and test data
    return np.array(enc_pad), np.array(enc_pad_test)


#Change encoding to order by root age and pad with 0s
encoding_pad_BD, encoding_pad_test_BD = encode_pad_0s_rootage(encoding_BD, encoding_test_BD)
encoding_pad_BDEI, encoding_pad_test_BDEI = encode_pad_0s_rootage(encoding_BDEI, encoding_test_BDEI)
encoding_pad_BDSS, encoding_pad_test_BDSS = encode_pad_0s_rootage(encoding_BDSS, encoding_test_BDSS)

#Combine encodings from the 3 models
encoding_pad = np.concatenate((encoding_pad_BD,encoding_pad_BDEI,encoding_pad_BDSS),axis=0)
encoding_pad_test = np.concatenate((encoding_pad_test_BD,encoding_pad_test_BDEI,encoding_pad_test_BDSS),axis=0)

#Delete intermediate variables
del(encoding_BD,encoding_BDEI,encoding_BDSS,encoding_pad_BD,encoding_pad_BDEI,encoding_pad_BDSS)
del(encoding_pad_test_BD,encoding_pad_test_BDEI,encoding_pad_test_BDSS)


#We **one-hot encode** `Y` (since it’s a 3-class classification) and split into training/validation:
Y = np.eye(3)[Y]

### Splitting Data into Training & Validation
# 30% for validation
Y, Y_valid, encoding_pad, encoding_pad_valid = train_test_split(
    Y, encoding_pad, test_size=0.3, shuffle=True, stratify=Y
)


## 4. Building & Training the CNN (2-Generation Context)

### Model Definition
We define a CNN that processes input of shape `(500, 19, 2)`:
- 500 = number of leaves or nodes (padded)
- 19 = number of features (including the newly added sampling probability)
- 2 = channels (leaves, nodes)

This architecture was inspired by the fact that internal nodes and leaves contribute differently to the tree likelihood calculation for multi-type birth-death models (MTBD, which includes BD, BDEI and BDSS; see Equation 8 in [Zhukova et al., 2023](https://academic.oup.com/sysbio/article/72/6/1387/7273092))

<img src="img/Figure_Architecture.png" width="1000" height="600"> 

<img src="img/Zhukova2023_formula.png" width="500" height="340"> 

In [None]:
# Creation of the Network Model: model definition
def build_model():
    # Initialize the Sequential model
    model = Sequential()
    
    # First convolutional layer: 
    # - Filters: 32 
    # - Kernel size: (1, 19), sliding across the second dimension of the input 
    # - Input shape: (500, 19, 2) where 500 is the number of tree leaves/nodes, 19 is the feature size, and 2 is the number of channels (leaves and nodes)
    # - Activation function: ELU (Exponential Linear Unit)
    # - Groups: 2 to apply separate convolutions for the two channels (leaves and nodes)
    model.add(Conv2D(filters=32, use_bias=False, kernel_size=(1, 19), input_shape=(500, 19, 2), activation='elu', groups=2))
    
    # Apply batch normalization to stabilize and speed up the training process
    model.add(BatchNormalization())
    
    # Second convolutional layer: 
    # - Filters: 32
    # - Kernel size: (1, 1) to process each feature independently
    # - Activation function: ELU
    model.add(Conv2D(filters=32, use_bias=False, kernel_size=(1, 1), activation='elu'))
    
    # Apply batch normalization again
    model.add(BatchNormalization())
    
    # Third convolutional layer: 
    # - Filters: 32
    # - Kernel size: (1, 1) for further feature processing
    # - Activation function: ELU
    model.add(Conv2D(filters=32, use_bias=False, kernel_size=(1, 1), activation='elu'))
    
    # Apply batch normalization for the final time before flattening
    model.add(BatchNormalization())
    
    # Flatten the 2D feature maps from the convolutional layers into a 1D vector, 
    # which will be passed to the fully connected (dense) layers
    model.add(GlobalAveragePooling2D())
    
    # Fully connected (FFNN) part:
    # Dense layers with decreasing number of units, all using ELU activation:
    model.add(Dense(64, activation='elu'))   # First dense layer with 64 units
    model.add(Dense(32, activation='elu'))   # Second dense layer with 32 units
    model.add(Dense(16, activation='elu'))   # Third dense layer with 16 units
    model.add(Dense(8, activation='elu'))    # Fourth dense layer with 8 units
    
    # Output layer: 
    # - 3 output neurons, corresponding to the 3 models
    # - Activation function: softmax
    model.add(Dense(3, activation='softmax'))
    
    # Show the summary of the model structure (number of layers, shapes of outputs, etc.)
    model.summary()

    # Return the constructed model
    return model

### Compilation & Fitting
Now we compile and fit the model.

In [None]:
from keras import losses

# Initialize the model using the build_model function that was previously defined
estimator = build_model()

# Compile the model:
# - Loss function: categorical_crossentropy is used to measure the error between the predicted probability distribution and the true distribution for multi-class classification tasks.
# - Optimizer: 'Adam' is used to minimize the loss function efficiently
# - Metrics: Accuracy is used to track the model's performance during training
estimator.compile(loss=keras.losses.categorical_crossentropy, optimizer = 'Adam', metrics=['accuracy'])

# Early stopping callback to prevent overfitting:
# - monitor: monitor the validation accuracy during training
# - patience: stop training if the validation accuracy doesn't improve for 100 consecutive epochs
# - mode: 'max' indicates that training will stop when the validation accuracy reaches its maximum
# - restore_best_weights: restore the weights from the best epoch (the one with the highest validation accuracy)
early_stop = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=100, mode='max', restore_best_weights=True)

# Custom callback to display training progress:
# - Print a dot for every epoch (or newline every 100 epochs) to indicate progress in training
class PrintD(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if epoch % 100 == 0:  # Print a newline every 100 epochs
            print('')
        print('.', end='')  # Print a dot to indicate progress during each epoch

# Set the maximum number of epochs (iterations over the entire dataset)
EPOCHS = 1000

# Train the model using the `fit` method:
# - encoding_pad: The padded training data (inputs)
# - Y: The target values (outputs)
# - verbose: set to 1 to print progress during training
# - epochs: The number of times to iterate over the entire dataset
# - validation_split: the fraction of data to use for validation (used to monitor validation loss)
# - batch_size: the number of samples per gradient update
# - callbacks: list of callbacks to be used during training (early stopping and progress display)
history = estimator.fit(encoding_pad, Y, verbose=1, epochs=EPOCHS, validation_data=(encoding_pad_valid, Y_valid), batch_size=1, callbacks=[early_stop, PrintD()])

### Evaluate the trained model
We evaluate our cassifier by using the test set, which was not seen by the network during training. We plot the results as a confusion matrix.

In [None]:
# Evaluate on test set
predicted_test = np.array(estimator.predict(encoding_pad_test))
pred_cat = [i.argmax() for i in predicted_test]

# Confusion matrix
print(confusion_matrix(Y_test, pred_cat))

## 5. Predicting empirical (real) data.
Our trained network can now be used to predict the most likely epidemiological model on real datasets.
We will use the the phylogenetic tree from [Rasmusen et al. (2017)](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1005448) with 200 HIV-1 sequences collected as part of the [Swiss Cohort Study (2010)](https://academic.oup.com/ije/article/39/5/1179/799735). 

<img src="img/HIV_tree.png" width="500" height="340"> 

In [None]:
# Load the data
encoding_Zurich = pd.read_csv(path + '/Encoded_Trees/Encoded_Zurich.csv', sep="\t", header=0, index_col=0).values.reshape(-1,1000,18)

### Preprocess according to the procedures done above.

# Delete the rescaling factor (not used here)
encoding_Zurich=np.delete(encoding_Zurich, -1, axis=1)

# Format the encoding, separating leaves and nodes in two channels.
def encode_pad_0s_rootage(enc):
    # Create an empty list to hold padded training encodings
    enc_pad = []
    
    # Iterate over each tree in the training dataset
    for i in range(enc.shape[0]):
        # Separate the leaves (where column 3 has value 1, which indicates leaves)
        leaves = enc[i][enc[i,:,3] == 1]
        # Sort leaves by their age (assumed to be in column 1)
        leaves = leaves[np.argsort(leaves[:, 1])]
        # Pad the leaves array with 0s until it has a maximum size of 500 leaves
        leaves = np.pad(leaves, [(0, (500 - leaves.shape[0])), (0, 0)], mode='constant')

        # Separate the nodes (where column 3 is greater than 1, indicating internal nodes)
        nodes = enc[i][enc[i,:,3] > 1]
        # Sort nodes by their age (assumed to be in column 1)
        nodes = nodes[np.argsort(nodes[:, 1])]
        # Copy the last node's value to balance the number of leaves and nodes
        nodes = np.append(nodes, nodes[-1].reshape(1, -1), axis=0)
        # Pad the nodes array with 0s to ensure a size of 500 nodes
        nodes = np.pad(nodes, [(0, (500 - nodes.shape[0])), (0, 0)], mode='constant')
        
        # Stack the leaves and nodes arrays together along axis 2 (creating 2 channels)
        enc_pad.append(np.stack((leaves, nodes), axis=2))
    # Convert lists to numpy arrays and return the padded data
    return np.array(enc_pad), np.array(enc_pad_test)

#Change encoding to order by root age and pad with 0s
encoding_pad_Zurich = encode_pad_0s_rootage(encoding_Zurich)

# predict values for the empirical dataset
predicted_emp = np.array(estimator.predict(encoding_pad_Zurich))

# Print the results
print(predicted_emp)