In [44]:
from sympy import symbols, lambdify, Expr
import torch
import numpy as np
import scqubits as sc

def sympy_to_pytorch(expr: Expr, variables: symbols):
    # Convert to a NumPy-compatible function
    numpy_func = lambdify(variables, expr, 'numpy')
    ...
    torch_fun = numpy_func
    return torch_fun


zp_yaml = """# zero-pi
branches:
- ["JJ", 1,2, EJ1 = 10, EC1 = 20]
- ["JJ", 3,4, EJ2=5, EC2 = 30]
- ["L", 2,3, L1 = 0.008]
- ["L", 4,1, L2=0.1]
- ["C", 1,3, C1 = 0.02]
- ["C", 2,4, C2 = 0.4]
"""

zp = sc.Circuit(zp_yaml, from_file=False)

zp.sym_hamiltonian()
expr = zp.sym_hamiltonian(return_expr = True)

variables = sorted(expr.free_symbols, key=lambda x: x.sort_key())

print(f"Please define the following variables: {variables}")

sympy_to_pytorch(expr, variables)

print(expr)


<IPython.core.display.Latex object>

Please define the following variables: [(2πΦ_{1}), EJ1, EJ2, L1, L2, Q2, Q3, n1, n_g1, θ1, θ2, θ3]
(-EJ1*cos(θ1 - 1.0*θ3) - EJ2*cos(-(2πΦ_{1}) + θ1 + θ3) + 2.0*L1*θ2**2 + 2.0*L1*θ2*θ3 + 0.5*L1*θ3**2 + 2.0*L2*θ2**2 - 2.0*L2*θ2*θ3 + 0.5*L2*θ3**2) + (0.104284*Q2**2 - 0.075367*Q2*Q3 - 0.376835*Q2*n1 - 0.376835*Q2*n_g1 + 48.01666*Q3**2 + 0.166601*Q3*n1 + 0.166601*Q3*n_g1 + 0.416501*n1**2 + 0.833003*n1*n_g1 + 0.416501*n_g1**2)
