# Generate IBS-based Neighbor-Joining tree

## Sheep samples

Collect the latest SMARTER genotype dataset and unpack it into a folder:

```bash
unzip SMARTER-OA-OAR3-top-0.4.10.zip -d SMARTER-OA-OAR3-top-0.4.10
```

Next, start by generate a IBS matrix of sheep samples using `plink`. Remove missing 
data from the whole dataset:

```bash
plink --chr-set 26 no-xy no-mt --allow-no-sex --bfile SMARTER-OA-OAR3-top-0.4.10/SMARTER-OA-OAR3-top-0.4.10 \
    --geno 0.1 --Z-genome --genome full --out SMARTER-OA-OAR3-top-0.4.10/ibs_output
```

Now, let's do the python stuff:

In [1]:
import json

import pandas as pd
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, to_tree
import plotly.graph_objects as go
from ete3 import Tree, TreeStyle, NodeStyle, TextFace

from src.features.utils import get_project_dir
from src.features.smarterdb import global_connection, SampleSheep

In [2]:
conn = global_connection()

In [3]:
# Load IBS data from PLINK
dtype_dict = {'IID1': str, 'IID2': str, 'DST': float}
ibs_data = pd.read_csv(
    get_project_dir() / "notebooks/results/SMARTER-OA-OAR3-top-0.4.10/ibs_output.genome.gz",
    sep=r'\s+',
    usecols=['IID1', 'IID2', 'DST'],
    dtype=dtype_dict
)

In [4]:
# This transform a pairwise items to a matrix
ibs_pivot = ibs_data.pivot(index='IID1', columns='IID2', values='DST').fillna(0)

# This will transform the IBS values to a distance matrix: 0 means identical, 1 means different
ibs_pivot = 1 - ibs_pivot
np.fill_diagonal(ibs_pivot.values, 0)
upper_triangular_matrix = ibs_pivot.values
individuals = ibs_pivot.index.values

The matrix I have is only the upper triangle of the matrix, so I need to mirror it to get the full matrix. This could be done by adding the transposed matrix to the original matrix. The diagonal should be subtracted to avoid double counting, in this case however is zero, since I called `np.fill_diagonal` with zeros.

In [None]:
distance_matrix = upper_triangular_matrix + upper_triangular_matrix.T - np.diag(upper_triangular_matrix.diagonal())

def is_symmetric(matrix, tol=1e-8):
    return np.allclose(matrix, matrix.T, atol=tol)

is_symmetric(distance_matrix)

In [6]:
# Generate NJ tree using hierarchical clustering (same as before)
condensed_dist = squareform(distance_matrix)
Z = linkage(condensed_dist, method='average')
tree = to_tree(Z, rd=False)

In [7]:
# Delete unused variables to free memory
del dtype_dict, ibs_data, ibs_pivot, upper_triangular_matrix, condensed_dist, Z

Generate a tree structure

In [8]:
def get_newick(node, newick, parentdist, leaf_names):
    if node.is_leaf():
        return "%s:%f%s" % (leaf_names[node.id], parentdist - node.dist, newick)
    else:
        if len(newick) > 0:
            newick = "):%f%s" % (parentdist - node.dist, newick)
        else:
            newick = ");"
        newick = get_newick(node.get_left(), newick, node.dist, leaf_names)
        newick = get_newick(node.get_right(), ",%s" % (newick), node.dist, leaf_names)
        return "(%s" % (newick)

newick_tree = get_newick(tree, "", tree.dist, individuals)

Now that you have the NJ tree in Newick format, you can use the ete3 library to visualize it and add breed labels. Get full breed name from the database:

In [9]:
breed_info = pd.DataFrame(data=json.loads(SampleSheep.objects.all().fields(smarter_id=True, breed=True).to_json()))
breed_info["_id"] = breed_info["_id"].apply(lambda x: x["$oid"])
breed_info = breed_info.set_index("_id")

# Merge breed information with the list of individuals
label_dict = dict(zip(breed_info.smarter_id, breed_info.breed))

Visualize the NJ tree with breed labels:

In [11]:
# Load the tree from the Newick format
t = Tree(newick_tree)

# Assign breed labels to the tree tips
for leaf in t:
    leaf_name = leaf.name
    if leaf_name in label_dict:
        leaf.name = label_dict[leaf_name]  # Replace individual ID with breed

# Style the tree
def render_tree(t):
    ts = TreeStyle()
    ts.show_leaf_name = True  # Show breed labels on the tree

    # Customize the node style
    nstyle = NodeStyle()
    nstyle["fgcolor"] = "darkred"
    nstyle["size"] = 10

    for leaf in t.iter_leaves():
        leaf.set_style(nstyle)

    # Show the tree with labels (all external application)
    # t.show(tree_style=ts)
    t.render("nj_tree_with_breeds.png", w=800, units="px", tree_style=ts)

render_tree(t)

In [12]:
# Helper function to convert hierarchical tree to plotly tree format
def add_nodes(node, pos=None, x=0, y=0, width=400, labels={}):
    if pos is None:
        pos = {}
    pos[node.id] = (x, y)

    if node.left:
        pos = add_nodes(node.left, pos=pos, x=x-width/2, y=y-1, width=width/2)
        pos = add_nodes(node.right, pos=pos, x=x+width/2, y=y-1, width=width/2)

    return pos

# Convert hierarchical tree to positions
positions = add_nodes(tree)

In [None]:
# Prepare the edge list
def get_edges(tree, pos):
    edges = []
    if tree.left:
        edges.append((tree.id, tree.left.id))
        edges.extend(get_edges(tree.left, pos))
    if tree.right:
        edges.append((tree.id, tree.right.id))
        edges.extend(get_edges(tree.right, pos))
    return edges

edges = get_edges(tree, positions)

# Prepare nodes and edges for Plotly
x_nodes = [positions[i][0] for i in positions]
y_nodes = [positions[i][1] for i in positions]

# Extract node names from individuals
node_labels = [individuals[i] if i < len(individuals) else f'Node {i}' for i in positions]

# Prepare edge coordinates for Plotly
x_edges = []
y_edges = []
for edge in edges:
    x_edges += [positions[edge[0]][0], positions[edge[1]][0], None]
    y_edges += [positions[edge[0]][1], positions[edge[1]][1], None]

# Create the trace for edges (lines)
edge_trace = go.Scatter(
    x=x_edges, y=y_edges,
    line=dict(width=1, color='gray'),
    hoverinfo='none',
    mode='lines'
)

# Create the trace for nodes (markers)
node_trace = go.Scatter(
    x=x_nodes, y=y_nodes,
    mode='markers+text',
    text=node_labels,
    textposition="top center",
    hoverinfo='text',
    marker=dict(
        showscale=False,
        color='lightblue',
        size=10,
        line_width=2
    )
)

# Create Plotly figure
fig = go.Figure(data=[edge_trace, node_trace])

# Customize layout
fig.update_layout(
    showlegend=False,
    hovermode='closest',
    margin=dict(b=0, l=0, r=0, t=0),
    xaxis=dict(showgrid=False, zeroline=False),
    yaxis=dict(showgrid=False, zeroline=False)
)

# Show interactive tree
fig.show()

In [None]:
x_nodes