# Determine mislabelling using NJ tree and IBS distance matrix

In this example, we will demonstrate how to determine mislabelling using the Neighbor 
Joining (NJ) tree and a random distance matrix. We will use the following steps 
to determine mislabelling:

## Step 1: Build the NJ Tree from the IBS Distance Matrix

In [None]:
import numpy as np
from Bio import Phylo
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceMatrix
import pandas as pd

np.random.seed(42)  # For reproducibility

# Example IBS distance matrix (replace with actual data)
# Assuming 'distance_df' is your IBS-based distance matrix with individuals in rows/columns
# distance_df = pd.read_csv("ibs_distance_matrix.csv", index_col=0)
# In this example, we simulate a distance matrix
samples = [f'Sample {i}' for i in range(10)]
distance_matrix = np.random.rand(10, 10)
np.fill_diagonal(distance_matrix, 0)  # Set diagonal to 0 (distance to self is zero)

# Convert distance matrix to lower triangular format
lower_triangle_matrix = []
for i in range(len(samples)):
    lower_triangle_matrix.append(distance_matrix[i, :i+1].tolist())  # Take lower triangular part

# Convert distance matrix into BioPython DistanceMatrix
matrix = DistanceMatrix(names=samples, matrix=lower_triangle_matrix)

# Build NJ tree
constructor = DistanceTreeConstructor()
nj_tree = constructor.nj(matrix)

# Visualize the tree (optional)
Phylo.draw_ascii(nj_tree)

## Step 2: Group Individuals Based on NJ Tree

In [None]:
# Traverse the NJ tree and extract clusters based on total pairwise distance within clades
def extract_clusters_by_total_distance(tree, distance_threshold):
    clusters = []

    # Traverse the internal nodes of the tree
    for clade in tree.get_nonterminals():
        # Get all terminal nodes (samples) under this clade
        terminals = clade.get_terminals()
        if len(terminals) < 2:
            continue  # Ignore clades with less than two terminals

        # Calculate the maximum pairwise distance between samples in this clade
        max_pairwise_distance = 0
        for term1 in terminals:
            for term2 in terminals:
                distance = tree.distance(term1, term2)
                if distance > max_pairwise_distance:
                    max_pairwise_distance = distance

        # If the maximum pairwise distance is below the threshold, consider it a cluster
        if max_pairwise_distance <= distance_threshold:
            cluster_samples = [term.name for term in terminals]
            clusters.append(cluster_samples)

    # Remove clusters that are subsets of other clusters
    unique_clusters = remove_subset_clusters(clusters)

    return unique_clusters

# Helper function to remove clusters that are subsets of other clusters
def remove_subset_clusters(clusters):
    clusters = sorted(clusters, key=len, reverse=True)  # Sort clusters by size (largest first)
    unique_clusters = []

    for cluster in clusters:
        is_subset = False
        for unique_cluster in unique_clusters:
            if set(cluster).issubset(set(unique_cluster)):
                is_subset = True
                break
        if not is_subset:
            unique_clusters.append(cluster)

    return unique_clusters

# Example usage:
distance_threshold = 0.3  # Define your threshold for maximum pairwise distance within clusters
clusters = extract_clusters_by_total_distance(nj_tree, distance_threshold)
print("Unique Clusters (after removing subsets):", clusters)

In [None]:
Phylo.draw_ascii(nj_tree)

## Step 3: Compare Genetic Clusters with Declared Breed Labels

In [None]:
# Assume 'assignment_df' contains 'sample' and 'breed' information
assignment_df = pd.DataFrame({
    'sample': samples,
    'breed': ['Breed A', 'Breed A', 'Breed B', 'Breed B', 'Breed A', 'Breed A', 'Breed C', 'Breed B', 'Breed C', 'Breed C']
})

# Create a new column for cluster IDs
cluster_assignments = []

# Loop through each sample in the assignment_df
for sample in assignment_df['sample']:
    assigned_cluster = None
    # Check which cluster the sample belongs to
    for cluster_id, cluster in enumerate(clusters, 1):  # Cluster IDs start from 1
        if sample in cluster:
            assigned_cluster = cluster_id
            break
    cluster_assignments.append(assigned_cluster)

# Add the cluster information as a new column
assignment_df['detected_cluster'] = cluster_assignments

# Display the updated assignment_df
assignment_df

In [None]:
# Function to label clusters with the most common (mode) breed and detect mislabelled samples
def check_and_label_clusters(clusters, assignment_df):
    mislabelled = []
    cluster_labels = {}

    for cluster_id, cluster in enumerate(clusters, 1):
        # Get the declared breeds for all samples in the cluster
        cluster_breeds = assignment_df[assignment_df['sample'].isin(cluster)]['breed']

        # Find the most common (mode) breed using value_counts()
        mode_breed = cluster_breeds.value_counts().idxmax()  # The most frequent breed in the cluster

        # Label the cluster with the mode breed
        cluster_labels[cluster_id] = mode_breed

        # Flag individuals whose breed doesn't match the mode
        for sample in cluster:
            declared_breed = assignment_df[assignment_df['sample'] == sample]['breed'].values[0]
            if declared_breed != mode_breed:
                mislabelled.append(sample)

    return mislabelled, cluster_labels

# Function to update assignment_df with the cluster labels
def update_assignment_with_clusters(assignment_df, clusters, cluster_labels):
    # Create a new column for the cluster labels
    assignment_df['cluster_label'] = None  # Initialize the column

    # Assign cluster labels to the corresponding samples
    for cluster_id, cluster in enumerate(clusters, 1):
        mode_breed = cluster_labels[cluster_id]
        assignment_df.loc[assignment_df['sample'].isin(cluster), 'cluster_label'] = mode_breed

    return assignment_df

# Find mislabelled individuals and assign cluster labels
mislabelled_breeds, cluster_labels = check_and_label_clusters(clusters, assignment_df)

# Update the DataFrame to store the assigned cluster labels
updated_assignment_df = update_assignment_with_clusters(assignment_df, clusters, cluster_labels)

# Output
print("Mislabelled or crossbreed samples:", mislabelled_breeds)
print("Cluster labels (mode breed):", cluster_labels)
print("Updated DataFrame with Cluster Labels:")
updated_assignment_df
