In [105]:
import sympy as sp
from itertools import permutations
import torch

def parse_dimacs_cnf(filename: str) -> tuple[int, list[list[int]]]:
    with open(filename, "r") as file:
        clauses, num_vars = [], 0
        for line in file:
            line = line.strip()
            if not line or line.startswith("c") or "%" in line:
                continue  # skip comments
            if line.startswith("p"):
                parts = line.split()
                num_vars, num_clauses = int(parts[2]), int(parts[3])
                continue
            literals = [int(x) for x in line.split() if int(x) != 0]
            if literals:
                if len(literals) != 3:
                    raise ValueError("Give me a 3-SAT problem.")
                clauses.append(literals)
    return num_vars, clauses


def convert_clauses_to_ising(num_vars, clauses, dtype=torch.float32):
    var_symbols = sp.symbols(f"x1:{num_vars+1}")
    h_coeff = torch.zeros(num_vars, dtype=dtype)
    J_coeff = torch.zeros((num_vars, num_vars), dtype=dtype)
    P_coeff = torch.zeros((num_vars, num_vars, num_vars), dtype=dtype)
    scalar = 0
    for clause in clauses:
        term = 1
        for literal in clause:
            var = var_symbols[abs(literal) - 1]
            term *= (1 - var) if literal > 0 else (1 + var)
        expanded = sp.expand(term) / 8 #in HOIM paper, they added 1/8 to make the energy look great
        coeffs = expanded.as_coefficients_dict()
        for monomial, coeff in coeffs.items():
            if monomial == 1:
                scalar += float(coeff)
                continue
            vars_in_monomial = []
            if isinstance(monomial, sp.Symbol):
                vars_in_monomial = [str(monomial)]
            elif isinstance(monomial, sp.Mul):
                vars_in_monomial = sorted([str(arg) for arg in monomial.args])
            else:
                vars_in_monomial = [str(monomial)]
            indices = sorted([int(var[1:]) - 1 for var in vars_in_monomial])
            if len(indices) == 1:
                h_coeff[indices[0]].add_(float(coeff))
            elif len(indices) == 2:
                J_coeff[indices[0], indices[1]].add_(float(coeff))
                J_coeff[indices[1], indices[0]].add_(float(coeff))
            elif len(indices) == 3:
                key = tuple(indices)
                for subkey in permutations(key):
                    P_coeff[subkey].add_(float(coeff))
    return h_coeff, J_coeff, P_coeff, scalar

def string_to_expression(expr_string: str):
    expression = sp.parse_expr(expr_string)    
    syms = expression.as_terms()[-1]
    assert set([s.name for s in syms]).issubset({'x1', 'x2', 'x3', 'z'})
    symbol_dict = dict(zip(map(lambda x: x.name, syms), syms))
    return expression, symbol_dict

def clause_to_expression(clause, expression: sp.Expr, expr_syms, var_symbols, aux_symbols, ind):
    clause_term = expression.xreplace(
        dict([(expr_syms[f'x{i+1}'], 1+var_symbols[abs(clause[i])-1]) if clause[i] < 0 else (expr_syms[f'x{i+1}'], 1-var_symbols[abs(clause[i])-1])
         for i in range(len(clause))])
    )
    if aux_symbols is not None:
        clause_term = clause_term.replace(expr_syms['z'], aux_symbols[ind])
    return clause_term

num_vars, clauses = parse_dimacs_cnf("uf20-01.cnf")
h, J, P, scalar2 = convert_clauses_to_ising( num_vars, clauses)
# h = torch.tensor(np.h.numpy())
P
# def run_langevin_brim(tstop, dt, )

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.1250],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.1250,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.1250],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.1250,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0

In [103]:
var_x1 = sp.Number
expression = sp.parse_expr("x3 * z + (2*z - x1 * z - x2*z+2*x1*x2)")
syms = expression.as_terms()[-1]
symbol_dict = dict(zip(map(lambda x: x.name, syms), syms))
expression.replace(symbol_dict['x1'], 1), expression
for expr in expression.as_terms():
    print(expr)

[(2*z, ((2.0, 0.0), (0, 0, 0, 1), ())), (x3*z, ((1.0, 0.0), (0, 0, 1, 1), ())), (-x1*z, ((-1.0, 0.0), (1, 0, 0, 1), ())), (-x2*z, ((-1.0, 0.0), (0, 1, 0, 1), ())), (2*x1*x2, ((2.0, 0.0), (1, 1, 0, 0), ()))]
[x1, x2, x3, z]


In [60]:
expression.xreplace({syms[3]: 1, syms[2]: 1, syms[0]: 1})

x2 + 2