Skip to content

Commit

Permalink
refactor qnlp-snake-removal
Browse files Browse the repository at this point in the history
  • Loading branch information
toumix committed Feb 11, 2020
1 parent abdf7cd commit 4fd416a
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 97 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@
from discopy.circuit import Circuit, Gate, Bra, Ket, CircuitFunctor
from discopy.pregroup import Word

__version__ = '0.2.2'
__version__ = '0.2.3'
54 changes: 27 additions & 27 deletions notebooks/qnlp-experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"Alice who loves Alice is rich.\n",
"Alice who loves Bob is rich.\n",
"\n",
"9.83 seconds to generate 10 sentences.\n"
"8.18 seconds to generate 10 sentences.\n"
]
}
],
Expand Down Expand Up @@ -235,7 +235,7 @@
" └───┘ └───┘ </pre>"
],
"text/plain": [
"<qiskit.visualization.text.TextDrawing at 0x134d04a90>"
"<qiskit.visualization.text.TextDrawing at 0x1280c8fd0>"
]
},
"execution_count": 6,
Expand Down Expand Up @@ -279,8 +279,8 @@
"from pytket.backends.ibm import AerBackend\n",
"from discopy.circuit import matrix_from_counts\n",
"\n",
"def evaluate(F, sentences, backend=AerBackend(noise_model), n_shots=2**10, seed=42):\n",
" circuits = [F(parsing[s]).to_tk().measure_all() for s in sentences]\n",
"def evaluate(params, sentences, backend=AerBackend(noise_model), n_shots=2**10, seed=42):\n",
" circuits = [F(params)(parsing[s]).to_tk().measure_all() for s in sentences]\n",
" list(map(backend.default_compilation_pass.apply, circuits))\n",
" backend.process_circuits(circuits, n_shots=n_shots, seed=seed)\n",
" return [matrix_from_counts(\n",
Expand All @@ -299,7 +299,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"3.275 seconds to compute the corpus.\n",
"3.175 seconds to compute the corpus.\n",
"\n",
"True sentences:\n",
"Alice is rich. (0.760)\n",
Expand All @@ -323,7 +323,7 @@
"from time import time\n",
"\n",
"start = time()\n",
"corpus = dict(zip(sentences, evaluate(F(params0), sentences)))\n",
"corpus = dict(zip(sentences, evaluate(params0, sentences)))\n",
"print(\"{:.3f} seconds to compute the corpus.\\n\".format(time() - start))\n",
"\n",
"delta = .1\n",
Expand Down Expand Up @@ -372,7 +372,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -381,12 +381,12 @@
"def loss(params, sentences=sentence_train):\n",
" return - np.mean(np.array([\n",
" (corpus[sentence] - scalar) ** 2\n",
" for sentence, scalar in zip(sentences, evaluate(F(params), sentences))]))"
" for sentence, scalar in zip(sentences, evaluate(params, sentences))]))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand All @@ -409,29 +409,29 @@
"\n",
"print(\"\\nIs Alice who loves Bob rich?\")\n",
"print(\"Yes, she is.\"\n",
" if evaluate(F(params), ['Alice who loves Bob is rich.'])[0] > .5 + delta\n",
" if evaluate(params, ['Alice who loves Bob is rich.'])[0] > .5 + delta\n",
" else \"No, she isn't.\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 (4 seconds since start): [0.62252593 0.90506185]\n",
"Epoch 2 (8 seconds since start): [0.59053632 0.93705145]\n",
"Epoch 3 (11 seconds since start): [0.55806232 0.96952546]\n",
"Epoch 4 (15 seconds since start): [0.54260212 0.98498566]\n",
"Epoch 5 (19 seconds since start): [0.53112536 0.99646241]\n",
"Epoch 6 (22 seconds since start): [0.5233175 1.00427028]\n",
"Epoch 7 (26 seconds since start): [0.52084204 1.00674574]\n",
"Epoch 8 (30 seconds since start): [0.52248008 1.0051077 ]\n",
"Epoch 9 (33 seconds since start): [0.52288718 1.0047006 ]\n",
"Epoch 10 (38 seconds since start): [0.52284507 1.00474271]\n"
"Epoch 1 (3 seconds since start): [0.62252593 0.90506185]\n",
"Epoch 2 (6 seconds since start): [0.63414504 0.91668096]\n",
"Epoch 3 (8 seconds since start): [0.5855492 0.9652768]\n",
"Epoch 4 (11 seconds since start): [0.56230937 0.98851663]\n",
"Epoch 5 (14 seconds since start): [0.54748535 1.00334065]\n",
"Epoch 6 (16 seconds since start): [0.54391572 1.00691028]\n",
"Epoch 7 (18 seconds since start): [0.53935373 1.01147227]\n",
"Epoch 8 (20 seconds since start): [0.54198379 1.0088422 ]\n",
"Epoch 9 (22 seconds since start): [0.53921175 1.01161424]\n",
"Epoch 10 (24 seconds since start): [0.52035009 0.99275258]\n"
]
},
{
Expand All @@ -442,10 +442,10 @@
" nfev: 20\n",
" nit: 10\n",
" success: True\n",
" x: array([0.52284507, 1.00474271])"
" x: array([0.52035009, 0.99275258])"
]
},
"execution_count": 18,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -467,14 +467,14 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing loss: -0.011108393780887127\n"
"Testing loss: -0.011193843558430672\n"
]
}
],
Expand All @@ -484,7 +484,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand All @@ -499,7 +499,7 @@
"source": [
"print(\"Is Alice who loves Bob rich?\")\n",
"print(\"Yes, she is.\"\n",
" if evaluate(F(result.x), ['Alice who loves Bob is rich.'])[0] > .5 + delta\n",
" if evaluate(result.x, ['Alice who loves Bob is rich.'])[0] > .5 + delta\n",
" else \"No, she isn't.\")"
]
}
Expand Down
Loading

0 comments on commit 4fd416a

Please sign in to comment.