# 3D Matrix Generation Notebook

## Load modules

In [92]:
import pandas as pd
import numpy as np
import igraph as ig
import chart_studio.plotly as ply
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from scipy.spatial import ConvexHull
import random
import os
import pathlib
from tqdm.notebook import tqdm
import kaleido
import json
from Bio import SeqIO

## Define I/O

In [73]:
workdir = pathlib.Path('/home/mf019/longread_pangenome/expanded_dataset_analysis/')
results_dir = workdir.joinpath('output')
alns_dir = results_dir.joinpath('alignments','ava','new_ava_v3') # not used?
networks_dir = results_dir.joinpath('homology_networks', 'nucl_v5_v4')
raw_matrix = networks_dir.joinpath('fixed_nucl_v5_homology_network_alignment_matrix.tsv')
edge_matrix = 'nucl_v5_matrix_edges_repl_20250122_2111.csv'#networks_dir.joinpath('fixed_nucl_v5_pd_matrix_edges.csv')
id_mapping = results_dir.joinpath('genotyping','replicons','calls_v10','best_hits_1000bp_v10.2.csv')

In [62]:
multifasta = workdir.joinpath("assemblies/all_contigs_v5_fixed.fna")
lengths = {}
for record in SeqIO.parse(multifasta, "fasta"):
    lengths[record.id] = len(record.seq)

## Main function definitions

In [74]:
def edges_from_matrix(matrix, output, min_contig_length=750, min_weight=0.0):
    """
    Convert alignment matrix to edge list with weights calculated as Jaccard index.
    
    Args:
        matrix: Path to alignment matrix TSV
        output: Path to save edge list
        min_contig_length: Minimum contig length to include (default: 1000)
        min_weight: Minimum edge weight to include (default: 0.0)
    """
    df = pd.read_csv(matrix, sep='\t', index_col=0, header=0)
    
    source = []
    target = []
    weight = []
    length_i = []  # Store lengths for debugging
    length_j = []
    align_length = []
    
    for i in tqdm(range(len(df.index)), desc='Processing matrix rows'):
        for j in range(i+1, len(df.columns)):
            if df.iloc[i,j] > 0:
                if df.iloc[i,i] >= min_contig_length and df.iloc[j,j] >= min_contig_length:
                    score = ((2 * df.iloc[i,j]) / (df.iloc[i,i] + df.iloc[j,j]))
                    
                    if score >= min_weight:
                        source.append(df.index[i])
                        target.append(df.columns[j])
                        weight.append(score)
                        length_i.append(df.iloc[i,i])
                        length_j.append(df.iloc[j,j])
                        align_length.append(df.iloc[i,j])

    edge_df = pd.DataFrame({
        'Source': source,
        'Target': target,
        'weight': weight,
        'source_length': length_i,
        'target_length': length_j,
        'alignment_length': align_length,
        'interaction': ['interacts'] * len(source)
    })
    
    # Print some statistics
    print(f"\nEdge Statistics:")
    print(f"Total edges: {len(edge_df)}")
    print(f"Unique source nodes: {len(edge_df['Source'].unique())}")
    print(f"Unique target nodes: {len(edge_df['Target'].unique())}")
    print(f"Weight range: {edge_df['weight'].min():.3f} - {edge_df['weight'].max():.3f}")
    
    edge_df.to_csv(output, index=False)
    print(f'Saved edges to {output}')
    return edge_df

def create_igraph(edge_df):
    nodes = list(set(edge_df['Source'].to_list() + edge_df['Target'].to_list()))
    nodes = [str(node) for node in nodes]  # Convert all nodes to strings
    G = ig.Graph()
    G.add_vertices(nodes)
    edges = [(str(row['Source']), str(row['Target'])) for _, row in edge_df[['Source', 'Target']].iterrows()]
    G.add_edges(edges)
    G.es['weight'] = edge_df['weight'].tolist()
    return G, nodes  # Return both the graph and the order of nodes

def align_labels_and_groups(idmap, vertex_order):
    longread_idmap = idmap
    # Create a dictionary mapping contig names to their labels and groups
    node_dict = longread_idmap.apply(lambda x: f"{x['assembly_id']}__{x['contig_id'].split(' ')[0]}", axis=1)
    node_dict = dict(zip(node_dict, longread_idmap['plasmid_name']))
    
    aligned_labels = []
    aligned_groups = []

    print("Sample of idmap:")
    print(longread_idmap.head())

    print("\nSample of vertex_order:")
    print(vertex_order[:10])

    for vertex in vertex_order:
        if vertex in node_dict:
            group = node_dict[vertex]
            # Handle specific plasmid name mappings
            if group == 'cp32-2':
                group = 'cp32-7'
            elif group == 'cp32-9-4':
                group = 'cp32-9'
            elif group == 'cp32-1+5':
                group = 'cp32-1'            
            elif group == 'cp32-3+10':
                group = 'cp32-3'
            elif group == 'cp32-1+5':
                group = 'cp32-1'
            elif group == 'cp32-5+1' or group == 'cp32-5-1':  # Add these mappings
                group = 'cp32-5'
            elif vertex == 'URI88H_contig000014':
                group = 'lp21-cp9'
            
            aligned_groups.append(group)
            aligned_labels.append(vertex)
        else:
            print(f'Fragment found: {vertex} (assigned as unknown)')
            aligned_labels.append(vertex)
            aligned_groups.append('unknown')

    print(f"\nTotal vertices: {len(vertex_order)}")
    print(f"Aligned labels: {len(aligned_labels)}")
    print(f"Aligned groups: {len(aligned_groups)}")
    print(f"Sample of aligned labels: {aligned_labels[:10]}")
    print(f"Sample of aligned groups: {aligned_groups[:10]}")
    print(f"Unique groups: {set(aligned_groups)}")

    return aligned_labels, aligned_groups

# random colors
# def create_color_mapping(groups):
#    unique_groups = list(set(groups))
#    n_colors = len(unique_groups)
#    colors = px.colors.qualitative.Plotly * (n_colors // len(px.colors.qualitative.Plotly) + 1)
#    #colors = px.colors.qualitative.Plotly[:n_colors]
#    return {group: colors[i] for i, group in enumerate(unique_groups)}

def create_color_mapping(groups):
    color_map = {
        "cp26": "#d60000", "lp54": "#018700", "lp17": "#b500ff", "lp28-3": "#05acc6",
        "lp28-4": "#97ff00", "lp38": "#ffa52f", "cp32-7": "#ff8ec8", "cp32-4": "#79525e",
        "lp36": "#00fdcf", "cp32-6": "#afa5ff", "lp25": "#93ac83", "lp28-1": "#9a6900",
        "cp32-3": "#366962", "cp32-9": "#d3008c", "cp32-5": "#fdf490", "cp32-11": "#c86e66",
        "cp32-12": "#9ee2ff", "cp32-10": "#00c846", "cp32-3+10": "#1daf50", "lp28-2": "#a877ac",
        "lp28-6": "#b8ba01", "lp21": "#f4bfb1", "lp28-5": "#ff28fd", "cp32-8": "#f2cdff",
        "cp32-1": "#009e7c", "cp9": "#ff6200", "lp56": "#56642a", "lp28-7": "#953f1f",
        "cp32-13": "#90318e", "lp5": "#ff3464", "lp21-cp9": "#a0e491", "cp9-3": "#8c9ab1",
        "lp28-8": "#829026", "lp28-9": "#ae083f", "lp28-11": "#c677b4", "chromosome": "#9eecff",
        "cp32-1+5": "#7b4b94",  # Added new color
        "lp32-3": "#4a9375",    # Added new color
        "Unclassified": "#cccccc",    # Gray for unclassified
        "none": "hsla(0, 0.00%, 0.00%, 0.00)"
    }

    unique_groups = set(groups)
    missing = unique_groups - set(color_map.keys())
    if missing:
        print(f"\nWarning: Missing colors for groups: {missing}")
        for group in missing:
            color_map[group] = "#cccccc"  # Use gray for any unknown groups
    
    return color_map

## 3D plot output as interactive html

In [75]:
def make_3d_plot_v2(G, labels, groups, output_file, layout, lengths):
    layout = G.layout(layout, dim=3)
    Xn = [layout[k][0] for k in range(len(G.vs))]
    Yn = [layout[k][1] for k in range(len(G.vs))]
    Zn = [layout[k][2] for k in range(len(G.vs))]

    color_map = create_color_mapping(groups)
    color_values = [color_map[group] for group in groups]

    # Create a list of unique groups
    unique_groups = list(set(groups))
    print(unique_groups)

    # Create a single edge trace for all connections
    Xe, Ye, Ze = [], [], []
    for e in G.es:
        Xe += [layout[e.source][0], layout[e.target][0], None]
        Ye += [layout[e.source][1], layout[e.target][1], None]
        Ze += [layout[e.source][2], layout[e.target][2], None]

    edge_trace = go.Scatter3d(
        x=Xe,
        y=Ye,
        z=Ze,
        mode='lines',
        name='Edges',
        line=dict(color='rgb(125,125,125)', width=0.35),
        hoverinfo='none',
        showlegend=False
    )

    # Create node traces for each group
    node_traces = []
    for group in unique_groups:
        group_indices = [i for i, g in enumerate(groups) if g == group]
        node_traces.append(
            go.Scatter3d(
                x=[Xn[i] for i in group_indices],
                y=[Yn[i] for i in group_indices],
                z=[Zn[i] for i in group_indices],
                mode='markers',
                name=group,
                marker=dict(
                    symbol='circle',
                    size=6,
                    color=color_map[group],
                    line=dict(color='rgb(38, 38, 38)', width=0.35),
                ),
                #text=[labels[i] for i in group_indices],
                text=[f'Label: {labels[i]}<br>Group: {group}<br>Length: {lengths[labels[i]]}' for i in group_indices],
                hoverinfo='text'
            )
        )

    # Combine all traces
    data = [edge_trace] + node_traces

    # Create slider for edge opacity
    steps = []
    for step in np.arange(0, 1.1, 0.1):
        step = round(step, 2)
        steps.append(
            dict(
                method="update",
                args=[{"opacity": [step, *[1]*len(node_traces)]}],  # First trace is edges, rest are nodes
                label=str(step)
            )
        )

    sliders = [dict(
        active=5,
        currentvalue={"prefix": "Edge Opacity: "},
        pad={"t": 5, "b": 10},
        steps=steps
    )]

    layout3d = go.Layout(
        title="Sequence Homology between contigs across 83 long-read assemblies (3D)",
        scene=dict(
            xaxis=dict(title=''),
            yaxis=dict(title=''),
            zaxis=dict(title=''),
        ),
        margin=dict(r=0, l=0, b=0, t=100),
        hovermode='closest',
        legend=dict(
            itemsizing='constant',
            title_text='Plasmids',
            bgcolor='rgba(255,255,255,0.5)',
            bordercolor='rgba(0,0,0,0)',
            borderwidth=2
        ),
        sliders=sliders,
        #annotations=[
        #    dict(
        #        showarrow=False,
        #        text="MJF-2024",
        #        xref='paper',
        #        yref='paper',
        #        x=0,
        #        y=0.1,
        #        xanchor='left',
        #        yanchor='bottom',
        #        font=dict(
        #            size=8
        #        )
        #    )
        #]
    )

    fig = go.Figure(data=data, layout=layout3d)

    fig.update_layout(
        scene_aspectmode='data',
        autosize=True,
        uirevision=True
    )

    config = {
        'responsive': True,
        'scrollZoom': True,
    }

    img = pio.to_image(fig, 'png')
    with open(f'{pathlib.Path(output_file).stem}.png', 'wb') as outfile:
        outfile.write(img)

    pio.write_html(fig, file=output_file, full_html=False, include_plotlyjs='cdn', config=config)
    print("html written!")

## 3D plot output as png

In [76]:
def make_3d_plot_v2_png(G, labels, groups, output_file, layout):
    layout = G.layout(layout, dim=3)
    Xn = [layout[k][0] for k in range(len(G.vs))]
    Yn = [layout[k][1] for k in range(len(G.vs))]
    Zn = [layout[k][2] for k in range(len(G.vs))]

    color_map = create_color_mapping(groups)
    color_values = [color_map[group] for group in groups]

    # Create a list of unique groups
    unique_groups = list(set(groups))
    print(unique_groups)

    # Create a single edge trace for all connections
    Xe, Ye, Ze = [], [], []
    for e in G.es:
        Xe += [layout[e.source][0], layout[e.target][0], None]
        Ye += [layout[e.source][1], layout[e.target][1], None]
        Ze += [layout[e.source][2], layout[e.target][2], None]

    # Create node traces for each group
    node_traces = []
    for group in unique_groups:
        group_indices = [i for i, g in enumerate(groups) if g == group]
        node_traces.append(
            go.Scatter3d(
                x=[Xn[i] for i in group_indices],
                y=[Yn[i] for i in group_indices],
                z=[Zn[i] for i in group_indices],
                mode='markers',
                name=group,
                marker=dict(
                    symbol='circle',
                    size=6,
                    color=color_map[group],
                    line=dict(color='rgb(38, 38, 38)', width=0.35),
                ),
                #text=[labels[i] for i in group_indices],
                text=[f'Label: {labels[i]}<br>Group: {group}' for i in group_indices],
                hoverinfo='text'
            )
        )

    # Combine all traces
    data = node_traces

    # Create slider for edge opacity
    steps = []
    for step in np.arange(0, 1.1, 0.1):
        step = round(step, 2)
        steps.append(
            dict(
                method="update",
                args=[{"opacity": [step, *[1]*len(node_traces)]}],  # First trace is edges, rest are nodes
                label=str(step)
            )
        )

    layout3d = go.Layout(
        title="Sequence Homology between contigs across 83 long-read assemblies (3D)",
        scene=dict(
            xaxis=dict(title=''),
            yaxis=dict(title=''),
            zaxis=dict(title=''),
        ),
        margin=dict(r=0, l=0, b=0, t=100),
        hovermode='closest',
        legend=dict(
            itemsizing='constant',
            title_text='Plasmids',
            bgcolor='rgba(255,255,255,0.5)',
            bordercolor='rgba(0,0,0,0)',
            borderwidth=2
        ),
        #sliders=sliders,
        #annotations=[
        #    dict(
        #        showarrow=False,
        #        text="MJF-2024",
        #        xref='paper',
        #        yref='paper',
        #        x=0,
        #        y=0.1,
        #        xanchor='left',
        #        yanchor='bottom',
        #        font=dict(
        #            size=8
        #        )
        #    )
        #]
    )

    fig = go.Figure(data=data, layout=layout3d)

    fig.update_layout(
        scene_aspectmode='data',
        autosize=True,
        uirevision=True
    )

    config = {
        'responsive': True,
        'scrollZoom': True,
    }

    img = pio.to_image(fig, 'png', width=5000, height=5000, scale=1)
    with open(f'{pathlib.Path(output_file).stem}.png', 'wb') as outfile:
        outfile.write(img)

    #pio.write_html(fig, file=output_file, full_html=False, include_plotlyjs='cdn', config=config)
    print("png written!")

In [77]:
def igraph_to_coords_file(G, aligned_labels, groups):
    color_map = create_color_mapping(groups)
    color_values = [color_map[group] for group in groups]
    layout = G.layout('kk', dim=3)
    Xn = [layout[k][0] for k in range(len(G.vs))]
    Yn = [layout[k][1] for k in range(len(G.vs))]
    Zn = [layout[k][2] for k in range(len(G.vs))]
    nodes = {
            i: {
                "location": [
                    Xn[i], Yn[i], Zn[i]
                ],
                "name": aligned_labels[i],
                "color": color_values[i],
                "plasmid": groups[i],
            } for i in range(G.vcount())
     } #
    edges = [{"source": e.source, "target": e.target} for e in G.es]

    return {"nodes": nodes, "edges": edges}

## Generate Graph and Output Plots

In [78]:
if not os.path.exists(edge_matrix):
    edge_df = edges_from_matrix(raw_matrix, edge_matrix)
else:
    edge_df = pd.read_csv(edge_matrix)
    print(f'Edge Matrix read from: {edge_matrix}')

Edge Matrix read from: nucl_v5_matrix_edges_repl_20250122_2111.csv


In [79]:
G, vertex_order = create_igraph(edge_df)

In [80]:
idmap = pd.read_csv(id_mapping, sep=',', header=0)
idmap['contig_id'] = idmap['contig_id'].apply(lambda x: x.split('[')[0])

In [81]:
aligned_labels, aligned_groups = align_labels_and_groups(idmap, vertex_order)

Sample of idmap:
  assembly_id   contig_id  contig_len                    plasmid_id  \
0       B331P   contig_1       903654        B500_chromosome_ParA_2   
1       B331P  contig_10         8714                gb|CP017210.1|   
2       B331P  contig_11        15594  RS00875_MM1_plsm_lp17_ParA_X   
3       B331P  contig_12        24722   RS00040_ZS7_ZS7_lp25_ParA_X   
4       B331P  contig_13        27641         H28_B31_lp28-3_ParA_X   

  plasmid_name   strain  query_length  ref_length  overall_percent_identity  \
0   chromosome     B500        903654         380                     100.0   
1          cp9     B331          8714        8714                     100.0   
2         lp17  RS00875         15594         246                     100.0   
3         lp25  RS00040         24722         252                     100.0   
4       lp28-3      H28         27641         251                     100.0   

   query_covered_length  ref_covered_length covered_intervals  \
0               

In [82]:
label2group = dict(zip(aligned_labels, aligned_groups))

In [83]:
#layouts = [
#    'circle', 'dh', 'drl', 'fr', 'fr3d', 'graphopt',
#    'grid', 'kk', 'kk3d', 'large', 'mds', 'random',
#    'random_3d', 'rt', 'rt_circular', 'circular_3d',
#]
#layout = 'fr3d'
layout = 'kk3d'
output_file = networks_dir.joinpath(f'igraph_asm_ava_homology_nucl_v5_{layout}_w001_aln500_l1500.html')
make_3d_plot_v2(G, aligned_labels, aligned_groups, output_file, layout, lengths)
make_3d_plot_v2_png(G, aligned_labels, aligned_groups, output_file, layout)
#make_3d_plot_v3(G, aligned_labels, aligned_groups, output_file)

['lp17', 'cp32-6', 'lp25', 'cp32-10', 'cp32-4', 'cp32-9', 'cp32-5', 'lp54', 'lp28-2', 'cp32-11', 'lp21-cp9', 'cp26', 'cp9-3', 'cp32-12', 'cp32-3', 'lp28-7', 'chromosome', 'lp28-8', 'lp36', 'cp32-8', 'lp28-11', 'cp32-1', 'lp56', 'cp9', 'cp32-7', 'lp38', 'cp32-13', 'lp21', 'lp28-1', 'lp28-5', 'lp32-3', 'lp28-3', 'lp28-9', 'lp28-6', 'lp5', 'lp28-4']
html written!
['lp17', 'cp32-6', 'lp25', 'cp32-10', 'cp32-4', 'cp32-9', 'cp32-5', 'lp54', 'lp28-2', 'cp32-11', 'lp21-cp9', 'cp26', 'cp9-3', 'cp32-12', 'cp32-3', 'lp28-7', 'chromosome', 'lp28-8', 'lp36', 'cp32-8', 'lp28-11', 'cp32-1', 'lp56', 'cp9', 'cp32-7', 'lp38', 'cp32-13', 'lp21', 'lp28-1', 'lp28-5', 'lp32-3', 'lp28-3', 'lp28-9', 'lp28-6', 'lp5', 'lp28-4']
png written!


## Generate JSON for blender rendering

In [84]:
# Convert the graph and save to JSON
graph_data = igraph_to_coords_file(G, aligned_labels, aligned_groups)
with open(networks_dir.joinpath('network.json'), 'w') as f:
    json.dump(graph_data, f, indent=4)

In [89]:
def make_3d_plot_v2_mayavi(G, labels, groups, output_file, layout):
    # Generate 3D layout for the graph
    layout_coords = G.layout(layout, dim=3)
    node_positions = {i: layout_coords[i] for i in range(len(G.vs))}

    # Extract node positions
    Xn = [node_positions[k][0] for k in G.vs.indices]
    Yn = [node_positions[k][1] for k in G.vs.indices]
    Zn = [node_positions[k][2] for k in G.vs.indices]

    # Create color mapping for groups
    unique_groups = list(set(groups))
    colors = plt.cm.viridis(np.linspace(0, 1, len(unique_groups)))
    group_to_color = {group: color[:3] for group, color in zip(unique_groups, colors)}  # RGB only

    # Prepare edges
    edges = G.es
    Xe, Ye, Ze = [], [], []
    for e in edges:
        src, tgt = e.source, e.target
        Xe += [node_positions[src][0], node_positions[tgt][0], None]
        Ye += [node_positions[src][1], node_positions[tgt][1], None]
        Ze += [node_positions[src][2], node_positions[tgt][2], None]

    # Mayavi 3D plot
    mlab.figure(size=(1000, 800), bgcolor=(1, 1, 1))

    # Plot edges
    for i in range(0, len(Xe), 3):
        mlab.plot3d(
            Xe[i:i+2], Ye[i:i+2], Ze[i:i+2],
            color=(0.5, 0.5, 0.5), tube_radius=0.02, opacity=0.5
        )

    # Plot nodes
    for group in unique_groups:
        group_indices = [i for i, g in enumerate(groups) if g == group]
        mlab.points3d(
            [Xn[i] for i in group_indices],
            [Yn[i] for i in group_indices],
            [Zn[i] for i in group_indices],
            scale_factor=0.1,
            color=group_to_color[group]
        )

    # Add node labels
    for i, label in enumerate(labels):
        mlab.text3d(
            Xn[i], Yn[i], Zn[i], label,
            scale=(0.05, 0.05, 0.05), color=(0, 0, 0)
        )

    # Set plot settings
    mlab.orientation_axes()
    mlab.view(azimuth=45, elevation=75, distance=8)
    mlab.savefig(output_file, size=(1920, 1080))
    mlab.show()

    print(f"3D plot saved to {output_file}")

In [None]:

mayavi_output_file = '3d_network_graph_mayavi.png'
make_3d_plot_v2_mayavi(G, aligned_labels, aligned_groups, mayavi_output_file, 'spring', lengths)