# Introduction to Statistical Mechanics (ME346A)

In [None]:
# Introduction to Statistical Mechanics (Winter, 2025)
# Wei Cai, Myung Chul Kim
# Virial Coefficients and Graph Counting Problem

Reference: Barker et al., Fifth Virial Coefficients, J. Chem. Phys. (1966) \\
https://doi.org/10.1063/1.1726606

## Graph Counting

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = False
import networkx as nx
from networkx import graph_atlas_g
from itertools import permutations
import sympy as sp

### Graph Plotting Functions

In [None]:
# Graph plotting
def create_grid_pos(n):
    grid_size = np.ceil(np.sqrt(n))
    pos = {}
    for i in range(n):
        row = i // grid_size
        col = i % grid_size
        pos[i] = (col / (grid_size - 1), 1 - row / (grid_size - 1))
    return pos

In [None]:
def plot_graphs(n_node_graphs, n_dofs=None, save_dir='graphs'):
    # set up the plot
    n_graphs = len(n_node_graphs)
    n_cols = 4 # may change
    n_rows = (n_graphs - 1) // n_cols + 1

    # plot graphs
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten()

    for i, g in enumerate(n_node_graphs):
        ax = axes[i]

        grid_pos = create_grid_pos(len(g.nodes()))

        nx.draw(g, grid_pos, ax=ax, with_labels=True, node_color='lightblue',
                node_size=500, font_size=10, font_weight='bold')

        ax.set_title(f'Graph {i+1}')
        if n_dofs is not None:
            ax.text(0.5, -0.1, f'DOF: {n_dofs[i]}',
                    horizontalalignment='center',
                    transform=ax.transAxes)

        ax.set_axis_off()

    for j in range(i+1, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.savefig(f"{save_dir}.png")
    plt.clf()

### Graph Counting Functions

In [None]:
# Obtain degree of freedom of each graph
def get_unique_adjs(graph):
    adj = nx.adjacency_matrix(graph).todense()
    n_nodes = adj.shape[0]
    unique_matrices = set()

    for perm in list(permutations(range(n_nodes))):
        perm_matrix = np.zeros((n_nodes, n_nodes))
        for i, j in enumerate(perm):
            perm_matrix[i, j] = 1

        multiplied_matrix = perm_matrix @ adj @ perm_matrix.T
        matrix_tuple = tuple(map(tuple, multiplied_matrix))
        unique_matrices.add(matrix_tuple)

    return np.array(list(unique_matrices))

### **TODO: Graph Indicator Functions**

In [None]:
# Function to determine connectivity
def is_connected(graph):
    """
    Determine connected graphs
    """
    return nx.is_connected(graph)

In [None]:
# TODO: function to determine 2-connectivity
def is_2_connected(graph):
    """
    Determine 2-connected graphs
    """
    return NotImplemented

### Symbolic Expression Functions

In [None]:
# Generate Latex expressions
def gen_expr(unique_adjs_list, n_dofs):
    expr = sp.Integer(0)
    for i, unique_adjs in enumerate(unique_adjs_list):
        # extract unique pairs of edges
        mask = np.tril(np.ones((unique_adjs.shape[1:3])), k=-1)
        unique_adjs[:, mask.astype(bool)] = 0
        indices = np.argwhere(unique_adjs == 1)[:, 1:]

        # group in terms
        n_dof = unique_adjs.shape[0]
        n_links = indices.shape[0] // n_dof
        sub_indices = indices[0:n_links]

        # update expression
        expr += n_dofs[i] * sp.Mul(*[sp.Function('f')(sp.Symbol(f'r_{{{min(pair)+1}{max(pair)+1}}}')) for pair in sub_indices])

    return expr

In [None]:
def latexify_expr(expr, n_nodes=4):
    # if the expression is a sum, split it into terms
    if isinstance(expr, sp.Add): terms = expr.args
    else: terms = [expr]

    # convert each term to LaTeX and join with newlines and plus signs
    latex_terms = [sp.latex(term) for term in terms]

    if len(latex_terms) == 1:
        latex_expr = ''
        latex_expr += f'\mathcal{{B}}_{n_nodes} = '
        latex_expr += f'-\\frac{{1}}{{{n_nodes}V}} '
        for i in range(n_nodes): latex_expr += f'\int_V d^3\mathbf{{r}}_{i+1} '
        latex_expr += '\,\,\,'
        latex_expr += latex_terms[0]

    elif len(latex_terms) > 1:
        latex_expr = ''
        # latex_expr = f'\\begin{{split}}'
        # latex_expr += '\\begin{align}'
        latex_expr += f'\mathcal{{B}}_{n_nodes} = '
        latex_expr += f'-\\frac{{1}}{{{n_nodes}V}} '
        for i in range(n_nodes): latex_expr += f'\int_V d^3\mathbf{{r}}_{i+1}'
        latex_expr += '\,\,\, '
        latex_expr += '\\Big['
        latex_expr += latex_terms[0]
        for term in latex_terms[1:]:
            latex_expr += ' \\\\ ' + '\\hspace{' + str(3 * n_nodes + 7) + 'em}'  + f'+ {term}'
        # latex_expr += f'\\end{{split}}'
        latex_expr += '\\Big]'
        # latex_expr += '\\end{align}'

    return latex_expr

### Play with Graphs

In [None]:
# Set number of nodes
n_nodes = 4 # you may change this
assert n_nodes > 1, "Number of nodes must be greater than 1"
print("Number of nodes:", n_nodes)

In [None]:
# Generate all graphs
n_node_graphs = [g for g in graph_atlas_g() if len(g.nodes())==n_nodes]
print("Total number of graphs:", len(n_node_graphs))
plot_graphs(n_node_graphs, save_dir=f'graphs_n{n_nodes}')

In [None]:
# Collect connected graphs
n_node_graphs_connected = []
for graph in n_node_graphs:
    if is_connected(graph):
        n_node_graphs_connected.append(graph)
print("Total number of connected graphs:", len(n_node_graphs_connected))
plot_graphs(n_node_graphs_connected, save_dir=f'graphs_n{n_nodes}_c')

In [None]:
# Collect 2-connected graphs
n_node_graphs_2_connected = []
for graph in n_node_graphs_connected:
    if is_2_connected(graph):
        n_node_graphs_2_connected.append(graph)
print("Total number of 2-connected graphs:", len(n_node_graphs_2_connected))
plot_graphs(n_node_graphs_2_connected, save_dir=f'graphs_n{n_nodes}_2c')

In [None]:
# Obtain degree of freedom of each graph
unique_adjs_list = []
for i, graph in enumerate(n_node_graphs_2_connected):
    unique_adjs = get_unique_adjs(graph)
    unique_adjs_list.append(unique_adjs)
    print(f"Graph {i+1} has multiplicity of {len(unique_adjs)}")
n_dofs = [len(unique_adjs) for unique_adjs in unique_adjs_list]
plot_graphs(n_node_graphs_2_connected, n_dofs=n_dofs, save_dir=f'graphs_n{n_nodes}_2c')

In [None]:
# Print out Latex expression of Virial coefficient
from IPython.display import Math, display
vir_expr = gen_expr(unique_adjs_list, n_dofs)
latex_expr = latexify_expr(vir_expr, n_nodes)
display(Math(latex_expr))