In [None]:
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm
from typing import List, Union, Optional, Callable
import pickle
from Bio import AlignIO, SeqIO
from ete3 import Tree, TreeNode
from gctree import CollapsedTree

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import umap
from ete3 import Tree, faces, TreeStyle, NodeStyle, TextFace, SequenceFace, COLOR_SCHEMES, CircleFace
from GCTree_preparation import *
import warnings

from Bio import Phylo
import math
import community as community_louvain
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import networkx as nx

# sklearn
from sklearn import metrics

warnings.filterwarnings("ignore")

path_to_storage = "/media/hieunguyen/HNSD01/storage/all_BSimons_datasets"
outdir = "/media/hieunguyen/HNSD_mini/outdir/sc_bulk_BCR_data_analysis_v0.1"

PROJECT = "220701_etc_biopsies"
path_to_main_output = f"{outdir}/tree_analysis/{PROJECT}"
path_to_01_output = os.path.join(path_to_main_output, "01_output")
os.system(f"mkdir -p {path_to_01_output}")

output_type = "mouse_based_output"

path_to_trees = os.path.join(path_to_storage, PROJECT, "GCtrees/v0.2", output_type)

all_tree_folder = [item for item in pathlib.Path(path_to_trees).glob("*") if 
                   os.path.isfile(f"{str(item)}/02_dnapars/gctree.out.inference.1.nk") == True]

all_nk_files = {all_tree_folder[i].name: os.path.join(all_tree_folder[i], "02_dnapars", "gctree.out.inference.1.nk") for i in range(len(all_tree_folder))}
print(f"Number of trees: {len(all_tree_folder)}")   

path_to_metadata = "/media/hieunguyen/HNSD01/src/sc_bulk_BCR_data_analysis/preprocessing/220701_etc_biopsies/metadata.csv"
mid_metadata = pd.read_csv(path_to_metadata, sep =";")

##### Re run the summary analysis of all trees and rendering tree figures
# rerun = True
rerun = False

path_to_04_output = os.path.join(outdir, "VDJ_output", "04_output")
thres = 0.85

clonedf = pd.read_csv(os.path.join(path_to_04_output, "full_clonedf_with_mutation_rate.csv"), index_col= [0])
clonedf = clonedf[clonedf['num_mutation'] != "region_not_covered-skip"]
clonedf = clonedf[clonedf['dataset.name'] == "220701_etc_biopsies"]

maindf = pd.read_csv(f"{path_to_01_output}/tree_summarydf.csv")
    # Reload the dictionary from the pickle file
with open(f"{path_to_01_output}/saveTreeobj.pkl", "rb") as f:
    saveTreeobj = pickle.load(f)

color_path = "./hex_color.csv"

In [None]:
##### analysis example for 1 tree. 
cloneid = "m30_IGHV1-82-01_IGHJ2-01_30_1.aln"

mouseid = cloneid.split("_")[0]
path_to_save_tree_svg = os.path.join(path_to_01_output, mouseid)
os.system(f"mkdir -p {path_to_save_tree_svg}")

treeobj = saveTreeobj[cloneid] 
avai_mids = treeobj.seqdf["MID"].unique()
mid_color_pal = pd.read_csv(color_path, index_col = [0]).to_dict()["hex color"]

ts = treeobj.generate_tree_style(color_path = color_path)
# treeobj.tree.render("%%inline", tree_style=ts) 

for input_mid in avai_mids:
    if input_mid == "GL":
        input_mid_col = "gray"
    else:
        input_mid_col = mid_color_pal[input_mid]
    ts.legend.add_face(CircleFace(10, input_mid_col), column = 0)
    ts.legend.add_face(TextFace(input_mid), column = 0)

idmapdf = treeobj.idmapseqdf.copy()
seqdf = treeobj.seqdf.copy()
seqdf["population"] = seqdf["MID"].apply(lambda x: mid_metadata[mid_metadata["Unnamed: 0"] == x]["population"].values[0])
seqdf = seqdf.merge(idmapdf, right_on = "seq", left_on = "seq")
treeobj.tree.render(f"%%inline", tree_style=ts) 


In [None]:

# https://stackoverflow.com/questions/43541376/how-to-draw-communities-with-networkx
show_plot = False

nw_path = all_nk_files[cloneid]

Tree = Phylo.read(nw_path, 'newick')
plt.figure(figsize=(20, 20))
G = Phylo.to_networkx(Tree)
pos = nx.spring_layout(G, seed = 42)  # Define the layout for the nodes
if show_plot:  
    nx.draw_networkx(G, pos=pos)
    plt.show()

plt.figure(figsize=(20, 20))
partition = community_louvain.best_partition(G, random_state = 42, resolution = 1)
cmap = cm.get_cmap('viridis', max(partition.values()) + 1)
if show_plot:
    nx.draw_networkx_nodes(G, pos, partition.keys(), node_size=40,
                       cmap=cmap, node_color=list(partition.values()))
    nx.draw_networkx_edges(G, pos, alpha=0.5)
    plt.show()

clusterdf = pd.DataFrame.from_dict(
    {
        "seq": [list(partition.keys())[i].name for i in range(len(partition.keys()))],
        "cluster": list(partition.values())
    }
)
seqdf["cluster"] = seqdf["seqid"].apply(lambda x: clusterdf[clusterdf["seq"] == x]["cluster"].values[0] 
                                      if clusterdf[clusterdf["seq"] == x]["cluster"].shape[0] != 0 else "error")
rand_index = metrics.rand_score(seqdf.cluster.values, seqdf.MID.values)
