In [1]:
from fastcda import FastCDA
from dgraph_flex import DgraphFlex

fc = FastCDA()

In [2]:
# show the version of Tetrad we are using
fc.getTetradVersion()

'7.6.3-0'

In [3]:
def read_edge_list(text:str):
    """
    extract edges from a text string
    
    1. GridB_secs --> GridA_secs
    2. GridB_secs_ --> GridB_secs
    3. SpatialSpan3_perc_accuracy --- SpatialSpan4_perc_accuracy
    4. SpatialSpan4_perc_accuracy --- SpatialSpan5_perc_accuracy
    5. SpatialSpan5_perc_accuracy --> GridB_secs
    """
    edges = []
    for line in text.strip().split("\n"):
        line = line.strip()
        if line:
            parts = line.split(" ")
            if len(parts) == 4 and parts[2] == "-->":
                edges.append(f"{parts[1]} {parts[2]} {parts[3]}")
    return edges

In [4]:
# given a list of directed edges defining a DAG,
# create a method that uses networkx to create a directed graph
# and return the nodes and edges of the graph that are 
# ancestors of a given target node
def get_ancestors_dag(edge_list, target_node):
    import networkx as nx
    G = nx.DiGraph()
    for edge in edge_list:
        parts = edge.split(" ")
        if len(parts) == 3 and parts[1] == "-->":
            G.add_edge(parts[0], parts[2])
    ancestors = nx.ancestors(G, target_node)
    ancestors.add(target_node)
    subgraph = G.subgraph(ancestors)
    
    # create a new edge list with the same format as the input
    new_edge_list = []
    for u, v in subgraph.edges:
        new_edge_list.append(f"{u} --> {v}")
    return new_edge_list

In [5]:
# K=3
edge_texts = {}
groups = {}
# group 1  
groups['group_1'] =   ['R26_lagged', 'R2_lagged', 'R19_lagged', 'R17_lagged', 'R9_lagged', 'R10_lagged', 'R21_lagged', 'R22_lagged', 'R4_lagged', 'R11_lagged']

edge_texts['group_1'] = """
1. GridA_secs --> GridB_secs
2. GridA_secs --> SpatialSpan3_perc_accuracy
3. GridA_secs_ --> GridA_secs
4. GridB_secs --> pain_cope
5. SpatialSpan3_perc_accuracy --> SpatialSpan4_perc_accuracy
6. SpatialSpan4_perc_accuracy_ --> get_better
7. SpatialSpan5_perc_accuracy --> SpatialSpan3_perc_accuracy
8. SpatialSpan5_perc_accuracy --> SpatialSpan4_perc_accuracy
9. SpatialSpan5_perc_accuracy_ --> SpatialSpan5_perc_accuracy
10. SpatialSpan5_perc_accuracy_ --> feel_burden
11. accomplish --> stressed
12. accomplish_ --> accomplish
13. depressed --> SpatialSpan5_perc_accuracy
14. depressed --> fatigue
15. depressed --> feel_burden
16. depressed --> feel_direction
17. depressed --> pain_interf
18. depressed --> pain_worry
19. depressed --> serious
20. depressed --> stressed
21. depressed_ --> depressed
22. depressed_ --> pain_cope
23. fatigue --> SpatialSpan3_perc_accuracy
24. fatigue --> SpatialSpan4_perc_accuracy
25. fatigue --> accomplish
26. fatigue --> feel_burden
27. fatigue_ --> SpatialSpan5_perc_accuracy
28. fatigue_ --> fatigue
29. feel_burden --> pain_interf
30. feel_burden --> serious
31. feel_burden_ --> feel_burden
32. feel_burden_ --> pain_interf
33. feel_burden_ --> rested
34. feel_burden_ --> serious
35. feel_direction --> feel_burden
36. feel_direction --> get_better
37. feel_direction --> pain_cope
38. feel_direction_ --> feel_direction
39. feel_support --> depressed
40. feel_support --> fatigue
41. feel_support --> feel_burden
42. feel_support --> feel_direction
43. feel_support --> pain_intens
44. feel_support --> trust_provid
45. feel_support_ --> feel_support
46. get_better --> SpatialSpan5_perc_accuracy
47. get_better_ --> get_better
48. get_better_ --> rested
49. pain_cope --> accomplish
50. pain_cope --> feel_burden
51. pain_cope --> pain_strategy
52. pain_cope_ --> feel_direction
53. pain_cope_ --> pain_cope
54. pain_cope_ --> pain_strategy
55. pain_cope_ --> pain_worry
56. pain_intens --> pain_mind
57. pain_intens --> trust_provid
58. pain_intens_ --> pain_intens
59. pain_intens_ --> trust_provid
60. pain_interf --> pain_intens
61. pain_interf --> pain_mind
62. pain_interf --> pain_worry
63. pain_interf --> stressed
64. pain_interf_ --> pain_interf
65. pain_mind --> stressed
66. pain_mind_ --> pain_interf
67. pain_mind_ --> stressed
68. pain_strategy --> pain_intens
69. pain_strategy_ --> pain_cope
70. pain_strategy_ --> pain_strategy
71. pain_worry --> pain_intens
72. pain_worry --> pain_mind
73. pain_worry_ --> fatigue
74. pain_worry_ --> feel_direction
75. pain_worry_ --> pain_worry
76. pain_worry_ --> serious
77. rested --> depressed
78. rested --> fatigue
79. rested --> feel_direction
80. rested --> pain_interf
81. rested_ --> accomplish
82. rested_ --> get_better
83. rested_ --> pain_interf
84. rested_ --> rested
85. serious --> pain_mind
86. serious --> pain_worry
87. serious --> stressed
88. serious_ --> depressed
89. serious_ --> rested
90. serious_ --> serious
91. stressed_ --> depressed
92. stressed_ --> feel_burden
93. stressed_ --> pain_worry
94. stressed_ --> stressed
95. trust_provid --> SpatialSpan3_perc_accuracy
96. trust_provid_ --> depressed
97. trust_provid_ --> feel_support
98. trust_provid_ --> pain_intens
99. trust_provid_ --> trust_provid
"""



In [6]:
# group 2
groups['group_2'] =   ['R15_lagged', 'R20_lagged', 'R1_lagged', 'R3_lagged', 'R25_lagged', 'R13_lagged', 'R18_lagged']

edge_texts['group_2'] = """
1. GridA_secs --> GridB_secs
2. GridB_secs_ --> trust_provid
3. SpatialSpan4_perc_accuracy --> SpatialSpan3_perc_accuracy
4. SpatialSpan5_perc_accuracy --> GridB_secs
5. SpatialSpan5_perc_accuracy --> SpatialSpan3_perc_accuracy
6. SpatialSpan5_perc_accuracy --> SpatialSpan4_perc_accuracy
7. SpatialSpan5_perc_accuracy_ --> GridB_secs
8. SpatialSpan5_perc_accuracy_ --> SpatialSpan4_perc_accuracy
9. SpatialSpan5_perc_accuracy_ --> SpatialSpan5_perc_accuracy
10. depressed_ --> depressed
11. depressed_ --> feel_burden
12. fatigue --> accomplish
13. fatigue --> feel_support
14. fatigue_ --> fatigue
15. feel_burden --> depressed
16. feel_burden --> pain_intens
17. feel_burden --> pain_interf
18. feel_burden --> pain_worry
19. feel_burden --> stressed
20. feel_burden_ --> feel_burden
21. feel_burden_ --> feel_support
22. feel_burden_ --> trust_provid
23. feel_direction --> get_better
24. feel_direction --> pain_worry
25. feel_direction --> rested
26. feel_direction --> trust_provid
27. feel_direction_ --> accomplish
28. feel_direction_ --> feel_direction
29. feel_direction_ --> get_better
30. feel_support --> accomplish
31. feel_support --> depressed
32. feel_support_ --> feel_support
33. feel_support_ --> pain_strategy
34. get_better --> pain_mind
35. get_better --> pain_strategy
36. get_better --> stressed
37. get_better_ --> feel_direction
38. get_better_ --> get_better
39. pain_cope_ --> SpatialSpan5_perc_accuracy
40. pain_cope_ --> feel_support
41. pain_cope_ --> pain_cope
42. pain_cope_ --> pain_strategy
43. pain_intens --> get_better
44. pain_intens --> pain_interf
45. pain_intens --> pain_mind
46. pain_intens_ --> pain_intens
47. pain_interf_ --> feel_burden
48. pain_interf_ --> pain_interf
49. pain_mind_ --> pain_mind
50. pain_strategy --> feel_support
51. pain_strategy --> pain_cope
52. pain_strategy_ --> pain_strategy
53. pain_worry --> pain_cope
54. pain_worry --> pain_mind
55. pain_worry_ --> pain_worry
56. rested --> fatigue
57. rested_ --> fatigue
58. rested_ --> rested
59. serious --> feel_burden
60. serious --> pain_worry
61. serious_ --> accomplish
62. serious_ --> depressed
63. serious_ --> feel_support
64. serious_ --> serious
65. stressed --> depressed
66. stressed --> fatigue
67. stressed --> rested
68. stressed_ --> stressed
69. stressed_ --> trust_provid
70. trust_provid --> SpatialSpan5_perc_accuracy
71. trust_provid --> feel_burden
72. trust_provid --> pain_interf
73. trust_provid --> pain_strategy
74. trust_provid --> rested
75. trust_provid_ --> feel_support
76. trust_provid_ --> pain_worry
77. trust_provid_ --> stressed
78. trust_provid_ --> trust_provid
"""

In [7]:
# group 1  
groups['group_3'] =   ['R14_lagged', 'R16_lagged', 'R28_lagged', 'R8_lagged', 'R6_lagged', 'R23_lagged', 'R24_lagged']
 
edge_texts['group_3'] = """
1. GridA_secs --> GridB_secs
2. SpatialSpan3_perc_accuracy --> GridB_secs
3. SpatialSpan4_perc_accuracy --> SpatialSpan3_perc_accuracy
4. SpatialSpan4_perc_accuracy_ --> SpatialSpan3_perc_accuracy
5. SpatialSpan4_perc_accuracy_ --> SpatialSpan5_perc_accuracy
6. SpatialSpan4_perc_accuracy_ --> feel_burden
7. SpatialSpan4_perc_accuracy_ --> pain_strategy
8. SpatialSpan5_perc_accuracy --> GridB_secs
9. SpatialSpan5_perc_accuracy --> SpatialSpan3_perc_accuracy
10. SpatialSpan5_perc_accuracy --> SpatialSpan4_perc_accuracy
11. SpatialSpan5_perc_accuracy --> pain_worry
12. SpatialSpan5_perc_accuracy_ --> GridB_secs
13. SpatialSpan5_perc_accuracy_ --> SpatialSpan4_perc_accuracy
14. SpatialSpan5_perc_accuracy_ --> SpatialSpan5_perc_accuracy
15. SpatialSpan5_perc_accuracy_ --> trust_provid
16. accomplish --> fatigue
17. accomplish_ --> accomplish
18. depressed_ --> depressed
19. depressed_ --> get_better
20. depressed_ --> pain_strategy
21. fatigue --> feel_direction
22. fatigue --> pain_interf
23. fatigue --> rested
24. fatigue_ --> fatigue
25. fatigue_ --> get_better
26. fatigue_ --> rested
27. feel_burden --> depressed
28. feel_burden_ --> feel_burden
29. feel_burden_ --> pain_cope
30. feel_burden_ --> pain_worry
31. feel_burden_ --> rested
32. feel_direction --> depressed
33. feel_direction --> trust_provid
34. feel_direction_ --> GridB_secs
35. feel_direction_ --> feel_direction
36. feel_direction_ --> feel_support
37. feel_direction_ --> pain_interf
38. feel_support --> GridB_secs
39. feel_support --> accomplish
40. feel_support --> depressed
41. feel_support --> fatigue
42. feel_support --> feel_direction
43. feel_support --> get_better
44. feel_support --> pain_intens
45. feel_support --> pain_strategy
46. feel_support --> rested
47. feel_support_ --> feel_support
48. get_better --> pain_cope
49. get_better_ --> depressed
50. get_better_ --> get_better
51. get_better_ --> pain_strategy
52. pain_cope --> feel_direction
53. pain_cope_ --> GridB_secs
54. pain_cope_ --> pain_cope
55. pain_cope_ --> pain_strategy
56. pain_intens --> pain_mind
57. pain_intens --> pain_worry
58. pain_intens --> serious
59. pain_intens_ --> feel_direction
60. pain_intens_ --> trust_provid
61. pain_interf --> feel_burden
62. pain_interf --> pain_intens
63. pain_interf --> pain_mind
64. pain_interf --> pain_strategy
65. pain_interf --> stressed
66. pain_interf_ --> fatigue
67. pain_interf_ --> feel_direction
68. pain_interf_ --> pain_interf
69. pain_interf_ --> pain_strategy
70. pain_mind --> get_better
71. pain_mind_ --> feel_support
72. pain_strategy --> SpatialSpan5_perc_accuracy
73. pain_strategy --> get_better
74. pain_strategy --> pain_cope
75. pain_strategy --> pain_worry
76. pain_strategy_ --> fatigue
77. pain_strategy_ --> pain_cope
78. pain_strategy_ --> pain_mind
79. pain_worry --> feel_burden
80. pain_worry --> serious
81. pain_worry_ --> pain_intens
82. pain_worry_ --> pain_worry
83. rested --> feel_burden
84. rested --> pain_intens
85. rested --> pain_worry
86. rested --> stressed
87. rested_ --> accomplish
88. rested_ --> rested
89. serious --> depressed
90. serious_ --> get_better
91. serious_ --> pain_strategy
92. serious_ --> serious
93. stressed --> feel_burden
94. stressed --> pain_cope
95. stressed --> pain_mind
96. stressed --> serious
97. stressed_ --> pain_interf
98. stressed_ --> stressed
99. trust_provid --> pain_worry
100. trust_provid_ --> SpatialSpan5_perc_accuracy
101. trust_provid_ --> pain_intens
102. trust_provid_ --> stressed
103. trust_provid_ --> trust_provid
"""

In [8]:
# read into a list of edges
edge_lists={}
group = 'group_1'
edge_lists[group] = read_edge_list(edge_texts[group])  
pass 

In [13]:
import pandas as pd
def read_group_data(ids:list[str], add_id:bool=False)->pd.DataFrame:
    """
    read the  data from the individual files in the data 
    directory  id.csv
    and concatenate them into a single dataframe
    
    args:
        ids: list of string ids corresponding to file names
        add_id: if True, add a column 'id' to indicate the source file
    returns:
        concatenated dataframe
    """
    dir = "data_raw"
    # create empty dataframe
    group_data = pd.DataFrame()
    for id in ids:
        filepath = f"{dir}/{id}.csv"
        df = pd.read_csv(filepath)
        if add_id:
            # add id column to the first column of the dataframe
            df.insert(0, 'id', id)
        # append to group_data
        group_data = pd.concat([group_data, df], ignore_index=True)
    return group_data


In [16]:
%matplotlib inline
import matplotlib.pyplot as plt
# read in edges for all groups

all_group_data = pd.DataFrame()
group_keys = ['group_1', 'group_2','group_3']
for group in group_keys:
    edge_lists[group] = read_edge_list(edge_texts[group])
    
    # check if 'pain_interf' is in the edges
    nodes = set()
    for edge in edge_lists[group]:
        parts = edge.split(" ")
        if len(parts) == 3 and parts[1] == "-->":
            nodes.add(parts[0])
            nodes.add(parts[2])
    if 'pain_interf' not in nodes:
        print(f"'pain_interf' not in edges for {group}, skipping...")
        continue
    print(f"Processing {group} with edges:")
    # get the ancestors of pain_interf
    ancestor_edge_list = get_ancestors_dag(edge_lists[group], 'pain_interf')

    # convert edges for lavaan model
    lavaan_model = fc.edges_to_lavaan(edge_lists[group])
    
    # read in the data for the group with id
    data = read_group_data(groups[group], add_id=True)
    # add a group column to dataframe at beginning of df
    data.insert(0, 'group', group)
    # concat to all_group_data
    all_group_data = pd.concat([all_group_data, data], ignore_index=True)

    # read in the data for the group without id for running sem
    data = read_group_data(groups[group])    
    # write the group data to a csv file
    data.to_csv(f"{group}_data.csv", index=False)
    # run sem
    sem_results = fc.run_semopy(lavaan_model, data)

    dg = DgraphFlex()  # create an object

    # create new blank plot
    plt.figure()
    # create the full graph
    dg.add_edges(edge_lists[group])  # add edges
    fc.add_sem_results_to_graph(dg, sem_results['estimates'])
    dg.save_graph(f"k-3_{group}_graph_pain_full")

    # create the ancestor graph
    dg = DgraphFlex()  # create an object
    dg.add_edges(ancestor_edge_list)  # add edges
    # add the sem results to the graph
    fc.add_sem_results_to_graph(dg, sem_results['estimates'])
    dg.save_graph(f"{group}_graph_pain_interf")
    dg.show_graph()
    
# write all_group_data to csv
all_group_data.to_csv("all_group_data.csv", index=False)
pass   

Processing group_1 with edges:
Processing group_2 with edges:
Processing group_3 with edges:


<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

In [None]:
# create graph for a single group
group = 'group_1'
edge_lists[group] = read_edge_list(edge_texts[group])

# get the ancestors of pain_interf
ancestor_edge_list = get_ancestors_dag(edge_lists[group], 'pain_interf')
print(f"Edges: {ancestor_edge_list}")

# convert full edges for lavaan model
lavaan_model = fc.edges_to_lavaan(edge_lists[group])

# read in the data for the group
data = read_group_data(groups[group])
# run sem
sem_results = fc.run_semopy(lavaan_model, data)


dg = DgraphFlex()  # create an object
dg.add_edges(edge_lists[group])  # add all edges
fc.add_sem_results_to_graph(dg, sem_results['estimates'])
# dg.save_graph(f"{group}_graph")
dg.show_graph()

In [None]:
# create another plot with just the ancestors
dg2 = DgraphFlex()  # create an object
dg2.add_edges(ancestor_edge_list)  # add edges
fc.add_sem_results_to_graph(dg2, sem_results['estimates'])
# dg2.save_graph(f"{group}_ancestors_graph")
dg2.show_graph()

How to create a heatmap version with the groups divided based on cda groups



In [None]:
# order the subjects in each group by their R numbers
groups_sorted = {}
for group in group_keys:
    groups_sorted[group] = sorted(groups[group], key=lambda x: int(x[1:]))
    print(f"{group}: {groups_sorted[group]}")
pass

In [None]:
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler

def cmap_red_green():
    
    # 1. Create the custom colormap
    # Define the colors for the colormap: red, white, and green
    colors = [(1, 0, 0), (1, 1, 1), (0, 1, 0)]  # These are RGB tuples for (Red, White, Green)

    # Define the positions (nodes) for each color. 
    # -1 corresponds to node 0.0
    #  0 corresponds to node 0.5
    # +1 corresponds to node 1.0
    nodes = [0.0, 0.5, 1.0]

    # Create the colormap using LinearSegmentedColormap
    # The name 'red_white_green' is what you can call it by later if you register it.
    cmap_name = 'red_white_green'
    my_cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, list(zip(nodes, colors)))
    return my_cmap

In [None]:
def create_heatmap(df: pd.DataFrame,
                    cmap,  # colormap
                    title: str = "Heatmap",
                    cbar_label: str='Standardized R',
                    xtick_labelsize: int = 18,
                    ytick_labelsize: int = 18,
                    title_fontsize: int = 30,
                    grid_fontsize: int = 12,
                    grid_alpha: float = .5,
                    grid_linewidth: int = 1,
                    
                    # horizontal line arguments
                    horiz_linewidth: int = 4,
                    horiz_color: str = 'purple',
                    horiz_rows: list = [],
                    horiz_labels: list = [],

                    # vertical line arguments
                    verti_linewidth: int = 4,
                    verti_color: str = 'black',
                    verti_labels: list = [],
                    
                    # reorder_rows
                    reorder_rows: list = [],

                   ):
    """
    Create a heatmap
    """
    
    # 1. get information from the df
    df_copy = df.copy()
    # set the index to be the first column
    df_copy = df_copy.set_index(df_copy.columns[0])

    # reorder the rows
    if reorder_rows:
        df_copy = df_copy.reindex(reorder_rows)
        
    col_labels = list(df_copy.columns)
    row_labels = list(df_copy.index)
    data = df_copy.to_numpy()
    

    # 2. Create the plot
    fig_width = max(8, data.shape[1] * 0.8)
    fig_height = max(6, data.shape[0] * 0.6)
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    im = ax.imshow(data, cmap=cmap, vmin=-1, vmax=1, interpolation='nearest')

    # 3. Add text annotations
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            color = "white" if abs(data[i, j]) > 0.6 else "black"
            ax.text(j, i, f"{data[i, j]:.1f}",
                        ha="center", va="center", color=color, 
                        fontsize=grid_fontsize)

    # 4. Set up axes
    ax.set_xticks(np.arange(len(col_labels)))
    ax.set_yticks(np.arange(len(row_labels)))
    ax.set_xticklabels(col_labels)
    ax.set_yticklabels(row_labels)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    # 5. Add gridlines
    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="black", linestyle='-', 
            linewidth=grid_linewidth, alpha=grid_alpha)
    ax.tick_params(which="minor", bottom=False, left=False)
    
    # 6. Add horizontal lines
    if horiz_labels:
        for horiz_label in horiz_labels:
            if horiz_label in row_labels:
                row_index = row_labels.index(horiz_label)
                line_y = row_index + 0.5
                ax.hlines(y=line_y, xmin=-0.5, xmax=data.shape[1] - 0.5, 
                        color=horiz_color, linewidth=horiz_linewidth, linestyle='-')
          
    # 7. Add vertical lines at specified locations
    if verti_labels:
        num_rows = data.shape[0]  # Get the number of rows for the line's height
        for verti_label in verti_labels:
            # Find the index of the column that matches the string
            if verti_label in col_labels:
                col_index = col_labels.index(verti_label)
                # Position the line to the RIGHT of the specified column
                line_x = col_index + 0.5
                ax.vlines(x=line_x, ymin=-0.5, ymax=num_rows - 0.5,
                        color=verti_color, linewidth=verti_linewidth, linestyle='-')
                
    # --- FONT SIZE ADJUSTMENTS ---
    ax.set_title(title, fontsize=title_fontsize)
    ax.tick_params(axis='x', labelsize=xtick_labelsize)
    ax.tick_params(axis='y', labelsize=ytick_labelsize)
    
    fig.tight_layout()
    plt.colorbar(im, ax=ax, label=cbar_label, shrink=1.2)
    plt.show()

In [None]:
# create the list of horizontal labels which are the last element of each group in groups_sorted
horiz_labels = []
for group in group_keys:
    horiz_labels.append(groups_sorted[group][-1])

horiz_labels

In [None]:
# create the list of sorted groups
reorder_rows = []
for group in group_keys:
    reorder_rows.extend(groups_sorted[group])
    
reorder_rows

In [None]:
# Load the dataframe from the CSV file you uploaded
data_df = pd.read_csv('cpcrun1_effectsize_grid_pain_interf_pmax_10.csv')

# reorder_rows list to reorder the dataframe rows. The column in the dataframe is the first column

data_df = data_df.set_index(data_df.columns[0])
data_df = data_df.reindex(reorder_rows)
# rename the index to be a column again
data_df = data_df.reset_index()
pass

In [None]:

# do the heatmap
create_heatmap(data_df, cmap_red_green(),
                verti_labels = ['pain_strategy',
                                'rested',
                                'get_better_',
                                ],
                horiz_labels = horiz_labels,
)
