In [39]:
from openai import OpenAI
from random import randint, choice
from os import environ
from pathlib import Path
from json import loads, dumps
environ["OPENAI_API_KEY"] = Path("~/.openaiapikey").expanduser().read_text().strip()

openaiClient = OpenAI()
def gpt_3_5_turbo_completion(query):
    answer = openaiClient.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {
                "role": "system",
                "content": query
            }
        ],
        seed = randint(0, 1000000)
    )
    return answer.choices[0].message.content

def gpt_4_turbo_completion(query):
    answer = openaiClient.chat.completions.create(
        model="gpt-4-turbo",
        messages=[
            {
                "role": "system",
                "content": query
            }
        ],
        seed = randint(0, 1000000)
    )
    return answer.choices[0].message.content

def tryRecieveAnswer(query, completionFunction = gpt_4_turbo_completion, answerConversion = lambda x: x, maxTries = 10):
    tryNumber = 0
    while tryNumber < maxTries:
        answer = completionFunction(query)
        try:
            answer = answerConversion(answer)
            return (answer, True)
        except:
            pass
        tryNumber += 1
    print(f"Failed to recieve answer for query: {query}. Last answer: {answer}")
    return (None, False)

def listAnswerConversion(answer):
    result = loads(answer)
    assert isinstance(result, list)
    for item in result:
        assert isinstance(item, str)
    return result

In [60]:
class DerivationNetwork:
    def __init__(self):
        self.equations = []
        self.derivations = []
class Equation:
    def __init__(self, leftSide, rightSide):
        self.leftSide = leftSide
        self.rightSide = rightSide
    def asString(self):
        return f"{self.leftSide} = {self.rightSide}"
class Derivation:
    def __init__(self, fromEquations, toEquations):
        self.fromEquations = fromEquations
        self.toEquations = toEquations
        self.symbols = dict()
def addEquation(network, leftSide, rightSide):
    equation = Equation(leftSide, rightSide)
    network.equations.append(equation)
    return (leftSide, equation, rightSide)
def addFromThisFollowsThat(network, fromEquations, toEquations):
    derivation = Derivation(fromEquations, toEquations)
    network.derivations.append(derivation)
    return derivation
def addSymbolDefinition(derivation, symbol, definition):
    derivation.symbols[symbol] = definition

In [89]:
nw = DerivationNetwork()
showSymbolDefinition = False
ham = False
leftSchroed, schroedEq, rightSchroed = addEquation(nw, "i\\hbar\\frac{\\partial}{\\partial t}\\psi(x,t)", "-\\frac{\\hbar^2}{2m}\\Delta\\psi(x,t)+V(x)\\psi(x,t)")

leftSeperation, seperationEq, rightSeperation = addEquation(nw, "\\psi(x,t)", "\\varphi(x)\\chi(t)")

leftTimeDepSchroed, timeDepSchroed, varSepConst= addEquation(nw, "\\frac{i\\hbar}{\\chi(t)}\\frac{d\\chi(t)}{dt}", "E")
leftSpatialSchroed, spatialSchroed, _ = addEquation(nw, "-\\frac{\\hbar^2}{2m \\varphi(x)}\\Delta\\varphi(x) + V(x)", "E")
addFromThisFollowsThat(nw, [schroedEq, seperationEq], [timeDepSchroed, spatialSchroed])

leftEnergyAssertion, energyAssertion, _ = addEquation(nw, "\\hbar\\omega", varSepConst)

leftSolutionOfChi, solutionOfChi, rightSolutionOfChi = addEquation(nw, "\\chi(t)", "Ae^{-i\\omega t}")
addFromThisFollowsThat(nw, [timeDepSchroed, energyAssertion], [solutionOfChi])

if ham:
    leftDiffEqOfPhi, diffEqOfPhi, rightDiffEqOfPhi = addEquation(nw, "-\\frac{\\hbar^2}{2m}\\Delta\\varphi(x) + V\\varphi(x)", "\\hbar\\omega\\varphi(x)")
    addFromThisFollowsThat(nw, [spatialSchroed, energyAssertion], [diffEqOfPhi])

    leftHamDef, hamDef, rightHamDef = addEquation(nw, "H", "-\\frac{\\hbar^2}{2m}\\Delta + V(x)")
    _, hamDiff, rightHamDif = addEquation(nw, leftDiffEqOfPhi, "H\\varphi(x)")
    addFromThisFollowsThat(nw, [hamDef], [hamDiff])

_, statSol, rightStatSol = addEquation(nw, leftSeperation, "\\varphi(x)e^{-i\\omega t}")
addFromThisFollowsThat(nw, [seperationEq, solutionOfChi], [statSol])

leftLinComb, linComb, rightLinComb = addEquation(nw, "\\psi(x,t)", "\\sum_n \\psi_n(x, t)")
linCombDeriv = addFromThisFollowsThat(nw, [schroedEq], [linComb])

_, solution, rightSolution = addEquation(nw, leftLinComb, "\\sum_n c_n\\varphi_n(x)e^{-i\\omega_n t}")
addFromThisFollowsThat(nw, [linComb, statSol], [solution])

trueNw = nw

In [88]:
def isDerivationTrue(derivation):
    toEquationsString = ("the equation" if len(derivation.toEquations) == 1 else "the equations ") + " , ".join([f"{eq.leftSide} = {eq.rightSide}" for eq in derivation.toEquations])
    fromEquationsString = ("the equation" if len(derivation.fromEquations) == 1 else "the equations ") + " , ".join([f"{eq.leftSide} = {eq.rightSide}" for eq in derivation.fromEquations])
    #symbolDefinitionString = "{" + ", ".join([f'{symbol} : "{definition}"' for symbol, definition in derivation.symbols.items()]) + "}"
    query = f'Can {fromEquationsString} be derived from {toEquationsString}? Return Y or N without further explanation.' # The symbol definitions of the equations are {symbolDefinitionString}.'
    def answerConversion(answer):
        assert answer.strip().lower() in ["y", "n"]
        return answer.strip().lower() == "y"
    return tryRecieveAnswer(query, answerConversion = answerConversion)[0]

In [51]:
def checkDerivations(network):
    wrongDerivations = []
    trueDerivations = []
    for derivation in network.derivations:
        if isDerivationTrue(derivation):
            trueDerivations.append(derivation)
        else:
            wrongDerivations.append(derivation)
    return (trueDerivations, wrongDerivations)

In [44]:
trueNwResults = checkDerivations(trueNw)
print(len(trueNwResults[0]), len(trueNwResults[1]))

4 1


In [45]:
for derivation in trueNwResults[1]:
    print("from:")
    for eq in derivation.fromEquations:
        print(eq.asString())
    print("to:")
    for eq in derivation.toEquations:
        print(eq.asString())
    print("")

from:
i\hbar\frac{\partial}{\partial t}\psi(x,t) = -\frac{\hbar^2}{2m}\Delta\psi(x,t)+V(x)\psi(x,t)
to:
\psi(x,t) = \sum_n \psi_n(x, t)



In [46]:
nw = DerivationNetwork()
showSymbolDefinition = False
ham = False
leftSchroed, schroedEq, rightSchroed = addEquation(nw, "i\\hbar\\frac{\\partial}{\\partial x}\\psi(x,t)", "-\\frac{\\hbar^2}{2m}\\Delta\\psi(x,t)+V(x)\\psi(x,t)")

leftSeperation, seperationEq, rightSeperation = addEquation(nw, "\\chi(x,t)", "\\varphi(x)\\psi(t)")

leftTimeDepSchroed, timeDepSchroed, varSepConst= addEquation(nw, "\\frac{i\\hbar}{\\chi(t)}\\frac{d\\chi(t)}{dt}", "E")
leftSpatialSchroed, spatialSchroed, _ = addEquation(nw, "-\\frac{\\hbar^2}{m \\varphi(x)}\\Delta\\varphi(x) + V(x)", varSepConst)
addFromThisFollowsThat(nw, [schroedEq, seperationEq], [timeDepSchroed, spatialSchroed])

leftEnergyAssertion, energyAssertion, _ = addEquation(nw, "\\hbar\\omega", varSepConst)

leftSolutionOfChi, solutionOfChi, rightSolutionOfChi = addEquation(nw, "\\chi(t)", "Ae^{-\\omega t}")
addFromThisFollowsThat(nw, [timeDepSchroed, energyAssertion], [solutionOfChi])

if ham:
    leftDiffEqOfPhi, diffEqOfPhi, rightDiffEqOfPhi = addEquation(nw, "-\\frac{\\hbar^2}{2m}\\Delta\\varphi(x) + V\\varphi(x)", "\\hbar\\omega\\varphi(x)")
    addFromThisFollowsThat(nw, [spatialSchroed, energyAssertion], [diffEqOfPhi])

    leftHamDef, hamDef, rightHamDef = addEquation(nw, "H", "-\\frac{\\hbar^2}{2m}\\Delta + V(x)")
    _, hamDiff, rightHamDif = addEquation(nw, leftDiffEqOfPhi, "H\\varphi(x)")
    addFromThisFollowsThat(nw, [hamDef], [hamDiff])

_, statSol, rightStatSol = addEquation(nw, leftSeperation, "\\varphi(x)e^{-i\\omega^2 t}")
addFromThisFollowsThat(nw, [seperationEq, solutionOfChi], [statSol])

leftLinComb, linComb, rightLinComb = addEquation(nw, "\\psi(x,t)", "\\prod_n \\psi_n(x, t)")
addFromThisFollowsThat(nw, [schroedEq], [linComb])

_, solution, rightSolution = addEquation(nw, leftLinComb, "\\sum_n c_n\\varphi_n(x)e^{-i\\omega_n t}")
addFromThisFollowsThat(nw, [linComb, statSol], [solution])

falseNw = nw


In [47]:
falseNwResults = checkDerivations(falseNw)
print(len(falseNwResults[0]), len(falseNwResults[1]))

0 5
