diff --git a/.travis.yml b/.travis.yml index 0004ae99..b374ab1c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ cache: pip before_install: - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test - sudo apt-get update + - sudo apt-get install pandoc - pip install -r requirements.txt - pip install -r requirements.dev.txt - pip install coveralls diff --git a/docs/requirements.txt b/docs/requirements.txt index 2404a798..f385eec8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,3 +3,5 @@ sphinx-jinja sphinxcontrib-bibtex sphinx-rtd-theme recommonmark +nbsphinx +pandoc diff --git a/docs/source/README.md b/docs/source/README.md new file mode 120000 index 00000000..fe840054 --- /dev/null +++ b/docs/source/README.md @@ -0,0 +1 @@ +../../README.md \ No newline at end of file diff --git a/docs/source/README_files b/docs/source/README_files new file mode 120000 index 00000000..81dd36e3 --- /dev/null +++ b/docs/source/README_files @@ -0,0 +1 @@ +../../README_files/ \ No newline at end of file diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst new file mode 100644 index 00000000..46c20e86 --- /dev/null +++ b/docs/source/advanced.rst @@ -0,0 +1,37 @@ + + + +============================ +Advanced Usage: Semirings +============================ + +All of the distributional code is implemented through a series of +semiring objects. These are passed through dynamic programming +backends to compute the distributions. + + +Standard Semirings +=================== + +.. autoclass:: torch_struct.LogSemiring +.. autoclass:: torch_struct.StdSemiring +.. autoclass:: torch_struct.MaxSemiring + +Higher-Order Semirings +========================= +.. autoclass:: torch_struct.EntropySemiring + +Sampling Semirings +=================== + +.. autoclass:: torch_struct.SampledSemiring +.. autoclass:: torch_struct.MultiSampledSemiring + + +Dynamic Programming +=================== + +.. autoclass:: torch_struct.LinearChain +.. autoclass:: torch_struct.SemiMarkov +.. autoclass:: torch_struct.DepTree +.. autoclass:: torch_struct.CKY diff --git a/docs/source/conf.py b/docs/source/conf.py index eb69279d..2531b65a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,7 +30,9 @@ 'sphinx.ext.napoleon', 'sphinxcontrib.jinja', 'sphinxcontrib.bibtex', - 'sphinx.ext.intersphinx' + 'sphinx.ext.intersphinx', + 'recommonmark', + 'nbsphinx' ] diff --git a/docs/source/index.rst b/docs/source/index.rst index 71bab2da..b8526bd3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,135 +1,19 @@ - +================= PyTorch-Struct -========================================== +================= .. toctree:: - :maxdepth: 2 + :maxdepth: 3 :caption: Contents: -Introduction -============ - -A library for structured prediction. - - - -Distributional Interface -======================== - - -The main interface is through a structured distribution objects. Each -of these implement a conditional random field over a class of -structures. Roughly, these represent specialized softmax's over -exponentially sized spaces. Each distribution object takes in -log_potentials (generalized logits) and can return properties of the -distribution. The properties of interest are, - -* Partition (e.g. logsumexp) -* Marginals (e.g. softmax) -* Argmax -* Entropy -* Samples -* to_event / from_event (adapters) - - -.. autoclass:: torch_struct.StructDistribution - :members: - -Linear Chain --------------- - -.. autoclass:: torch_struct.LinearChainCRF - - -Semi-Markov --------------- - -.. autoclass:: torch_struct.SemiMarkovCRF - - -Dependency Tree ----------------- - - -.. autoclass:: torch_struct.DependencyCRF - - -Binary Tree --------------- - -.. autoclass:: torch_struct.TreeCRF - -Context-Free Grammar ---------------------- - -.. autoclass:: torch_struct.SentCFG - - - - - -Networks -=========== - -Common structured networks. - - -.. autoclass:: torch_struct.networks.TreeLSTM - -.. autoclass:: torch_struct.networks.NeuralCFG - -.. autoclass:: torch_struct.networks.SpanLSTM - - -Data -==== - -Datasets for common structured prediction tasks. - -.. autoclass:: torch_struct.data.ConllXDataset -.. autoclass:: torch_struct.data.ListOpsDataset - - -Advanced Usage: Semirings -========================= - -All of the distributional code is implemented through a series of -semiring objects. These are passed through dynamic programming -backends to compute the distributions. - - -Standard Semirings ------------------- - -.. autoclass:: torch_struct.LogSemiring -.. autoclass:: torch_struct.StdSemiring -.. autoclass:: torch_struct.MaxSemiring - -Higher-Order Semirings ----------------------- -.. autoclass:: torch_struct.EntropySemiring - -Sampling Semirings ----------------------- - -.. autoclass:: torch_struct.SampledSemiring -.. autoclass:: torch_struct.MultiSampledSemiring - - -Dynamic Programming -------------------- - -.. autoclass:: torch_struct.LinearChain -.. autoclass:: torch_struct.SemiMarkov -.. autoclass:: torch_struct.DepTree -.. autoclass:: torch_struct.CKY - + README + model + networks + advanced + refs -References -========== -.. bibliography:: refs.bib Indices and tables diff --git a/docs/source/model.ipynb b/docs/source/model.ipynb new file mode 100644 index 00000000..acac731d --- /dev/null +++ b/docs/source/model.ipynb @@ -0,0 +1,522 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. toctree:: \n", + " :maxdepth: 2 \n", + " :caption: Contents: \n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import torch_struct\n", + "import torch\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linear Chain" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.LinearChainCRF " + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABiCAYAAAB5/Jk6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAF5UlEQVR4nO3dz4vd5R3F8XOcMYmJ1NYfm/xAs1DLUPzFkNoKLoyQSKVuDbSLImSjbSwFsf0biuhCCkFjFxVdRBciodOFghsRU01bY6qEVJ1JLEaltVRKjD0u7r18RxmdW7zfPJ/Ofb9Wc78Xnnt4nrmH732+94eTCABQ13mtAwAAvhpFDQDFUdQAUBxFDQDFUdQAUBxFDQDFzfYx6DqvzwZt6mPosV11zcdNH3/kzT9vbB2hjAprwnp0KqwHOm8tfqL3P/zUK93nPt5H/Q1fnO9658TH/V8snDrS9PFHdm2+rnWEMiqsCevRqbAe6OzYtajDf/rPikXN1gcAFEdRA0BxFDUAFEdRA0BxFDUAFEdRA0BxFDUAFEdRA0BxFDUAFEdRA0BxFDUAFDdWUdvebfsN28dt3993KABAZ9Witj0j6WFJt0mak7TH9lzfwQAAA+OcUe+QdDzJiSRnJD0p6Y5+YwEARsYp6i2SFpfdXhoeAwCcAxP74QDbeyXtlaQN4svZAWBSxjmjPilp27LbW4fHPifJ/iTzSebP1/pJ5QOAqTdOUb8s6Urb222vk3SnpGf6jQUAGFl16yPJWdv3SFqQNCPpQJKjvScDAEgac486ySFJh3rOAgBYAZ9MBIDiKGoAKI6iBoDiKGoAKI6iBoDiKGoAKI6iBoDiKGoAKI6iBoDiKGoAKI6iBoDiJvZ91Mtddc3HWlg40sfQY9u1+bqmjz+ycKrtPFRSYU1Yj06F9UDnzXzwpfdxRg0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFDcqkVt+4Dt92y/di4CAQA+b5wz6t9K2t1zDgDAl1i1qJO8IOnDc5AFALAC9qgBoLiJFbXtvbYP2z58+oNPJzUsAEy9iRV1kv1J5pPMX3bJzKSGBYCpx9YHABQ3ztvznpD0oqSrbS/Zvqv/WACAkVV/3DbJnnMRBACwMrY+AKA4ihoAiqOoAaA4ihoAiqOoAaA4ihoAiqOoAaA4ihoAiqOoAaA4ihoAiqOoAaA4ihoAinOSyQ9qn5b09tcY4lJJ708ozv875qLDXHSYi85amYvLk1y20h29FPXXZftwkvnWOSpgLjrMRYe56EzDXLD1AQDFUdQAUFzVot7fOkAhzEWHuegwF501Pxcl96gBAJ2qZ9QAgKFyRW17t+03bB+3fX/rPK3Y3mb7eduv2z5qe1/rTK3ZnrH9qu1nW2dpyfY3bR+0/Vfbx2x/r3WmVmz/fPj8eM32E7Y3tM7Uh1JFbXtG0sOSbpM0J2mP7bm2qZo5K+kXSeYk3Sjp7imei5F9ko61DlHAQ5J+n+Tbkq7VlM6J7S2SfiZpPsl3JM1IurNtqn6UKmpJOyQdT3IiyRlJT0q6o3GmJpK8m+SV4d//0uDJuKVtqnZsb5X0A0mPtM7Sku2LJN0s6VFJSnImyT/apmpqVtIFtmclbZR0qnGeXlQr6i2SFpfdXtIUl9OI7SskXS/ppbZJmnpQ0n2S/ts6SGPbJZ2W9NhwG+gR25tah2ohyUlJv5b0jqR3Jf0zyR/apupHtaLGF9i+UNJTku5N8lHrPC3Yvl3Se0n+2DpLAbOSbpD0myTXS/q3pKm8lmP7Wxq84t4uabOkTbZ/1DZVP6oV9UlJ25bd3jo8NpVsn69BST+e5OnWeRq6SdIPbb+lwXbYLbZ/1zZSM0uSlpKMXl0d1KC4p9Gtkv6W5HSSTyQ9Len7jTP1olpRvyzpStvbba/T4MLAM40zNWHbGuxDHkvyQOs8LSX5ZZKtSa7Q4H/iuSRr8sxpNUn+LmnR9tXDQzslvd4wUkvvSLrR9sbh82Wn1uiF1dnWAZZLctb2PZIWNLiCeyDJ0caxWrlJ0o8l/cX2keGxXyU51DATavippMeHJzMnJP2kcZ4mkrxk+6CkVzR4l9SrWqOfUuSTiQBQXLWtDwDAF1DUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFDcZwZTXGTwGrZDAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, C = 3, 10, 2\n", + "def show_chain(chain):\n", + " plt.imshow(chain.detach().sum(-1).transpose(0, 1))\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, C, C)\n", + "dist = torch_struct.LinearChainCRF(log_potentials)\n", + "show_chain(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABiCAYAAAB5/Jk6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAGTElEQVR4nO3dTYzUdx3H8fenu7QUMJTGJkbAFg2hIT6kumlQEg/WA41GEuOhJHogTfDQJx8SUz16NkZjGhNssQcbesBqGkXxoIkX0xShiaW0CUEpIFrQihWkdOHrYWcztAF3THf4/WTfrxMzm/z3k3923zv7nx0mVYUkqV/XtR4gSfrvDLUkdc5QS1LnDLUkdc5QS1LnDLUkdW5yLAddsrQWLb95HIce2bIVZ5t+/lmvH7mh9QRqso+fx+dvar0AqLReAMDik9OtJ/CetadaTwDghky0nsDBf69oPYHX/3qa6dNnL/sFOpZQL1p+M2u2fmUchx7Zxs/ub/r5Zx354ntbT+Dcu5a0ngDAkc3tI3nduT5+aK37wautJ/C9nz/WegIA71u0rPUENjz3udYTOPDg41f8WB9ftZKkKzLUktQ5Qy1JnTPUktQ5Qy1JnTPUktQ5Qy1JnTPUktQ5Qy1JnTPUktQ5Qy1JnRsp1Ek2JXkpyaEkD497lCRpaM5QJ5kAHgHuBtYDW5KsH/cwSdKMUR5R3wkcqqrDVXUeeBLYPN5ZkqRZo4R6JXD0ktvHBvdJkq6CeXsyMcm2JHuT7J0+e2a+DitJC94ooT4OrL7k9qrBfW9SVduraqqqpiaXLJ2vfZK04I0S6meBtUnWJLkeuAd4eryzJEmz5nwrrqqaTnI/sAeYAHZU1YGxL5MkASO+Z2JV7QZ2j3mLJOkyfGWiJHXOUEtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHVupP/r43918caLnPvg2XEcemQvfvMDTT//rK07f9p6Avv+dWvrCQCc2DPVegIXFlfrCQDc+5NftJ7AAx/p442aLpz6W+sJvLHtltYTqFevnGMfUUtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHXOUEtS5wy1JHVuzlAn2ZHklSTPX41BkqQ3G+UR9ePApjHvkCRdwZyhrqrfAn+/ClskSZfhNWpJ6ty8hTrJtiR7k+y98NqZ+TqsJC148xbqqtpeVVNVNTXxjqXzdVhJWvC89CFJnRvlz/N2Ar8D1iU5luTe8c+SJM2a881tq2rL1RgiSbo8L31IUucMtSR1zlBLUucMtSR1zlBLUucMtSR1zlBLUucMtSR1zlBLUucMtSR1zlBLUucMtSR1LlU1/wdNTgJH3sYh3gmcmqc5/+88F0OeiyHPxdC1ci5urapbLveBsYT67Uqyt6qmWu/ogediyHMx5LkYWgjnwksfktQ5Qy1Jnes11NtbD+iI52LIczHkuRi65s9Fl9eoJUlDvT6iliQNdBfqJJuSvJTkUJKHW+9pJcnqJL9J8kKSA0kear2ptSQTSfYn+VnrLS0luSnJriQvJjmY5KOtN7WS5MuD74/nk+xMsrj1pnHoKtRJJoBHgLuB9cCWJOvbrmpmGvhqVa0HNgD3LeBzMesh4GDrER34LvDLqrod+BAL9JwkWQk8CExV1fuBCeCetqvGo6tQA3cCh6rqcFWdB54ENjfe1ERVnaiqfYN/v8bMN+PKtqvaSbIK+BTwaOstLSVZDnwceAygqs5X1T/armpqErgxySSwBPhz4z1j0VuoVwJHL7l9jAUcp1lJbgPuAJ5pu6Sp7wBfAy62HtLYGuAk8MPBZaBHkyxtPaqFqjoOfAt4GTgBnK6qX7VdNR69hVpvkWQZ8GPgS1X1z9Z7WkjyaeCVqvp96y0dmAQ+DHy/qu4AzgAL8rmcJCuY+Y17DfBuYGmSz7ddNR69hfo4sPqS26sG9y1ISRYxE+knquqp1nsa2gh8JsmfmLkc9okkP2o7qZljwLGqmv3tahcz4V6IPgn8sapOVtUbwFPAxxpvGoveQv0ssDbJmiTXM/PEwNONNzWRJMxchzxYVd9uvaelqvp6Va2qqtuY+Zr4dVVdk4+c5lJVfwGOJlk3uOsu4IWGk1p6GdiQZMng++UurtEnVidbD7hUVU0nuR/Yw8wzuDuq6kDjWa1sBL4A/CHJc4P7vlFVuxtuUh8eAJ4YPJg5DGxtvKeJqnomyS5gHzN/JbWfa/RVir4yUZI619ulD0nSWxhqSeqcoZakzhlqSeqcoZakzhlqSeqcoZakzhlqSercfwA7zWreE1M3WwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_chain(dist.marginals[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABiCAYAAAB5/Jk6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAF1klEQVR4nO3dzYtd9R3H8c+nM8aYSFufNnmgZqGWofjEEJ/ARVNIpKJbA3YhhWzqQ0UQ9W8Q0YUIQaMLRRfRhUjoFKrgRoJRgxpTJaTWTKKYKq1i0Bj9uLj3cicydq54T35f575fqznnwrkfvpz74czv3AcnEQCgrl+0DgAA+P8oagAojqIGgOIoagAojqIGgOIoagAobrqLg67w6Vmp1V0cemQXXnys6fMPvPfmqtYRsADnxRCzqOVLfaHj+cqLPeYu3kf9S5+dK7xp7Mf9MeaO7G36/AOb11zaOgIW4LwYYha17M4/9Fk+XbSoWfoAgOIoagAojqIGgOIoagAojqIGgOIoagAojqIGgOIoagAojqIGgOIoagAojqIGgOJGKmrbW2y/a/uA7Xu6DgUAGFqyqG1PSXpY0nWSZiRttT3TdTAAQM8oV9QbJR1IcjDJcUnPSLqx21gAgIFRinqtpEMLtuf7+wAAp8DYfjjA9jZJ2yRppfgicAAYl1GuqA9LWr9ge11/30mSbE8ym2T2NJ0+rnwAMPFGKepXJV1ge4PtFZJukvR8t7EAAANLLn0kOWH7VklzkqYk7Uiyr/NkAABJI65RJ9klaVfHWQAAi+CTiQBQHEUNAMVR1ABQHEUNAMVR1ABQHEUNAMVR1ABQHEUNAMVR1ABQHEUNAMVR1ABQ3Ni+j3qhCy8+prm5vV0cemSb11za9PkH5o60nQNOxnkxxCxq2bj52A8+xhU1ABRHUQNAcRQ1ABRHUQNAcRQ1ABRHUQNAcRQ1ABRHUQNAcRQ1ABRHUQNAcRQ1ABRHUQNAcUsWte0dtj+2/fapCAQAONkoV9RPSNrScQ4AwA9YsqiTvCzp01OQBQCwCNaoAaC4sRW17W2299jec/STb8Z1WACYeGMr6iTbk8wmmT3vnKlxHRYAJh5LHwBQ3Chvz3ta0iuSLrI9b/vP3ccCAAws+eO2SbaeiiAAgMWx9AEAxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFAcRQ0AxVHUAFCck4z/oPZRSf/+CYc4V9J/xhTn545ZDDGLIWYxtFxm8Zsk5y32QCdF/VPZ3pNktnWOCpjFELMYYhZDkzALlj4AoDiKGgCKq1rU21sHKIRZDDGLIWYxtOxnUXKNGgAwVPWKGgDQV66obW+x/a7tA7bvaZ2nFdvrbb9k+x3b+2zf0TpTa7anbL9h+4XWWVqy/WvbO23/0/Z+21e1ztSK7Tv7r4+3bT9te2XrTF0oVdS2pyQ9LOk6STOSttqeaZuqmROS7koyI+lKSX+Z4FkM3CFpf+sQBTwk6W9JfivpEk3oTGyvlXS7pNkkv5M0Jemmtqm6UaqoJW2UdCDJwSTHJT0j6cbGmZpI8mGS1/t/f67ei3Ft21Tt2F4n6Y+SHm2dpSXbv5J0raTHJCnJ8ST/bZuqqWlJZ9ielrRK0pHGeTpRrajXSjq0YHteE1xOA7bPl3SZpN1tkzT1oKS7JX3bOkhjGyQdlfR4fxnoUdurW4dqIclhSfdL+kDSh5L+l+TvbVN1o1pR43tsnynpWUl/TfJZ6zwt2L5e0sdJXmudpYBpSZdLeiTJZZK+kDSR93Jsn6Xef9wbJK2RtNr2zW1TdaNaUR+WtH7B9rr+volk+zT1SvqpJM+1ztPQNZJusP2+esthv7f9ZNtIzcxLmk8y+O9qp3rFPYn+IOlfSY4m+VrSc5KubpypE9WK+lVJF9jeYHuFejcGnm+cqQnbVm8dcn+SB1rnaSnJvUnWJTlfvXPixSTL8sppKUk+knTI9kX9XZskvdMwUksfSLrS9qr+62WTlumN1enWARZKcsL2rZLm1LuDuyPJvsaxWrlG0p8kvWV7b3/ffUl2NcyEGm6T9FT/YuagpFsa52kiyW7bOyW9rt67pN7QMv2UIp9MBIDiqi19AAC+h6IGgOIoagAojqIGgOIoagAojqIGgOIoagAojqIGgOK+A95AYIBMilsJAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "event = dist.to_event(torch.tensor([[0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1]]), 2)\n", + "show_chain(event[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HMM" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.HMM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Semi-Markov" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.SemiMarkovCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABACAYAAAAzmD0HAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAHXklEQVR4nO3df+xVdR3H8eerL78WmYI4Q2AlztxoS2PfkTVzbjQE1qSca7C2KN2cK7b8ozU2N+f6z1r9UXM1K6Y1p6wfFms4JGvrH0GJfUFIha+MJoRg0qByidi7P87n4uVy7v0exznnfvjyemx333PP+dxzXvfDh/f3nh/fcxURmJlZvt437ABmZjaYC7WZWeZcqM3MMudCbWaWORdqM7PMTWlipdM0PWYws4lVt+qjH3/zvNexb/f7a0gyOUyW/pws78Py8l/+w6l4S2XL1MTleR/U7Piklta+3rZt+fvYea/j1qtuqCHJ5DBZ+nOyvA/Ly/Z4hpNxvLRQ+9CHmVnmXKjNzDLnQm1mljkXajOzzFUq1JKWS3pZ0rik9U2HMjOzd01YqCWNAA8BK4BFwBpJi5oOZmZmhSqfqJcA4xFxICJOAU8Aq5qNZWZmHVUK9Tzg1a7nh9K8s0i6W9IOSTve5q268pmZXfRqO5kYEQ9HxGhEjE5lel2rNTO76FUp1IeBBV3P56d5ZmbWgiqF+nngWklXS5oGrAY2NRvLzMw6JrwpU0SclrQO2AKMABsiYm/jyczMDKh497yI2AxsbjiLmZmV8F8mmpllzoXazCxzjdyPevT6GfHclgUTNxzA9+s1a9Zkuq/2+b6XHN6H70dtZnYBc6E2M8ucC7WZWeZcqM3MMlflNqcbJB2TtKeNQGZmdrYqn6gfAZY3nMPMzPqYsFBHxJ+B4y1kMTOzErUdo+6+H/Xrb7xT12rNzC56jdyP+orLR+parZnZRc9XfZiZZc6F2swsc1Uuz3sceBa4TtIhSXc1H8vMzDqqfHHAmjaCmJlZOR/6MDPLnAu1mVnmXKjNzDLXyBcHSHod+NuAJnOAf9S+4fo5Z30uhIzgnHVzzuo+HBFXlC1opFBPRNKOiBhtfcPvkXPW50LICM5ZN+eshw99mJllzoXazCxzwyrUDw9pu++Vc9bnQsgIzlk356zBUI5Rm5lZdT70YWaWORdqM7PMNVqoJS2X9LKkcUnrS5ZPl7QxLd8u6SNN5umTcYGkP0n6q6S9kr5R0uYWSSckjaXH/W3nTDkOSnohZdhRslySfpD6c7ekxS3nu66rj8YknZR0b0+bofRl2Xd/Spotaauk/ennrD6vXZva7Je0dgg5vyvppfRv+qSky/q8duD4aCHnA5IOd/3bruzz2oF1oYWcG7syHpQ01ue1rfXnhCKikQcwArwCLASmAbuART1tvgb8OE2vBjY2lWdAzrnA4jR9CbCvJOctwO/bzlaS9SAwZ8DylcBTgIAbge1DzDoCvEZxEf/Q+xK4GVgM7Oma9x1gfZpeDzxY8rrZwIH0c1aantVyzmXAlDT9YFnOKuOjhZwPAN+sMC4G1oWmc/Ys/x5w/7D7c6JHk5+olwDjEXEgIk4BTwCretqsAh5N078ClkpSg5nOERFHImJnmv4X8CIwr80MNVoF/DwK24DLJM0dUpalwCsRMegvVFsT5d/92T3+HgU+X/LSW4GtEXE8Iv4JbKXBL3suyxkRT0fE6fR0GzC/qe1X1ac/q6hSF2ozKGeqNV8EHm9q+3VpslDPA17ten6IcwvgmTZpIJ4ALm8w00Dp0MsngO0liz8laZekpyR9rNVg7wrgaUl/kXR3yfIqfd6W1fT/D5BDXwJcGRFH0vRrwJUlbXLqU4A7Kfaaykw0PtqwLh2i2dDnUFJO/fkZ4GhE7O+zPIf+BHwy8QxJHwB+DdwbESd7Fu+k2IW/Hvgh8Nu28yU3RcRiYAXwdUk3DynHQJKmAbcBvyxZnEtfniWKfd2sr1WVdB9wGnisT5Nhj48fAdcANwBHKA4r5GwNgz9ND7s/z2iyUB8GFnQ9n5/mlbaRNAW4FHijwUylJE2lKNKPRcRvepdHxMmI+Hea3gxMlTSn5ZhExOH08xjwJMVuZLcqfd6GFcDOiDjauyCXvkyOdg4NpZ/HStpk0aeSvgJ8DvhS+qVyjgrjo1ERcTQi3omI/wE/6bP9XPpzCnA7sLFfm2H3Z7cmC/XzwLWSrk6fsFYDm3rabAI6Z9HvAP7YbxA2JR2n+hnwYkR8v0+bD3WOnUtaQtFvrf5CkTRT0iWdaYoTTHt6mm0Cvpyu/rgRONG1a9+mvp9UcujLLt3jby3wu5I2W4BlkmalXfllaV5rJC0HvgXcFhFv9mlTZXw0qud8yBf6bL9KXWjDZ4GXIuJQ2cIc+vMsTZ6ppLgKYR/FWd770rxvUww4gBkUu8fjwHPAwrbPpgI3Uezy7gbG0mMlcA9wT2qzDthLcYZ6G/DpIeRcmLa/K2Xp9Gd3TgEPpf5+ARgdQs6ZFIX30q55Q+9Lil8cR4C3KY6L3kVxPuQZYD/wB2B2ajsK/LTrtXemMToOfHUIOccpjut2xmfnSqmrgM2DxkfLOX+Rxt1uiuI7tzdnen5OXWgzZ5r/SGdMdrUdWn9O9PCfkJuZZc4nE83MMudCbWaWORdqM7PMuVCbmWXOhdrMLHMu1GZmmXOhNjPL3P8Bo8rrNAqbhkwAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, C, K = 3, 20, 2, 6\n", + "def show_sm(chain):\n", + " plt.imshow(chain.detach().sum(1).sum(-1).transpose(0, 1))\n", + "\n", + "# batch, N, K, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, K, C, C)\n", + "log_potentials[:, :, :3] = -1e9\n", + "dist = torch_struct.SemiMarkovCRF(log_potentials)\n", + "show_sm(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABACAYAAAAzmD0HAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAIJElEQVR4nO3dfYwUdx3H8ffn7ri2UCxHKS1PETBNE/qHlRDSmkqa1CAQU9QYAzERbZOmURIbYwymSW38rxr9Q9NoqpJW0xTiQ5U0NC1Wjf8ULBIebQtXggJSqIXw0JbH+/rHzMFyzO5t+9uZHZLPK9nc7M5vZj/32999b3ZmdlYRgZmZ1VdPtwOYmVlrLtRmZjXnQm1mVnMu1GZmNedCbWZWc31lrHRgYk9MnZ626v07ru9Qmg/vzMyxyevoO6n0dbx3IWn582N7kzMMpa+C3rPp64C0s5SiJ/316LmQfqbUUG96jkj86+3pwOvRczZtbAJEX/r24rnxaf3Zf/Dd5AypTvMuZ+NM4S9SSqGeOr2PNc9PTlrHt2be1aE0H97u781LXsfkv41JXsfErceSlj96x0ByhtMD6YXlI/vT/6iVWCTPjU0vCv0nh5LXcWZCeo7TE9PWMb4Dr8e4/5xKXsfpyekbRP9dkFbKZn33leQMqTbFy03nedeHmVnNuVCbmdWcC7WZWc25UJuZ1VxbhVrSIklvSBqUtKrsUGZmdsmohVpSL/AEsBiYAyyXNKfsYGZmlmlni3o+MBgReyPiLLAGWFpuLDMzG9ZOoZ4G7G+4fyB/7DKSHpS0WdLmY0fTzzM1M7NMxw4mRsSTETEvIuYNJJ6Ib2Zml7RTUQ8CMxruT88fMzOzCrRTqF8FbpU0S1I/sAxYV24sMzMbNuoH5CPivKSVwItAL7A6InaVnszMzIA2L8oUEeuB9SVnMTOzAj7qZ2ZWcy7UZmY1V8r1qK9TD7f3X1fGqiulU+lXy9dQ+kXmU6/BfM3x9OsO957twMXdx6Zf03rskbTf5f0b01/TMe8nr4Jxh84lr0NDadc6P39t+usx1J9eQo7PSr9m+7lbziSvo868RW1mVnMu1GZmNedCbWZWcy7UZmY1185lTldLOiJpZxWBzMzscu1sUT8FLCo5h5mZNTFqoY6IvwNHK8hiZmYFOraPuvF61G+/k37erpmZZUq5HvVNHfhQgZmZZXzWh5lZzblQm5nVXDun5z0LvALcJumApAfKj2VmZsPa+eKA5VUEMTOzYt71YWZWcy7UZmY150JtZlZziki/sP0VK5XeBv7doskk4H8df+LOc87OuRoygnN2mnO276MRcVPRjFIK9WgkbY6IeZU/8QfknJ1zNWQE5+w05+wM7/owM6s5F2ozs5rrVqF+skvP+0E5Z+dcDRnBOTvNOTugK/uozcysfd71YWZWcy7UZmY1V2qhlrRI0huSBiWtKph/jaS1+fxNkmaWmadJxhmS/irpX5J2SfpmQZt7JB2XtDW/PVp1zjzHPkk78gybC+ZL0k/y/twuaW7F+W5r6KOtkk5IenhEm670ZdF3f0qaKGmDpD35z4Emy67I2+yRtKILOX8o6fX8NX1O0oQmy7YcHxXkfEzSwYbXdkmTZVvWhQpyrm3IuE/S1ibLVtafo4qIUm5AL/AmMBvoB7YBc0a0+Trw83x6GbC2rDwtck4B5ubT44HdBTnvAZ6vOltB1n3ApBbzlwAvAALuBDZ1MWsv8BbZSfxd70tgATAX2Nnw2A+AVfn0KuDxguUmAnvznwP59EDFORcCffn040U52xkfFeR8DPh2G+OiZV0oO+eI+T8CHu12f452K3OLej4wGBF7I+IssAZYOqLNUuDpfPp3wL2SVGKmK0TEoYjYkk+fBF4DplWZoYOWAr+OzEZggqQpXcpyL/BmRLT6hGplovi7PxvH39PA5woW/QywISKORsQxYAMlftlzUc6IeCkizud3NwLTy3r+djXpz3a0Uxc6plXOvNZ8CXi2rOfvlDIL9TRgf8P9A1xZAC+2yQficeDGEjO1lO96+QSwqWD2XZK2SXpB0u2VBrskgJck/VPSgwXz2+nzqiyj+R9AHfoS4OaIOJRPvwXcXNCmTn0KcD/Zu6Yio42PKqzMd9GsbrIrqU79+SngcETsaTK/Dv0J+GDiRZKuB34PPBwRJ0bM3kL2Fv7jwE+BP1adL3d3RMwFFgPfkLSgSzlaktQP3Af8tmB2XfryMpG91631uaqSHgHOA880adLt8fEz4GPAHcAhst0Kdbac1lvT3e7Pi8os1AeBGQ33p+ePFbaR1AfcALxTYqZCksaQFelnIuIPI+dHxImIOJVPrwfGSJpUcUwi4mD+8wjwHNnbyEbt9HkVFgNbIuLwyBl16cvc4eFdQ/nPIwVtatGnkr4KfBb4cv5P5QptjI9SRcThiLgQEUPAL5o8f136sw/4ArC2WZtu92ejMgv1q8CtkmblW1jLgHUj2qwDho+ifxH4S7NBWJZ8P9WvgNci4sdN2twyvO9c0nyyfqv0H4qkcZLGD0+THWDaOaLZOuAr+dkfdwLHG97aV6nplkod+rJB4/hbAfypoM2LwEJJA/lb+YX5Y5WRtAj4DnBfRLzXpE0746NUI46HfL7J87dTF6rwaeD1iDhQNLMO/XmZMo9Ukp2FsJvsKO8j+WPfJxtwANeSvT0eBP4BzK76aCpwN9lb3u3A1vy2BHgIeChvsxLYRXaEeiPwyS7knJ0//7Y8y3B/NuYU8ETe3zuAeV3IOY6s8N7Q8FjX+5LsH8ch4BzZftEHyI6HvAzsAf4MTMzbzgN+2bDs/fkYHQS+1oWcg2T7dYfH5/CZUlOB9a3GR8U5f5OPu+1kxXfKyJz5/SvqQpU588efGh6TDW271p+j3fwRcjOzmvPBRDOzmnOhNjOrORdqM7Oac6E2M6s5F2ozs5pzoTYzqzkXajOzmvs/vAoFeuHv9dsAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_sm(dist.marginals[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABiCAYAAAB5/Jk6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAF1klEQVR4nO3dzYuddxnG8e/lTNOYiO/d5AWbRa0Esa0MsVpwYYSkKHbbgC5EyMZqFUGqf4OILooQanRhaRexiyLBCCq4kdjYRm0aKyHW5qXSaNFKpcbo7WLOMNMyOkM7J7/bnO9nNec5hx8XN8+5eOb3nDOTqkKS1NcbRgeQJP1vFrUkNWdRS1JzFrUkNWdRS1JzFrUkNTc/jUU35frazNZpLK3X4N3v+/voCAD87tdbRkfQCp4XvbzMS1yuf2S15zKNz1G/OW+vD2Tvhq+r1+bYxZOjIwCwb9utoyNoBc+LXo7Xj3mxXli1qN36kKTmLGpJas6ilqTmLGpJas6ilqTmLGpJas6ilqTmLGpJas6ilqTmLGpJas6ilqTm1lXUSfYneTrJmST3TTuUJGnZmkWdZA64H7gT2A0cSLJ72sEkSYvWc0W9BzhTVWer6jLwMHDXdGNJkpasp6i3A+dWPD4/OSZJugo27B8HJDkIHATYjH8IXJI2ynquqC8AO1c83jE59gpVdaiqFqpq4Tqu36h8kjTz1lPUjwE3JdmVZBNwN/DodGNJkpasufVRVVeS3AMcA+aAw1V1aurJJEnAOveoq+oocHTKWSRJq/CbiZLUnEUtSc1Z1JLUnEUtSc1Z1JLUnEUtSc1Z1JLUnEUtSc1Z1JLUnEUtSc1Z1JLUXKpqwxdduGVz/eLYzrVfOAP2bbt1dAQ1dOziydER2pybHWbRwZ595zjxq5ez2nNeUUtScxa1JDVnUUtScxa1JDVnUUtScxa1JDVnUUtScxa1JDVnUUtScxa1JDVnUUtScxa1JDW3ZlEnOZzk+SRPXo1AkqRXWs8V9XeB/VPOIUn6L9Ys6qr6GfDCVcgiSVqFe9SS1NyGFXWSg0lOJDlx6c//2qhlJWnmbVhRV9WhqlqoqoUb3jG3UctK0sxz60OSmlvPx/MeAn4O3JzkfJLPTD+WJGnJ/FovqKoDVyOIJGl1bn1IUnMWtSQ1Z1FLUnMWtSQ1Z1FLUnMWtSQ1Z1FLUnMWtSQ1Z1FLUnMWtSQ1Z1FLUnMWtSQ1l6ra+EWTS8AfXscS7wT+tEFx/t85i2XOYpmzWHatzOJdVXXDak9MpahfryQnqmphdI4OnMUyZ7HMWSybhVm49SFJzVnUktRc16I+NDpAI85imbNY5iyWXfOzaLlHLUla1vWKWpI00a6ok+xP8nSSM0nuG51nlCQ7k/w0yVNJTiW5d3Sm0ZLMJXkiyQ9GZxkpyVuTHEny2ySnk3xwdKZRknxx8v54MslDSTaPzjQNrYo6yRxwP3AnsBs4kGT32FTDXAG+VFW7gduBz87wLJbcC5weHaKBbwI/rKr3ALcwozNJsh34PLBQVe8F5oC7x6aajlZFDewBzlTV2aq6DDwM3DU40xBV9VxVPT75+W8svhm3j001TpIdwMeAB0ZnGSnJW4APA98GqKrLVfWXsamGmgfemGQe2AJcHJxnKroV9Xbg3IrH55nhclqS5EbgNuD42CRDfQP4MvDv0UEG2wVcAr4z2QZ6IMnW0aFGqKoLwNeAZ4HngL9W1Y/GppqObkWtV0nyJuD7wBeq6sXReUZI8nHg+ar65egsDcwD7we+VVW3AS8BM3kvJ8nbWPyNexewDdia5JNjU01Ht6K+AOxc8XjH5NhMSnIdiyX9YFU9MjrPQHcAn0jyDIvbYR9J8r2xkYY5D5yvqqXfro6wWNyz6KPA76vqUlX9E3gE+NDgTFPRragfA25KsivJJhZvDDw6ONMQScLiPuTpqvr66DwjVdVXqmpHVd3I4jnxk6q6Jq+c1lJVfwTOJbl5cmgv8NTASCM9C9yeZMvk/bKXa/TG6vzoACtV1ZUk9wDHWLyDe7iqTg2ONcodwKeA3yQ5OTn21ao6OjCTevgc8ODkYuYs8OnBeYaoquNJjgCPs/gpqSe4Rr+l6DcTJam5blsfkqRXsaglqTmLWpKas6glqTmLWpKas6glqTmLWpKas6glqbn/AIMhXHvVOLuoAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Use -1 for segments.\n", + "event = dist.to_event(torch.tensor([[0, 1, -1, 1, -1, -1, 0, 1, 0, 1, 1]]), (2, 6))\n", + "show_sm(event[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Alignment" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.AlignmentCRF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dependency Trees" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.DependencyCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAJzklEQVR4nO3dzYtdBx2H8efrTNqaqvV1kxdMFr4QRFsZarXgoim0vmA3LipU0E02vlQRpLrxHxDRhQih6sZiF7ELkWIqVhduotM0qElUSq1NmhajohWFJtWfixkhJk3umZt7PDM/ng8EMnduT79M5sm5c3LnTqoKSX28bOoBkhbLqKVmjFpqxqilZoxaamZ5jIO+/rVLtWf3toUf93e/3L7wY2prevPb/zn1hEk9deo8f/rLv/JS7xsl6j27t/Hzw7sXftw7dty48GNqazp8+NjUEyZ18x2nLvs+H35LzRi11IxRS80YtdSMUUvNGLXUzKCok9yZ5LdJnkhy39ijJM1vZtRJloCvA+8D9gEfSbJv7GGS5jPkTH0z8ERVPVlV54AHgbvGnSVpXkOi3glc+PSV0+u3/Y8kB5KsJlk9++d/LWqfpA1a2IWyqjpYVStVtfKG1y0t6rCSNmhI1M8AFz6Re9f6bZI2oSFR/wJ4U5K9Sa4B7ga+P+4sSfOa+V1aVfVikk8Ch4El4FtVdXz0ZZLmMuhbL6vqYeDhkbdIWgCfUSY1Y9RSM0YtNWPUUjNGLTWTMX6W1qvy2npX9i/8uJLWHKkf83z95SVfTdQztdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUjFFLzRi11IxRS80YtdSMUUvNGLXUzKCfpaWNO3zm2NQTNuSOHTdOPUEL4plaasaopWaMWmrGqKVmjFpqxqilZoxaamZm1El2J/lJkhNJjie59/8xTNJ8hjz55EXgc1V1NMkrgceS/KiqToy8TdIcZp6pq+rZqjq6/vu/AyeBnWMPkzSfDT1NNMke4CbgyEu87wBwAOA6ti9gmqR5DL5QluQVwPeAz1TV8xe/v6oOVtVKVa1s49pFbpS0AYOiTrKNtaAfqKqHxp0k6WoMufod4JvAyar6yviTJF2NIWfqW4GPArclObb+6/0j75I0p5kXyqrqZ0D+D1skLYDPKJOaMWqpGaOWmjFqqRlfeHAkvpCfpuKZWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaamZw1EmWkjye5AdjDpJ0dTZypr4XODnWEEmLMSjqJLuADwD3jztH0tUaeqb+KvB54N+Xu0OSA0lWk6ye54WFjJO0cTOjTvJB4I9V9diV7ldVB6tqpapWtnHtwgZK2pghZ+pbgQ8leQp4ELgtyXdGXSVpbjOjrqovVNWuqtoD3A08WlX3jL5M0lz8d2qpmeWN3Lmqfgr8dJQlkhbCM7XUjFFLzRi11IxRS80YtdTMhq5+d3X4zLGFH/OOHTcu/JjSEJ6ppWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmfDVRfOVP9eKZWmrGqKVmjFpqxqilZoxaasaopWaMWmpmUNRJXp3kUJLfJDmZ5N1jD5M0n6FPPvka8MOq+nCSa4DtI26SdBVmRp3kBuC9wMcAquoccG7cWZLmNeTh917gLPDtJI8nuT/J9RffKcmBJKtJVs/zwsKHShpmSNTLwDuBb1TVTcA/gPsuvlNVHayqlapa2ca1C54paaghUZ8GTlfVkfW3D7EWuaRNaGbUVfUccCrJW9Zv2g+cGHWVpLkNvfr9KeCB9SvfTwIfH2+SpKsxKOqqOgasjLxF0gL4jDKpGaOWmjFqqRmjlpoxaqmZLfVqoofPHBvluL6a6Hgf27H4Z3Z5nqmlZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaamZLvfCgLzY3Hj+2fXimlpoxaqkZo5aaMWqpGaOWmjFqqRmjlpoZFHWSzyY5nuTXSb6b5Lqxh0maz8yok+wEPg2sVNXbgCXg7rGHSZrP0Iffy8DLkywD24Ez402SdDVmRl1VzwBfBp4GngX+VlWPXHy/JAeSrCZZPc8Li18qaZAhD79fA9wF7AV2ANcnuefi+1XVwapaqaqVbVy7+KWSBhny8Pt24PdVdbaqzgMPAe8Zd5akeQ2J+mngliTbkwTYD5wcd5akeQ35mvoIcAg4Cvxq/b85OPIuSXMa9P3UVfUl4Esjb5G0AD6jTGrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqxqilZoxaasaopWaMWmrGqKVmjFpqJlW1+IMmZ4E/DLjr64E/LXzAeLbS3q20FbbW3s2w9Y1V9YaXescoUQ+VZLWqViYbsEFbae9W2gpba+9m3+rDb6kZo5aamTrqrfbD67fS3q20FbbW3k29ddKvqSUt3tRnakkLZtRSM5NFneTOJL9N8kSS+6baMUuS3Ul+kuREkuNJ7p160xBJlpI8nuQHU2+5kiSvTnIoyW+SnEzy7qk3XUmSz65/Hvw6yXeTXDf1potNEnWSJeDrwPuAfcBHkuybYssALwKfq6p9wC3AJzbx1gvdC5ycesQAXwN+WFVvBd7BJt6cZCfwaWClqt4GLAF3T7vqUlOdqW8GnqiqJ6vqHPAgcNdEW66oqp6tqqPrv/87a590O6dddWVJdgEfAO6fesuVJLkBeC/wTYCqOldVf5121UzLwMuTLAPbgTMT77nEVFHvBE5d8PZpNnkoAEn2ADcBR6ZdMtNXgc8D/556yAx7gbPAt9e/VLg/yfVTj7qcqnoG+DLwNPAs8LeqemTaVZfyQtlASV4BfA/4TFU9P/Wey0nyQeCPVfXY1FsGWAbeCXyjqm4C/gFs5usrr2HtEeVeYAdwfZJ7pl11qamifgbYfcHbu9Zv25SSbGMt6Aeq6qGp98xwK/ChJE+x9mXNbUm+M+2kyzoNnK6q/z7yOcRa5JvV7cDvq+psVZ0HHgLeM/GmS0wV9S+ANyXZm+Qa1i42fH+iLVeUJKx9zXeyqr4y9Z5ZquoLVbWrqvaw9nF9tKo23dkEoKqeA04lecv6TfuBExNOmuVp4JYk29c/L/azCS/sLU/xP62qF5N8EjjM2hXEb1XV8Sm2DHAr8FHgV0mOrd/2xap6eMJNnXwKeGD9L/cngY9PvOeyqupIkkPAUdb+VeRxNuFTRn2aqNSMF8qkZoxaasaopWaMWmrGqKVmjFpqxqilZv4DGAwaxAHGxvoAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, N = 3, 10, 10\n", + "def show_deps(tree):\n", + " plt.imshow(tree.detach())\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, N)\n", + "dist = torch_struct.DependencyCRF(log_potentials)\n", + "show_deps(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAL80lEQVR4nO3dX2zddRnH8c+np11Zx4CBKNJOtjnULCQGUglK4gUYFSGQGE1GAonc7Eb+aEgIGBNuvSBELohmAYwJREwmF8QQUKP458LFsmFgK8iYsL/IYML+s3V9vGhN5kZ3fj37fvm1j+9XQrK2Z88eur77Oz07/dYRIQB59LW9AICyiBpIhqiBZIgaSIaogWT6awztLF4U/RcsKT63f7+Lz5Sk/sPHyw89NlF+piRpvv1rRZ2/M01Olp/ZqXSNO15+18OTB3Q0jnzoO7dK1P0XLNFFP7yz+NxP/LFTfKYkLXnpveIzvfvd4jMlSVHhg7mmTp2/szh0uPjMvkVDxWdK0uTBQ8Vn/vXA0zO+jbvfQDJEDSRD1EAyRA0kQ9RAMkQNJNMoattft/2q7S227629FIDedY3adkfSw5Kuk7RK0s22V9VeDEBvmlypr5S0JSK2RsRRSU9KuqnuWgB61STqYUnbT3h5x/Tr/oftNbbHbI8d33+w1H4AZqnYA2URsTYiRiNitLN4UamxAGapSdQ7JS094eWR6dcBmIOaRP03SZfaXm57gaTVkmZ+NjmAVnX9Lq2ImLB9u6TnJHUkPRYRm6pvBqAnjb71MiKekfRM5V0AFMAzyoBkiBpIhqiBZIgaSIaogWSqHDzoo9bC7QPF504O1DlJ8/XV5U8+XfnYB8VnSpLe2VtlbByps2+tn9XWd845xWfWOMxQqvO+Pd37lSs1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZBMldNEB98+rGUPvVx87t4bVxWfKUlLf3es+MzXb7uo+ExJWvHLwSpzfbzOqZ/atqvK2Dh0qMrcKvr80f5xH+mfBqA6ogaSIWogGaIGkiFqIBmiBpIhaiCZrlHbXmr7D7Y3295k+66PYjEAvWny5JMJSXdHxAbbiyW9YPu3EbG58m4AetD1Sh0RuyNiw/Sv90salzRcezEAvZnV00RtL5N0uaT1H/K2NZLWSNJZXlRgNQC9aPxAme2zJf1K0vciYt/Jb4+ItRExGhGjC/rOKrkjgFloFLXtAU0F/UREPFV3JQBnosmj35b0qKTxiHiw/koAzkSTK/XVkm6VdI3tF6f/+0blvQD0qOsDZRHxF0kf7TeEAugZzygDkiFqIBmiBpIhaiCZKgcP1nL+xn9XmXtg5bnFZ658dHfxmZL0+o8WV5m74v4jVeb2nX9elbnH33q7+My+wTqHOk7WOCQxZj4okis1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZBMldNEY3KyygmKnb3vF58pSYvH9hefuX90pPhMSfr03buqzN33heEqcxeP760yt/PxC4vPnHy3zq59Q0PFZ/rQzNdjrtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMo2jtt2xvdH2r2suBODMzOZKfZek8VqLACijUdS2RyRdL+mRuusAOFNNr9Q/lnSPpMmZbmB7je0x22PH4oMiywGYva5R275B0tsR8cLpbhcRayNiNCJGBzxYbEEAs9PkSn21pBttvyHpSUnX2H686lYAetY16oi4LyJGImKZpNWSfh8Rt1TfDEBP+HdqIJlZfT91RDwv6fkqmwAogis1kAxRA8kQNZAMUQPJEDWQTJXTRGuJw4frDO4v/24YfPdo8ZmS9N5VdU79PHfTv6vM3fnV8qd+StIn/1T+ZNlORPGZknT8rX8Vnxkx4zO2uVID2RA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8nUOU00pJiYKD528sDB4jMlyYPlf572wGu7is+UpPO2uMrcf92wosrckXVvVJm785vLis8cfnpf8ZmS5IULy888MPP1mCs1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEyjqG2fZ3ud7Vdsj9v+Yu3FAPSm6ZNPHpL0bER8y/YCSUMVdwJwBrpGbftcSV+W9B1Jioijkur88GUAZ6zJ3e/lkvZI+pntjbYfsb3o5BvZXmN7zPbYMX1QfFEAzTSJul/SFZJ+EhGXSzoo6d6TbxQRayNiNCJGB1T+udQAmmkS9Q5JOyJi/fTL6zQVOYA5qGvUEfGWpO22Pzv9qmslba66FYCeNX30+w5JT0w/8r1V0m31VgJwJhpFHREvShqtvAuAAnhGGZAMUQPJEDWQDFEDyRA1kEyd00Qtub/CaNf5HOROp/zMGv//kuKDOk/BXfLakSpzX7v9kipzL334zeIzd3y7zq7Dj1d43x7iNFHg/wZRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8lUOR3PAwvUGb64+NwYXFB8piT5SPnD/Cbf2Vt8piSpr87n4QVvvltl7srH6hzA+I/bP1V85kXrjxefKUkTKyu08PeBGd/GlRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIplHUtr9ve5Ptl23/wvZZtRcD0JuuUdselnSnpNGIuExSR9Lq2osB6E3Tu9/9khba7pc0JGlXvZUAnImuUUfETkkPSNomabek9yPiNyffzvYa22O2x44eP1R+UwCNNLn7vUTSTZKWS7pY0iLbt5x8u4hYGxGjETG6oDNUflMAjTS5+/0VSf+MiD0RcUzSU5K+VHctAL1qEvU2SVfZHrJtSddKGq+7FoBeNfmaer2kdZI2SHpp+vesrbwXgB41+mbXiLhf0v2VdwFQAM8oA5IhaiAZogaSIWogGaIGkqly1OPkYL+OrLiw+NzBHe8VnylJcehI8Zl9i88uPlOSJvftrzJXE3VO0tT7dfb9zE8nis/c/s2R4jMl6c8P/7z4zCu/NvPpr1ypgWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkiGqIFkHBHlh9p7JL3Z4KYfk/RO8QXqmU/7zqddpfm171zY9ZKI+NAje6tE3ZTtsYgYbW2BWZpP+86nXaX5te9c35W730AyRA0k03bU8+2H18+nfefTrtL82ndO79rq19QAymv7Sg2gMKIGkmktattft/2q7S22721rj25sL7X9B9ubbW+yfVfbOzVhu2N7o+1ft73L6dg+z/Y626/YHrf9xbZ3Oh3b35/+OHjZ9i9sn9X2TidrJWrbHUkPS7pO0ipJN9te1cYuDUxIujsiVkm6StJ35/CuJ7pL0njbSzTwkKRnI+Jzkj6vObyz7WFJd0oajYjLJHUkrW53q1O1daW+UtKWiNgaEUclPSnpppZ2Oa2I2B0RG6Z/vV9TH3TD7W51erZHJF0v6ZG2dzkd2+dK+rKkRyUpIo5GRJ0fQl5Ov6SFtvslDUna1fI+p2gr6mFJ2094eYfmeCiSZHuZpMslrW93k65+LOkeSZNtL9LFckl7JP1s+kuFR2wvanupmUTETkkPSNomabek9yPiN+1udSoeKGvI9tmSfiXpexGxr+19ZmL7BklvR8QLbe/SQL+kKyT9JCIul3RQ0lx+fGWJpu5RLpd0saRFtm9pd6tTtRX1TklLT3h5ZPp1c5LtAU0F/UREPNX2Pl1cLelG229o6suaa2w/3u5KM9ohaUdE/PeezzpNRT5XfUXSPyNiT0Qck/SUpC+1vNMp2or6b5Iutb3c9gJNPdjwdEu7nJZta+prvvGIeLDtfbqJiPsiYiQilmnq/fr7iJhzVxNJioi3JG23/dnpV10raXOLK3WzTdJVtoemPy6u1Rx8YK+/jT80IiZs3y7pOU09gvhYRGxqY5cGrpZ0q6SXbL84/bofRMQzLe6UyR2Snpj+5L5V0m0t7zOjiFhve52kDZr6V5GNmoNPGeVpokAyPFAGJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJPMfJruyoZcO4JwAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_deps(dist.marginals[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAJ9ElEQVR4nO3d3YtchR3G8edpjIlGRdqmYl5ovGgFkRrLkl4ohVps4gvaSwW9EnJTIdKC6KX/gHjTm0WlLVqDoIJY2zRoRASNbuJqjdEQxGKiEKuIptKY6NOLHWGjSfbM5Jw9xx/fDyzu7gzjwybfnJ3ZnTlOIgB1fK/vAQDaRdRAMUQNFEPUQDFEDRRzRhc3eqaXZblWdHHT6MhPf/Z53xOOs+/1s/ueMGj/03/1RY74RJd1EvVyrdAv/Osubhod2bZttu8Jx9m4an3fEwZtZ5456WV8+w0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0U0yhq25tsv217v+27uh4FYHILRm17iaQ/SrpG0iWSbrZ9SdfDAEymyZF6g6T9Sd5J8oWkrZJu7HYWgEk1iXq1pPfmfXxg9Lnj2N5se8b2zFEdaWsfgDG19kBZkukkU0mmlmpZWzcLYExNoj4oae28j9eMPgdggJpE/Yqkn9i+yPaZkm6S9GS3swBMasEXHkxyzPbtkrZJWiLpwSR7Ol8GYCKNXk00ydOSnu54C4AW8BtlQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFOMkrd/o1GXL8/K2tQtfcZFsXLW+7wlAq3bmGX2aj32iyzhSA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFLBi17QdtH7L9xmIMAnB6mhyp/yRpU8c7ALRkwaiTPC/p40XYAqAFrd2ntr3Z9oztmQ8/+rKtmwUwptaiTjKdZCrJ1MofLGnrZgGMiUe/gWKIGiimyY+0HpH0oqSLbR+wfVv3swBM6oyFrpDk5sUYAqAdfPsNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMQs+oWMS+14/WxtXre/ipiey7f3Zvid8y5C+PqiFIzVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxRA1UAxRA8UQNVAMUQPFEDVQDFEDxTQ5Qd5a2ztsv2l7j+0tizEMwGSaPJ/6mKQ/JNlt+1xJu2xvT/Jmx9sATGDBI3WSD5LsHr3/maS9klZ3PQzAZMZ65RPb6yRdLmnnCS7bLGmzJC3X2S1MAzCJxg+U2T5H0mOS7kjy6TcvTzKdZCrJ1FIta3MjgDE0itr2Us0F/XCSx7udBOB0NHn025IekLQ3yb3dTwJwOpocqa+QdKukq2zPjt6u7XgXgAkt+EBZkhckeRG2AGgBv1EGFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMWO98sl31cZV6/ue8C3b3p/te8Jxhvg1Gpoh/Zlt2Pj5SS/jSA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMU3Oernc9su2X7O9x/Y9izEMwGSaPJ/6iKSrkhwenaf6Bdt/T/JSx9sATKDJWS8j6fDow6Wjt3Q5CsDkGt2ntr3E9qykQ5K2J9l5gutstj1je+aojrS9E0BDjaJO8mWS9ZLWSNpg+9ITXGc6yVSSqaVa1vZOAA2N9eh3kk8k7ZC0qZs5AE5Xk0e/V9o+f/T+WZKulvRW18MATKbJo98XSvqz7SWa+0fg0SRPdTsLwKSaPPr9uqTLF2ELgBbwG2VAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0U0+RZWujAxlXr+54waNven+17wrcM6c9sXz466WUcqYFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoopnHUoxPPv2qbk+MBAzbOkXqLpL1dDQHQjkZR214j6TpJ93c7B8Dpanqkvk/SnZK+OtkVbG+2PWN75qiOtDIOwPgWjNr29ZIOJdl1quslmU4ylWRqqZa1NhDAeJocqa+QdIPtdyVtlXSV7Yc6XQVgYgtGneTuJGuSrJN0k6Rnk9zS+TIAE+Hn1EAxY71EcJLnJD3XyRIAreBIDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRRD1EAxRA0UQ9RAMUQNFEPUQDFEDRQz1rO0gMWycdX6vid8Z3GkBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqAYogaKIWqgGKIGiiFqoBiiBoohaqCYRk+9HJ2b+jNJX0o6lmSqy1EAJjfO86l/leQ/nS0B0Aq+/QaKaRp1JP3T9i7bm090Bdubbc/YnjmqI+0tBDCWpt9+X5nkoO0fSdpu+60kz8+/QpJpSdOSdJ6/n5Z3Amio0ZE6ycHRfw9JekLShi5HAZjcglHbXmH73K/fl/QbSW90PQzAZJp8+32BpCdsf339vyb5R6erAExswaiTvCPpskXYAqAF/EgLKIaogWKIGiiGqIFiiBoohqiBYogaKIaogWKIGiiGqIFiiBoohqiBYpy0/3oGtj+U9O8WbuqHkob0umjsObWh7ZGGt6mtPT9OsvJEF3QSdVtszwzplUvZc2pD2yMNb9Ni7OHbb6AYogaKGXrU030P+Ab2nNrQ9kjD29T5nkHfpwYwvqEfqQGMiaiBYgYZte1Ntt+2vd/2XQPY86DtQ7YH8dLIttfa3mH7Tdt7bG/pec9y2y/bfm20554+93zN9hLbr9p+qu8t0tyJJm3/y/as7ZnO/j9Du09te4mkfZKulnRA0iuSbk7yZo+bfinpsKS/JLm0rx3z9lwo6cIku0evyb5L0m/7+hp57vWjVyQ5bHuppBckbUnyUh975u36vaQpSeclub7PLaM970qa6vpEk0M8Um+QtD/JO0m+kLRV0o19DhqdYujjPjfMl+SDJLtH738maa+k1T3uSZLDow+Xjt56PVrYXiPpOkn397mjD0OMerWk9+Z9fEA9/oUdOtvrJF0uaWfPO5bYnpV0SNL2JL3ukXSfpDslfdXzjvkWPNFkG4YYNRqyfY6kxyTdkeTTPrck+TLJeklrJG2w3dvdFNvXSzqUZFdfG07iyiQ/l3SNpN+N7ta1bohRH5S0dt7Ha0afwzyj+66PSXo4yeN97/lakk8k7ZC0qccZV0i6YXQfdqukq2w/1OMeSYt3oskhRv2KpJ/Yvsj2mZJukvRkz5sGZfTA1AOS9ia5dwB7Vto+f/T+WZp7kPOtvvYkuTvJmiTrNPf359kkt/S1R1rcE00OLuokxyTdLmmb5h4AejTJnj432X5E0ouSLrZ9wPZtfe7R3JHoVs0dgWZHb9f2uOdCSTtsv665f5S3JxnEj5EG5AJJL9h+TdLLkv7W1YkmB/cjLQCnZ3BHagCnh6iBYogaKIaogWKIGiiGqIFiiBoo5v9vCGG1b/1T8gAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Convert from 1-index standard format. (Head is 0)\n", + "event = dist.to_event(torch.tensor([[2, 3, 4, 1, 0, 4]]), None)\n", + "show_deps(event[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Non-Projective Dependency Trees" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.NonProjectiveDependencyCRF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAMU0lEQVR4nO3db2yddRnG8etau2XdGNsEoqMbbhpElhkz0+gGCgZMREUxxpgpYMAX8w/gNCoBX+hr/wY1ik5ESVhEsy1qDAGMiNGIC90Yka2gCzi2OWQIDKyMUnv7ojWZG915dvb7+bR3vp+EZO3pbu60/e455/TpcxwRApDHjLYXAFAWUQPJEDWQDFEDyRA1kExvjaFzF86Khf19xef+48DJxWdK0sJTnys+86ln5hWfKUlR5SsmzT7w7ypzY0mduSPDs4rPnNFXZ1er/E+YXnjioEYP/ssvdVuVb5GF/X265qeri8+99ftvLz5Tkt77kd8Wn/mTn59XfKYkjZwyVmXuWd87WGVu3FD+H0xJevTeM4rPnLPi6eIzJam3p/zX7KF1N096G3e/gWSIGkiGqIFkiBpIhqiBZIgaSKZR1LYvsv2w7V22r6u9FIDudYzado+kb0t6h6Tlkj5oe3ntxQB0p8mR+o2SdkXEIxExIuk2SZfUXQtAt5pE3S9pz2Fv75143/+wvdb2oO3B4adGSu0H4DgVe6IsItZHxEBEDMx9WfnzcgE00yTqfZKWHPb24on3AZiCmkR9n6QzbS+zPUvSGkm/qLsWgG51/C2tiBi1fbWkOyX1SLo5InZU3wxAVxr96mVE3C7p9sq7ACiAM8qAZIgaSIaogWSIGkiGqIFkqlx48NndJ+muj72l+NxFz9e5ON5PF5xffOayr2wvPlOSNFbnwoPzfz2nytznPjC7ytzv/+7G4jOveuBDxWdK0qLryl+ldNfeyW/jSA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJFPlaqLzzhjW+d/5Y43RVezedEHxmSOrzi4+U5IOvrrOa3+vOum3VeZuuKb8lVol6UvnX1x85kU/Hyo+U5LuePfq4jNHbpk8XY7UQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDIdo7a9xPZvbO+0vcP2uv/HYgC60+Tkk1FJn4mIbbbnSdpq+1cRsbPybgC60PFIHRH7I2LbxJ+fkzQkqb/2YgC6c1yPqW0vlbRS0paXuG2t7UHbg8NPj5TZDsBxaxy17ZMkbZL0qYh49sjbI2J9RAxExMDchXXOTwbQWaOobc/UeNAbImJz3ZUAnIgmz35b0g8kDUXE1+uvBOBENDlSnyvpckkX2N4+8d87K+8FoEsdf6QVEb+X5P/DLgAK4IwyIBmiBpIhaiAZogaSqXLhwb4ZI1rRt7f43Otv/XDxmZJ0aNFo8Zn73lrnBJwXXl5+V0n60eA5VeYu2h5V5s6/bbj4zAff3Fd8piQd+kL5z8FYz+S3caQGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpKpcjXRPU+fos9tvrz43L989DvFZ0rS67Z8qPjMBXfOKz5Tkp46e2aVuSsv3lll7tbdy6vMPfnq04vP3HfV/OIzJUln/Kv8zFljk97EkRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIpnHUtnts32/7lzUXAnBijudIvU7SUK1FAJTRKGrbiyW9S9JNddcBcKKaHqlvkHStpEnPTbO91vag7cGx4fIvCA6gmY5R275Y0hMRsfVYHxcR6yNiICIGZsydW2xBAMenyZH6XEnvsf1XSbdJusD2rVW3AtC1jlFHxPURsTgilkpaI+nuiLis+mYAusLPqYFkjuv3qSPiHkn3VNkEQBEcqYFkiBpIhqiBZIgaSIaogWSqXE2093np1Aei+Nw37fh48ZmSNOMUF585+8nni8+UpLHevipzn760zpU01/zsnipzb3nV6uIzz7ziD8VnStJTV5bf9fFnJz8ec6QGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpKpcjXRnpExzdt9qPjcvw/MKT5Tkl6+Zbj4zE/csqn4TEn63H3vrzJ3bH6d1xT/0dZzqsx9zfqR4jN9d3/xmZI0uqn81WrjGIdjjtRAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMo2itr3A9kbbD9kesl3+ZfwAFNH05JNvSLojIt5ve5akOmeBADhhHaO2PV/SeZKukKSIGJFU/nQeAEU0ufu9TNIBST+0fb/tm2wfdU6h7bW2B20Pvvhi+dMuATTTJOpeSW+QdGNErJQ0LOm6Iz8oItZHxEBEDMycWec8YgCdNYl6r6S9EbFl4u2NGo8cwBTUMeqIeFzSHttnTbzrQkk7q24FoGtNn/2+RtKGiWe+H5F0Zb2VAJyIRlFHxHZJA5V3AVAAZ5QByRA1kAxRA8kQNZAMUQPJVLmaqF8Y1czdB4rPnXPGkuIzJemfS/qKz7zrmRXFZ0rSa64t/3mVpN3fXFBl7vLPPlll7jMDryg+c/+D5b8PJKl3SRSfOTZr8ts4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQTJULDyrGFIcOFR87Y7T8BdwkyXXGVjF888wqc0/7Wp2L7j166fwqc5fe+HDxmY+vPrP4TEma/+fyM/e/MPltHKmBZIgaSIaogWSIGkiGqIFkiBpIhqiBZBpFbfvTtnfYftD2j23Prr0YgO50jNp2v6RPShqIiBWSeiStqb0YgO40vfvdK6nPdq+kOZL+Vm8lACeiY9QRsU/SVyU9Jmm/pIMRcdeRH2d7re1B24MjY+VPEQXQTJO73wslXSJpmaTTJc21fdmRHxcR6yNiICIGZs3gITfQliZ3v98m6dGIOBARL0raLOmcumsB6FaTqB+TtMr2HNuWdKGkobprAehWk8fUWyRtlLRN0p8m/s76ynsB6FKj36eOiC9K+mLlXQAUwBllQDJEDSRD1EAyRA0kQ9RAMlWuJnpo0Ww9/NnyV2Y87b7iIyVJIye7+Myt31pZfKYkHXrfM1XmnlLriqpjdcZ+7N57i8/87pt7is+UpIe+vLj4zNFfT/4F40gNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSTjiPKXkbR9QNLuBh96qqQniy9Qz3TadzrtKk2vfafCrq+MiNNe6oYqUTdlezAiBlpb4DhNp32n067S9Np3qu/K3W8gGaIGkmk76un24vXTad/ptKs0vfad0ru2+pgaQHltH6kBFEbUQDKtRW37ItsP295l+7q29ujE9hLbv7G90/YO2+va3qkJ2z2277f9y7Z3ORbbC2xvtP2Q7SHbq9ve6Vhsf3ri++BB2z+2PbvtnY7UStS2eyR9W9I7JC2X9EHby9vYpYFRSZ+JiOWSVkm6agrverh1kobaXqKBb0i6IyJeK+n1msI72+6X9ElJAxGxQlKPpDXtbnW0to7Ub5S0KyIeiYgRSbdJuqSlXY4pIvZHxLaJPz+n8W+6/na3OjbbiyW9S9JNbe9yLLbnSzpP0g8kKSJGIqLOC3CX0yupz3avpDmS/tbyPkdpK+p+SXsOe3uvpngokmR7qaSVkra0u0lHN0i6VtVe8r2YZZIOSPrhxEOFm2zPbXupyUTEPklflfSYpP2SDkbEXe1udTSeKGvI9kmSNkn6VEQ82/Y+k7F9saQnImJr27s00CvpDZJujIiVkoYlTeXnVxZq/B7lMkmnS5pr+7J2tzpaW1Hvk7TksLcXT7xvSrI9U+NBb4iIzW3v08G5kt5j+68af1hzge1b211pUnsl7Y2I/97z2ajxyKeqt0l6NCIORMSLkjZLOqflnY7SVtT3STrT9jLbszT+ZMMvWtrlmGxb44/5hiLi623v00lEXB8RiyNiqcY/r3dHxJQ7mkhSRDwuaY/tsybedaGknS2u1MljklbZnjPxfXGhpuATe71t/E8jYtT21ZLu1PgziDdHxI42dmngXEmXS/qT7e0T7/t8RNze4k6ZXCNpw8Q/7o9IurLlfSYVEVtsb5S0TeM/FblfU/CUUU4TBZLhiTIgGaIGkiFqIBmiBpIhaiAZogaSIWogmf8As1S6RUJMrskAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, N = 3, 10, 10\n", + "def show_deps(tree):\n", + " plt.imshow(tree.detach())\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, N)\n", + "dist = torch_struct.NonProjectiveDependencyCRF(log_potentials)\n", + "show_deps(dist.marginals[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Binary Labeled Trees" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + " .. autoclass:: torch_struct.TreeCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD4CAYAAADl7fPiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPdklEQVR4nO3df6yeZX3H8fdnBSEwQLBa+TU1wkiY0c41BTK2wFB+NER0IQ6ybLix1JmRzWRmYVsixP3jsjgSh0GrNuCiyKarNrGhNGwJkiBSSPmlIB3B0IJUrCsqRil+98e5a851+jz0Oc+Pc55zfL+S5rl/XOe+r5vTfLjv5756fVNVSNIBv7bYHZA0XQwFSQ1DQVLDUJDUMBQkNQ5b7A70cthxR9URq44bqG39cPBLOOz5nwzbpbH4zbe+OHDb7zx01AR7ouVq0L9jTz39Es/vfTm99k1lKByx6jh+6+PvG6jtS1957cDHXbnhniF7NB5bt+4YuO1FJ62eYE+0XA36d2ztRU/33efjg6TGSKGQ5OIkjyfZmeTaHvuPSHJbt//eJG8c5XySJm/oUEiyAvgEcAlwJnBlkjPnNLsa+GFVnQbcAPzzsOeTtDBGuVNYC+ysqier6ufAF4HL5rS5DLilW/4ScEGSnl9uSJoOo4TCycDsbyt2ddt6tqmq/cA+4DW9DpZkfZLtSbbv3zf4t/SSxmtqvmisqg1Vtaaq1hx2nK/jpMUySijsBk6dtX5Kt61nmySHAccBPxjhnJImbJRQuA84PcmbkrwKuALYPKfNZuCqbvly4L/Lf6stTbWhBy9V1f4k1wBbgRXAxqp6NMlHgO1VtRn4LPDvSXYCe5kJDklTLNP4P+5jc0KdlQsGartvy2kDH/e4dTuH7ZK0rNxbd/JC7e35JnBqvmiUNB0MBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUmNqZy4dT7mM3R5UkOitz4z+ISsg3LiVi0W7xQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVJjlApRpyb5nyTfSvJokr/p0ea8JPuS7Oj+fHi07kqatFEGL+0H/raqHkhyDHB/km1V9a057b5eVZeOcB5JC2joO4WqeraqHuiWfwR8m4MrRElaYsYyzLmrJv3bwL09dp+T5EHgGeBDVfVon2OsB9YDHMlkKkRNakg0DDbM2aHLWgpGDoUkvw58GfhgVb0wZ/cDwBuq6sdJ1gFfAU7vdZyq2gBsgJkp3kftl6ThjPT2IcnhzATC56vqv+bur6oXqurH3fIW4PAkK0c5p6TJGuXtQ5ipAPXtqvrXPm1ef6D0fJK13fmsJSlNsVEeH34X+BPg4SQHHqr/AfgNgKr6JDP1Iz+QZD/wU+AKa0lK022UWpJ3Az3LTs1qcyNw47DnkLTwHNEoqWEoSGoYCpIahoKkhqEgqbHkZ3OelPkMiT57y+WDHZPBjyktFu8UJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDUc0TgGg45+fH79OQMfc+WGe4btjjQS7xQkNQwFSY2RQyHJU0ke7srCbe+xP0k+nmRnkoeSvH3Uc0qanHF9p3B+VT3fZ98lzNR6OB04C7ip+5Q0hRbi8eEy4HM14xvAq5OcuADnlTSEcYRCAXckub8r/TbXycDTs9Z30aPmZJL1SbYn2f4SPxtDtyQNYxyPD+dW1e4krwO2JXmsqu6a70EsGydNh5HvFKpqd/e5B9gErJ3TZDdw6qz1U7ptkqbQqLUkj05yzIFl4ELgkTnNNgN/2r2FOBvYV1XPjnJeSZMz6uPDKmBTVy7yMOALVXV7kr+EX5aO2wKsA3YCLwJ/NuI5JU1QprG047E5oc7KBYvdjUW19Zkdh27Uueik1RPsyfjN59rO3jHYpLgA31j9pYHbLrX/ZoMadCj941++gRe//3TPso+OaJTUMBQkNQwFSQ1DQVLDUJDUMBQkNQwFSQ1DQVLDUJDUMBQkNRzmvAxMapboSR1335bTBm476EzZmp97605eqL0Oc5Z0aIaCpIahIKlhKEhqGAqSGoaCpIahIKkxdCgkOaMrFXfgzwtJPjinzXlJ9s1q8+HRuyxpkoaeuLWqHgdWAyRZwcy07Zt6NP16VV067HkkLaxxPT5cAPxvVX13TMeTtEjGVWD2CuDWPvvOSfIg8Azwoap6tFejruTceoAjOWpM3frVMJ8hxvOZSRnmMaP0hsFnR3bo8nQbRyn6VwHvAv6zx+4HgDdU1duAfwO+0u84VbWhqtZU1ZrDOWLUbkka0jgeHy4BHqiq5+buqKoXqurH3fIW4PAkK8dwTkkTMo5QuJI+jw5JXp+ufFSStd35fjCGc0qakJG+U+jqR74TeP+sbbNLxl0OfCDJfuCnwBU1jf9WW9IvjRQKVfUT4DVztn1y1vKNwI2jnEPSwnJEo6SGoSCpYShIahgKkhqGgqTGuIY5a4m46KTBhyPPazZnBh9qrenmnYKkhqEgqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIahoKkhqEgqeEwZ/U1n1mi9205bSJ9cObnheedgqTGQKGQZGOSPUkembXthCTbkjzRfR7f52ev6to8keSqcXVc0mQMeqdwM3DxnG3XAndW1enAnd16I8kJwHXAWcBa4Lp+4SFpOgwUClV1F7B3zubLgFu65VuAd/f40YuAbVW1t6p+CGzj4HCRNEVG+U5hVVU92y1/D1jVo83JwNOz1nd12yRNqbF80djVchipnkOS9Um2J9n+Ej8bR7ckDWGUUHguyYkA3eeeHm12A6fOWj+l23YQa0lK02GUUNgMHHibcBXw1R5ttgIXJjm++4Lxwm6bpCk16CvJW4F7gDOS7EpyNfBR4J1JngDe0a2TZE2SzwBU1V7gn4D7uj8f6bZJmlIDjWisqiv77LqgR9vtwF/MWt8IbByqd5IWnMOcNRbzGY48qSHRGg+HOUtqGAqSGoaCpIahIKlhKEhqGAqSGoaCpIahIKlhKEhqGAqSGg5z1oKb1JBoZ34eD+8UJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FS45Ch0KeO5L8keSzJQ0k2JXl1n599KsnDSXYk2T7OjkuajEHuFG7m4FJv24C3VNVbge8Af/8KP39+Va2uqjXDdVHSQjpkKPSqI1lVd1TV/m71G8wUeZG0DIxjmPOfA7f12VfAHUkK+FRVbeh3kCTrgfUAR3LUGLql5WA+Q5e3PrNj4LYXnbR6mO78ShgpFJL8I7Af+HyfJudW1e4krwO2JXmsu/M4SBcYGwCOzQkj1aWUNLyh3z4keR9wKfDHXYHZg1TV7u5zD7AJWDvs+SQtjKFCIcnFwN8B76qqF/u0OTrJMQeWmakj+UivtpKmxyCvJHvVkbwROIaZR4IdST7ZtT0pyZbuR1cBdyd5EPgm8LWqun0iVyFpbA75nUKfOpKf7dP2GWBdt/wk8LaReidpwTmiUVLDUJDUMBQkNQwFSQ1DQVLD2Zy1bMxn6PKgs0RPaoboSc1SPehxX/7ru/vu805BUsNQkNQwFCQ1DAVJDUNBUsNQkNQwFCQ1DAVJDUNBUsMRjfqVNOgowflMBjsfv3P9OfNoPfiIxkGva0X9rO8+7xQkNQwFSY1hy8Zdn2R3Nz/jjiTr+vzsxUkeT7IzybXj7LikyRi2bBzADV05uNVVtWXuziQrgE8AlwBnAlcmOXOUzkqavKHKxg1oLbCzqp6sqp8DXwQuG+I4khbQKN8pXNNVnd6Y5Pge+08Gnp61vqvb1lOS9Um2J9n+Ev2/GZU0WcOGwk3Am4HVwLPAx0btSFVtqKo1VbXmcI4Y9XCShjRUKFTVc1X1clX9Avg0vcvB7QZOnbV+SrdN0hQbtmzcibNW30PvcnD3AacneVOSVwFXAJuHOZ+khXPIEY1d2bjzgJVJdgHXAeclWc1MqfmngPd3bU8CPlNV66pqf5JrgK3ACmBjVT06kauQNDbpUzB6UR2bE+qsXLDY3ZDmZT5Douczyewk3Ft38kLtTa99jmiU1DAUJDUMBUkNQ0FSw1CQ1DAUJDUMBUkNQ0FSw1CQ1DAUJDWczVkak/kMXd635bSB2w46Q/O4eKcgqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIag8zRuBG4FNhTVW/ptt0GnNE1eTXwf1V10EvaJE8BPwJeBvZX1Zox9VvShAwyeOlm4Ebgcwc2VNUfHVhO8jFg3yv8/PlV9fywHZS0sA4ZClV1V5I39tqXJMB7gT8Yb7ckLZZRhzn/HvBcVT3RZ38BdyQp4FNVtaHfgZKsB9YDHMlRI3ZLmm7zGbr8/PpzBm57//U3DdRu7UUv9t03aihcCdz6CvvPrardSV4HbEvyWFew9iBdYGyAmSneR+yXpCEN/fYhyWHAHwK39WtTVbu7zz3AJnqXl5M0RUZ5JfkO4LGq2tVrZ5KjkxxzYBm4kN7l5SRNkUOGQlc27h7gjCS7klzd7bqCOY8OSU5KsqVbXQXcneRB4JvA16rq9vF1XdIkDPL24co+29/XY9szwLpu+UngbSP2T9ICc0SjpIahIKlhKEhqGAqSGoaCpIazOUtTbuWGewZue/a7Lx+o3bd/enPffd4pSGoYCpIahoKkhqEgqWEoSGoYCpIahoKkhqEgqWEoSGoYCpIaqZq+OVKTfB/47pzNK4HlWD9iuV4XLN9rWw7X9Yaqem2vHVMZCr0k2b4cK0wt1+uC5Xtty/W6DvDxQVLDUJDUWEqh0Le61BK3XK8Llu+1LdfrApbQdwqSFsZSulOQtAAMBUmNJREKSS5O8niSnUmuXez+jEuSp5I8nGRHku2L3Z9RJNmYZE+SR2ZtOyHJtiRPdJ/HL2Yfh9Hnuq5Psrv7ve1Ism4x+zhuUx8KSVYAnwAuAc4Erkxy5uL2aqzOr6rVy+C9983AxXO2XQvcWVWnA3d260vNzRx8XQA3dL+31VW1pcf+JWvqQ4GZStU7q+rJqvo58EXgskXuk+aoqruAvXM2Xwbc0i3fArx7QTs1Bn2ua1lbCqFwMvD0rPVd3bbloIA7ktyfZP1id2YCVlXVs93y95gpOrxcXJPkoe7xYsk9Fr2SpRAKy9m5VfV2Zh6N/irJ7y92hyalZt59L5f33zcBbwZWA88CH1vc7ozXUgiF3cCps9ZP6bYteVW1u/vcA2xi5lFpOXkuyYkA3eeeRe7PWFTVc1X1clX9Avg0y+z3thRC4T7g9CRvSvIq4Apg8yL3aWRJjk5yzIFl4ELgkVf+qSVnM3BVt3wV8NVF7MvYHAi6zntYZr+3qa8QVVX7k1wDbAVWABur6tFF7tY4rAI2JYGZ38MXqur2xe3S8JLcCpwHrEyyC7gO+CjwH0muZuafwr938Xo4nD7XdV6S1cw8Dj0FvH/ROjgBDnOW1FgKjw+SFpChIKlhKEhqGAqSGoaCpIahIKlhKEhq/D9fpEvjXBfZOAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, NT = 3, 20, 3\n", + "def show_tree(tree):\n", + " t = tree.detach()\n", + " plt.imshow(t[ :, : , 0] + \n", + " 2 * t[ :,:, 1] +\n", + " 3 * t[ :,:, 2])\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, N, NT)\n", + "dist = torch_struct.TreeCRF(log_potentials)\n", + "show_tree(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQUAAAD4CAYAAADl7fPiAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASYklEQVR4nO3df5BdZX3H8fdnN78kBiFSAiQBUVNa6pTopKEMKRMaRchQI4y1SZk2tCmhaqZV8AetrSiOHVvHUhUKRo1gB5G2GswMEchQZ5ARhIWGH5EgaYiwG0jQYAghv3b32z/2xNlnc2/ynHvu3b27fF4zmb33nO+e89zczWfPvfeb51FEYGZ2UMdID8DM2otDwcwSDgUzSzgUzCzhUDCzxLiRHkAtU6d2xPQZnVm1W56dln3cjr29+YMo86lMf2at8g+ZfcyyyjwulRhwq45bqja/NPu4ZZ6HzhK/Y0s8rv6Jef8WAE4+eXtW3dbuXl7a0V9zEG0ZCtNndPK9O47Lql32wb/JPu7kJ1/MH8SB/ACJvfuy6jQu/8nNPWZZsX9/dq3G5/94RF9//nHHlfixK1Fb5u8397ilnoc3vD6/duKE7NJXTzk6u/b6f/9yVt3iC+uHh18+mFmiUihIOl/SU5I2Sbqqxv6Jkm4r9v9E0puqnM/MWq/hUJDUCVwPXACcDiyRdPqQsmXASxHxVuBa4J8bPZ+ZDY8qVwpzgU0RsTki9gPfARYNqVkE3Fzc/m9ggVTmnSMzG25VQmE68Nyg+93Ftpo1EdEL7ATeWOtgkpZL6pLUtWNH/ptWZtZcbfNGY0SsjIg5ETFn6tS2GZbZa06Vf309wMxB92cU22rWSBoHvAH4ZYVzmlmLVQmFh4BZkk6VNAFYDKwZUrMGWFrcfh/wP+H/q23W1hpuXoqIXkkrgLuATmBVRGyQdA3QFRFrgG8A/yFpE7CDgeAwszamdvzFfbSmxplakFV7ycbu7OOuuuKi7NrXde/Kru3YtSevcM/e7GPGvvzOw1KiROfhpEn5h3018+8A0OSj8o9bogOzDI0fn3f+3hKt8cfkdx7SkX+R3ndM/t/XnpNel1W3/odf4pWXumt+Euh39Mws4VAws4RDwcwSDgUzSzgUzCzhUDCzhEPBzBIOBTNLOBTMLOFQMLNEW07cWsYtvzUju/aKp2/Jrr12xZ9m106cmNcy2/FyXh2Adua3WceBA9m19LdmxmHKTPJaZrxlxlBi8thsnSUmg921O7+2ry+7dFyJlvcpL76cVdf5av32bV8pmFnCoWBmCYeCmSUcCmaWcCiYWcKhYGYJh4KZJaqsEDVT0g8l/VTSBkl/W6NmvqSdktYXfz5Vbbhm1mpVmpd6gSsj4hFJU4CHJa2LiJ8OqftRRFxY4TxmNowavlKIiOcj4pHi9i7gSQ5dIcrMRpmmtDkXq0m/HfhJjd1nSXoU2Ap8NCI21DnGcmA5wCTyZ68t4/pZv5lde/Oz/5Zd+8f/8LGsuqmP5s8MrP35MylTZsZhSszeXWLGYfpLHLevzHhLKNGSnN1qfSB/rKVmRi/R5oxKPA8dmW3h/fVbwiuHgqTXA98FPhwRQxuvHwFOiYhXJC0Ebgdm1TpORKwEVsLAFO9Vx2Vmjan06YOk8QwEwi0R8b2h+yPi5Yh4pbi9Fhgv6bgq5zSz1qry6YMYWAHqyYj41zo1Jxxcel7S3OJ8XkvSrI1VeflwNvBnwOOS1hfb/h44GSAibmRg/cgPSOoF9gCLvZakWXurspbkfcBh39WIiOuA6xo9h5kNP3c0mlnCoWBmCYeCmSUcCmaWcCiYWWLUz+bcKpedPC+79gubb8iq++SVl2cfc8pT+W2w2rsvuzYoMZNyGZ0t+v1Soh04SrR7K3OW6FKfoB+mdbjGALJLY+/e/MOOy/wnfZjH5SsFM0s4FMws4VAws4RDwcwSDgUzSzgUzCzhUDCzhEPBzBIOBTNLuKOxCT735tlZdZdt/G72MVd95KLs2sl792fX5vfRAePyJ0Jlb4nOv9yuO4DI7xJUZwt+nEt0i5bqvizRKdkxcWILxuCORjPL5FAws0TlUJC0RdLjxbJwXTX2S9KXJW2S9Jikd1Q9p5m1TrNehJ0bEb+os+8CBtZ6mAWcCdxQfDWzNjQcLx8WAd+KAQ8Ax0g6cRjOa2YNaEYoBHC3pIeLpd+Gmg48N+h+NzXWnJS0XFKXpK4DlHjH18yaqhkvH+ZFRI+k44F1kjZGxL1lD+Jl48zaQ+UrhYjoKb5uB1YDc4eU9AAzB92fUWwzszZUdS3JyZKmHLwNnAc8MaRsDfDnxacQvw/sjIjnq5zXzFqn6suHacDqYr67ccC3I+JOSX8Nv146bi2wENgEvAr8RcVzmlkLqR2XdjxaU+NMLRjpYYyou7auP3JR4ZwP1np/t7bJz+zKH0SJyUU7fvVK/nHLKNNmXEL0ZbZP7ysxKW6JiWPL/N1qUok25/68f8/371zNzt4Xaw7CHY1mlnAomFnCoWBmCYeCmSUcCmaWcCiYWcKhYGYJh4KZJRwKZpZwKJhZwrM5t6l3n5Q3QzTAxRvuzq69/ePvyq7tOJDfAj+xN3/W5Y5f5bdax4ED+bX782uVOaN0mdblKDGbs0q0OZd5XPRnPg+H+e8NvlIws4RDwcwSDgUzSzgUzCzhUDCzhEPBzBIOBTNLNBwKkk4rloo7+OdlSR8eUjNf0s5BNZ+qPmQza6WGm5ci4ilgNoCkTgambV9do/RHEXFho+cxs+HVrJcPC4D/i4ifN+l4ZjZCmtXmvBi4tc6+syQ9CmwFPhoRG2oVFUvOLQeYxFFNGtZrww9+55js2lufvTa7duG/fDy79oRt+W27lGgHzp2dGMq1DueOoVWty3SU+H1cotVbEyZkFtbf1Yyl6CcA7wH+q8buR4BTIuIM4CvA7fWOExErI2JORMwZT4kprc2sqZrx8uEC4JGI2DZ0R0S8HBGvFLfXAuMlHdeEc5pZizQjFJZQ56WDpBNUXFNJmluc75dNOKeZtUil9xSK9SPfBVw+aNvgJePeB3xAUi+wB1gc7bgklZn9WqVQiIjdwBuHbLtx0O3rgOuqnMPMhpc7Gs0s4VAws4RDwcwSDgUzSzgUzCzh2ZxfYy49eV527dINa7Nrf9B1TnZtx8TMVlyAPXvya1ugVOtyGZ2dZQaRXZo9+/RhGgN8pWBmCYeCmSUcCmaWcCiYWcKhYGYJh4KZJRwKZpZwKJhZwqFgZgmHgpkl3OZsdZWZJfqW527Irv2jT1yZXXvsQyV+REu0A2t3Zvv07t355x9XYqx9/fm1kV8bZY5bh68UzCyRFQqSVknaLumJQdumSlon6eni67F1vndpUfO0pKXNGriZtUbulcJNwPlDtl0F3BMRs4B7ivsJSVOBq4EzgbnA1fXCw8zaQ1YoRMS9wI4hmxcBNxe3bwbeW+Nb3w2si4gdEfESsI5Dw8XM2kiV9xSmRcTzxe0XgGk1aqYDzw26311sM7M21ZQ3Gou1HCqt5yBpuaQuSV0H2NeMYZlZA6qEwjZJJwIUX7fXqOkBZg66P6PYdgivJWnWHqqEwhrg4KcJS4Hv16i5CzhP0rHFG4znFdvMrE3lfiR5K3A/cJqkbknLgM8D75L0NPDO4j6S5kj6OkBE7AA+CzxU/Lmm2GZmbSqrBSsiltTZtaBGbRfwV4PurwJWNTQ6Mxt2bnO2prhk5tnZtTc886Xs2isu/1B27YQde7NrO3LXOd6bf8zYtz+7VmVaooeZ25zNLOFQMLOEQ8HMEg4FM0s4FMws4VAws4RDwcwSDgUzSzgUzCzhUDCzRPv2WtqY9YlTz8yu/cqWr2TXXnbVR7Jrj3kssyW5RDuy+ktMKdKRP/M0dDZ/DIc5va8UzCzhUDCzhEPBzBIOBTNLOBTMLOFQMLOEQ8HMEkcMhTrrSH5B0kZJj0laLanm8sSStkh6XNJ6SV3NHLiZtUbOlcJNHLrU2zrgbRHxu8DPgL87zPefGxGzI2JOY0M0s+F0xFCotY5kRNwdEb3F3QcYWOTFzMaAZrQ5/yVwW519AdwtKYCvRsTKegeRtBxYDjCJo5owLBsLrnjTWdm1P956Y3bt/GWXZdUddaD3yEUF7dqdXRu5s0kDUn5LdPT15R617p5KoSDpk0AvcEudknkR0SPpeGCdpI3FlcchisBYCXC0plZal9LMGtfwpw+SLgUuBC6JOrEXET3F1+3AamBuo+czs+HRUChIOh/4OPCeiHi1Ts1kSVMO3mZgHcknatWaWfvI+Uiy1jqS1wFTGHhJsF7SjUXtSZLWFt86DbhP0qPAg8AdEXFnSx6FmTXNEd9TqLOO5Dfq1G4FFha3NwNnVBqdmQ07dzSaWcKhYGYJh4KZJRwKZpZwKJhZwrM525jx7pNmZ9des/lrWXWfuWxZ9jEnbs2fdZm+/uzS6M+vVe7s06/Uvx7wlYKZJRwKZpZwKJhZwqFgZgmHgpklHApmlnAomFnCoWBmCYeCmSXc0WivSZ97c1734zef/VL2MS/+zMeya4/92d7s2vEv7MyupTPz9/xh5oL1lYKZJRwKZpZodNm4T0vqKeZnXC9pYZ3vPV/SU5I2SbqqmQM3s9ZodNk4gGuL5eBmR8TaoTsldQLXAxcApwNLJJ1eZbBm1noNLRuXaS6wKSI2R8R+4DvAogaOY2bDqMp7CiuKVadXSTq2xv7pwHOD7ncX22qStFxSl6SuA+yrMCwzq6LRULgBeAswG3ge+GLVgUTEyoiYExFzxjOx6uHMrEENhUJEbIuIvojoB75G7eXgeoCZg+7PKLaZWRtrdNm4EwfdvYjay8E9BMySdKqkCcBiYE0j5zOz4XPEjsZi2bj5wHGSuoGrgfmSZjOw1PwW4PKi9iTg6xGxMCJ6Ja0A7gI6gVURsaElj8LMmkZ1FoweUUdrapypBSM9DLNSPvvMQ9m1V165Irt2yuMvZtdq7/6suh+/8G127ttWs9nZHY1mlnAomFnCoWBmCYeCmSUcCmaWcCiYWcKhYGYJh4KZJRwKZpZwKJhZwrM5mzXJP576e9m1n9x0U3btP125NLv2qGd3Z9XFLzrr7vOVgpklHApmlnAomFnCoWBmCYeCmSUcCmaWcCiYWSJnjsZVwIXA9oh4W7HtNuC0ouQY4FcRccgyvpK2ALuAPqA3IuY0adxm1iI5zUs3AdcB3zq4ISL+5OBtSV8EDrdW9rkR8YtGB2hmw+uIoRAR90p6U619kgS8H/jD5g7LzEZK1TbnPwC2RcTTdfYHcLekAL4aESvrHUjScmA5wCSOqjgss/Z27Vt/O7t23v8+kF277stnZ9X1/bx+m3PVUFgC3HqY/fMiokfS8cA6SRuLBWsPUQTGShiY4r3iuMysQQ1/+iBpHHAxcFu9mojoKb5uB1ZTe3k5M2sjVT6SfCewMSK6a+2UNFnSlIO3gfOovbycmbWRI4ZCsWzc/cBpkrolLSt2LWbISwdJJ0laW9ydBtwn6VHgQeCOiLizeUM3s1bI+fRhSZ3tl9bYthVYWNzeDJxRcXxmNszc0WhmCYeCmSUcCmaWcCiYWcKhYGYJz+Zs1uYefnv+7+5Pb/pmVt0VD9b/P4q+UjCzhEPBzBIOBTNLOBTMLOFQMLOEQ8HMEg4FM0s4FMws4VAws4RDwcwSimi/OVIlvQj8fMjm44CxuH7EWH1cMHYf21h4XKdExG/U2tGWoVCLpK6xuMLUWH1cMHYf21h9XAf55YOZJRwKZpYYTaFQd3WpUW6sPi4Yu49trD4uYBS9p2Bmw2M0XSmY2TBwKJhZYlSEgqTzJT0laZOkq0Z6PM0iaYukxyWtl9Q10uOpQtIqSdslPTFo21RJ6yQ9XXw9diTH2Ig6j+vTknqK5229pIUjOcZma/tQkNQJXA9cAJwOLJF0+siOqqnOjYjZY+Bz75uA84dsuwq4JyJmAfcU90ebmzj0cQFcWzxvsyNibY39o1bbhwIDK1VviojNEbEf+A6waITHZENExL3AjiGbFwE3F7dvBt47rINqgjqPa0wbDaEwHXhu0P3uYttYEMDdkh6WtHykB9MC0yLi+eL2CwwsOjxWrJD0WPHyYtS9LDqc0RAKY9m8iHgHAy+NPiTpnJEeUKvEwGffY+Xz7xuAtwCzgeeBL47scJprNIRCDzBz0P0ZxbZRLyJ6iq/bgdUMvFQaS7ZJOhGg+Lp9hMfTFBGxLSL6IqIf+Bpj7HkbDaHwEDBL0qmSJgCLgTUjPKbKJE2WNOXgbeA84InDf9eoswZYWtxeCnx/BMfSNAeDrnARY+x5a/sVoiKiV9IK4C6gE1gVERtGeFjNMA1YLQkGnodvR8SdIzukxkm6FZgPHCepG7ga+Dzwn5KWMfBf4d8/ciNsTJ3HNV/SbAZeDm0BLh+xAbaA25zNLDEaXj6Y2TByKJhZwqFgZgmHgpklHApmlnAomFnCoWBmif8HtiQZVRD59FQAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_tree(dist.marginals[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Probabilistic Context-Free Grammars" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + " .. autoclass:: torch_struct.SentCFG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Base Class" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.StructDistribution\n", + " :members: " + ] + } + ], + "metadata": { + "celltoolbar": "Raw Cell Format", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/networks.rst b/docs/source/networks.rst new file mode 100644 index 00000000..81a5a2ab --- /dev/null +++ b/docs/source/networks.rst @@ -0,0 +1,24 @@ +================= +Networks and Data +================ + +Networks +================== + +Common structured networks. + + +.. autoclass:: torch_struct.networks.TreeLSTM + +.. autoclass:: torch_struct.networks.NeuralCFG + +.. autoclass:: torch_struct.networks.SpanLSTM + + +Datasets +======== + +Datasets for common structured prediction tasks. + +.. autoclass:: torch_struct.data.ConllXDataset +.. autoclass:: torch_struct.data.ListOpsDataset diff --git a/docs/source/refs.rst b/docs/source/refs.rst new file mode 100644 index 00000000..f2153d73 --- /dev/null +++ b/docs/source/refs.rst @@ -0,0 +1,4 @@ +References +========== + +.. bibliography:: refs.bib