diff --git a/.travis.yml b/.travis.yml index 0004ae99..b374ab1c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,7 @@ cache: pip before_install: - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test - sudo apt-get update + - sudo apt-get install pandoc - pip install -r requirements.txt - pip install -r requirements.dev.txt - pip install coveralls diff --git a/docs/requirements.txt b/docs/requirements.txt index 2404a798..f385eec8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,3 +3,5 @@ sphinx-jinja sphinxcontrib-bibtex sphinx-rtd-theme recommonmark +nbsphinx +pandoc diff --git a/docs/source/README.md b/docs/source/README.md new file mode 120000 index 00000000..fe840054 --- /dev/null +++ b/docs/source/README.md @@ -0,0 +1 @@ +../../README.md \ No newline at end of file diff --git a/docs/source/README_files b/docs/source/README_files new file mode 120000 index 00000000..81dd36e3 --- /dev/null +++ b/docs/source/README_files @@ -0,0 +1 @@ +../../README_files/ \ No newline at end of file diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst new file mode 100644 index 00000000..46c20e86 --- /dev/null +++ b/docs/source/advanced.rst @@ -0,0 +1,37 @@ + + + +============================ +Advanced Usage: Semirings +============================ + +All of the distributional code is implemented through a series of +semiring objects. These are passed through dynamic programming +backends to compute the distributions. + + +Standard Semirings +=================== + +.. autoclass:: torch_struct.LogSemiring +.. autoclass:: torch_struct.StdSemiring +.. autoclass:: torch_struct.MaxSemiring + +Higher-Order Semirings +========================= +.. autoclass:: torch_struct.EntropySemiring + +Sampling Semirings +=================== + +.. autoclass:: torch_struct.SampledSemiring +.. autoclass:: torch_struct.MultiSampledSemiring + + +Dynamic Programming +=================== + +.. autoclass:: torch_struct.LinearChain +.. autoclass:: torch_struct.SemiMarkov +.. autoclass:: torch_struct.DepTree +.. autoclass:: torch_struct.CKY diff --git a/docs/source/conf.py b/docs/source/conf.py index eb69279d..2531b65a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -30,7 +30,9 @@ 'sphinx.ext.napoleon', 'sphinxcontrib.jinja', 'sphinxcontrib.bibtex', - 'sphinx.ext.intersphinx' + 'sphinx.ext.intersphinx', + 'recommonmark', + 'nbsphinx' ] diff --git a/docs/source/index.rst b/docs/source/index.rst index 71bab2da..b8526bd3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,135 +1,19 @@ - +================= PyTorch-Struct -========================================== +================= .. toctree:: - :maxdepth: 2 + :maxdepth: 3 :caption: Contents: -Introduction -============ - -A library for structured prediction. - - - -Distributional Interface -======================== - - -The main interface is through a structured distribution objects. Each -of these implement a conditional random field over a class of -structures. Roughly, these represent specialized softmax's over -exponentially sized spaces. Each distribution object takes in -log_potentials (generalized logits) and can return properties of the -distribution. The properties of interest are, - -* Partition (e.g. logsumexp) -* Marginals (e.g. softmax) -* Argmax -* Entropy -* Samples -* to_event / from_event (adapters) - - -.. autoclass:: torch_struct.StructDistribution - :members: - -Linear Chain --------------- - -.. autoclass:: torch_struct.LinearChainCRF - - -Semi-Markov --------------- - -.. autoclass:: torch_struct.SemiMarkovCRF - - -Dependency Tree ----------------- - - -.. autoclass:: torch_struct.DependencyCRF - - -Binary Tree --------------- - -.. autoclass:: torch_struct.TreeCRF - -Context-Free Grammar ---------------------- - -.. autoclass:: torch_struct.SentCFG - - - - - -Networks -=========== - -Common structured networks. - - -.. autoclass:: torch_struct.networks.TreeLSTM - -.. autoclass:: torch_struct.networks.NeuralCFG - -.. autoclass:: torch_struct.networks.SpanLSTM - - -Data -==== - -Datasets for common structured prediction tasks. - -.. autoclass:: torch_struct.data.ConllXDataset -.. autoclass:: torch_struct.data.ListOpsDataset - - -Advanced Usage: Semirings -========================= - -All of the distributional code is implemented through a series of -semiring objects. These are passed through dynamic programming -backends to compute the distributions. - - -Standard Semirings ------------------- - -.. autoclass:: torch_struct.LogSemiring -.. autoclass:: torch_struct.StdSemiring -.. autoclass:: torch_struct.MaxSemiring - -Higher-Order Semirings ----------------------- -.. autoclass:: torch_struct.EntropySemiring - -Sampling Semirings ----------------------- - -.. autoclass:: torch_struct.SampledSemiring -.. autoclass:: torch_struct.MultiSampledSemiring - - -Dynamic Programming -------------------- - -.. autoclass:: torch_struct.LinearChain -.. autoclass:: torch_struct.SemiMarkov -.. autoclass:: torch_struct.DepTree -.. autoclass:: torch_struct.CKY - + README + model + networks + advanced + refs -References -========== -.. bibliography:: refs.bib Indices and tables diff --git a/docs/source/model.ipynb b/docs/source/model.ipynb new file mode 100644 index 00000000..955d4cdf --- /dev/null +++ b/docs/source/model.ipynb @@ -0,0 +1,599 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. toctree:: \n", + " :maxdepth: 2 \n", + " :caption: Contents: \n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch_struct\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "matplotlib.rcParams['figure.figsize'] = (7.0, 7.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linear Chain" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.LinearChainCRF " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAABtCAYAAADjwmW6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAGNUlEQVR4nO3czYuddx2G8et2JjWmgm/tJi+YLGoliG1lCNWCCyskRbHbBnQhQjZWqwhS/RtEdFGEUKMLS7uIXRQJRlDBjcTGNtSmsSVEbV4qjRattGiMfl3MCBEyzuTknHyfZ871WeWcCQ83v8nMNefMyUlVIUlSl7d0D5AkzTdDJElqZYgkSa0MkSSplSGSJLUyRJKkVouzuOgt716onTs2zeLS1+2l57Z0T1jV+z74ZveEVXlukxnyuQ2dn9eN5e+8waX6R672sczi/xEt3bG5fnV0x9SvOw17t97ZPWFVRy+c6J6wKs9tMkM+t6Hz87qxHKuf8nq9dtUQ+dScJKmVIZIktTJEkqRWhkiS1MoQSZJaGSJJUitDJElqZYgkSa0MkSSplSGSJLUyRJKkVoZIktTKEEmSWhkiSVKrdYUoyb4kLyY5neThWY+SJM2PNUOUZAF4BLgP2A3sT7J71sMkSfNhPY+I9gCnq+pMVV0CngDun+0sSdK8WE+ItgFnr7h9buW+/5HkQJLjSY5f/PO/prVPkrTBTe3FClV1sKqWqmrp1vcsTOuykqQNbj0hOg/suOL29pX7JEm6busJ0dPAbUl2JbkJeAB4arazJEnzYnGtv1BVl5M8CBwFFoBDVXVy5sskSXNhzRABVNUR4MiMt0iS5pDvrCBJamWIJEmtDJEkqZUhkiS1MkSSpFaGSJLUyhBJkloZIklSK0MkSWpliCRJrQyRJKmVIZIktTJEkqRW63r37Wv10nNb2Lv1zllc+rodvXCie8Kqhnpm4LlNasjnNnR+XjeWPXvfXPVjPiKSJLUyRJKkVoZIktTKEEmSWhkiSVIrQyRJamWIJEmtDJEkqZUhkiS1MkSSpFaGSJLUyhBJkloZIklSK0MkSWpliCRJrQyRJKnVmiFKcijJq0mevxGDJEnzZT2PiL4P7JvxDknSnFozRFX1C+C1G7BFkjSHFqd1oSQHgAMAm9kyrctKkja4qb1YoaoOVtVSVS1t4q3TuqwkaYPzVXOSpFaGSJLUaj0v334c+CVwe5JzST43+1mSpHmx5osVqmr/jRgiSZpPPjUnSWpliCRJrQyRJKmVIZIktTJEkqRWhkiS1MoQSZJaGSJJUitDJElqZYgkSa0MkSSplSGSJLUyRJKkVoZIktQqVTX9iyYXgT9M6XK3AH+a0rXmiec2Gc9tMp7b5Obl7N5bVbde7QMzCdE0JTleVUvdO8bGc5uM5zYZz21ynp1PzUmSmhkiSVKrMYToYPeAkfLcJuO5TcZzm9zcn93gf0ckSdrYxvCISJK0gQ06REn2JXkxyekkD3fvGYMkO5L8PMkLSU4meah705gkWUjybJIfdW8ZiyTvTHI4yW+TnEry4e5NY5Dkyytfo88neTzJ5u5NXQYboiQLwCPAfcBuYH+S3b2rRuEy8JWq2g3cDXzec7smDwGnukeMzLeBH1fV+4E78PzWlGQb8EVgqao+ACwAD/Su6jPYEAF7gNNVdaaqLgFPAPc3bxq8qnqlqp5Z+fPfWP6msK131Tgk2Q58Ani0e8tYJHkH8FHguwBVdamq/tK7ajQWgbclWQS2ABea97QZcoi2AWevuH0Ov6FekyQ7gbuAY71LRuNbwFeBf3cPGZFdwEXgeytPaT6a5ObuUUNXVeeBbwAvA68Af62qn/Su6jPkEOk6JHk78EPgS1X1eveeoUvySeDVqvp195aRWQQ+BHynqu4C3gD8fe4akryL5Wd4dgFbgZuTfLp3VZ8hh+g8sOOK29tX7tMakmxiOUKPVdWT3XtG4h7gU0l+z/LTwB9L8oPeSaNwDjhXVf991H2Y5TDp//s48LuqulhV/wSeBD7SvKnNkEP0NHBbkl1JbmL5F3lPNW8avCRh+fn6U1X1ze49Y1FVX6uq7VW1k+V/az+rqrn9CXW9quqPwNkkt6/cdS/wQuOksXgZuDvJlpWv2XuZ4xd5LHYPWE1VXU7yIHCU5VeUHKqqk82zxuAe4DPAb5KcWLnv61V1pHGTNrYvAI+t/MB4Bvhs857Bq6pjSQ4Dz7D8StdnmeN3WPCdFSRJrYb81JwkaQ4YIklSK0MkSWpliCRJrQyRJKmVIZIktTJEkqRWhkiS1Oo/sU9lwBtdqp8AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, C = 3, 10, 2\n", + "def show_chain(chain):\n", + " plt.imshow(chain.detach().sum(-1).transpose(0, 1))\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, C, C)\n", + "dist = torch_struct.LinearChainCRF(log_potentials)\n", + "show_chain(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAABtCAYAAADjwmW6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAGp0lEQVR4nO3c3YtcBx3G8efpbkKymzZaKkqTYIOElEWQ6lobKyJNwRTF3jbQglLojdX6hkTB/0CsgkUINXphaS9ikaih8ULBG61d04BNYyCk2mysNlLtxs3F5uXxYkaIsCe72Z3Jb86e7+cqM7McHg7Z/c6cGcZJBABAlZuqBwAAuo0QAQBKESIAQClCBAAoRYgAAKUIEQCg1PhQDrpxMus23zqMQ6/alQ2j+3H1d2+aq57Q6B8Xbq6e0OimC6P7fMqXqxdc27q5heoJjXbc+Xb1hEYnX7utekKji7e4esKiLv7rLV2en1903FBCtG7zrXrfI18ZxqFX7fzOi9UTGn3tYy9UT2j05LHd1RMabTw6UT2h0fq50X3iI0nvOTJbPaHRL4/8vHpCo90PP1o9odGZ+9dXT1jU7PefbHxsdJ9KAgA6gRABAEoRIgBAKUIEAChFiAAApQgRAKAUIQIAlCJEAIBShAgAUIoQAQBKESIAQClCBAAoRYgAAKUIEQCg1LJCZHuP7ZO2T9neN+xRAIDuWDJEtsckPSXpAUlTkvbanhr2MABANyznFdHdkk4lOZ1kQdJzkh4c7iwAQFcsJ0RbJJ256vZs/77/Y/sx2zO2Zy5fmB/UPgDAGjewDysk2Z9kOsn02MTkoA4LAFjjlhOis5K2XXV7a/8+AABWbTkheknSDtvbba+X9JCkQ8OdBQDoivGlfiDJJduPSzoiaUzSgSTHh74MANAJS4ZIkpIclnR4yFsAAB3ENysAAEoRIgBAKUIEAChFiAAApQgRAKAUIQIAlCJEAIBShAgAUIoQAQBKESIAQClCBAAoRYgAAKUIEQCg1LK+fft6ZfKKFj5yfhiHXrX1JzdVT2h06LOfqJ7QaMN9E9UTGi18+D/VExpN3nyhesI1fWvfz6onNPrk7R+qntBo17E/VE9odPr3u6onLCpjaXyMV0QAgFKECABQihABAEoRIgBAKUIEAChFiAAApQgRAKAUIQIAlCJEAIBShAgAUIoQAQBKESIAQClCBAAoRYgAAKUIEQCgFCECAJRaMkS2D9h+0/YrN2IQAKBblvOK6MeS9gx5BwCgo5YMUZLfSnrrBmwBAHTQwN4jsv2Y7RnbM5fn5gd1WADAGjewECXZn2Q6yfTYLZODOiwAYI3jU3MAgFKECABQajkf335W0u8k7bQ9a/vR4c8CAHTF+FI/kGTvjRgCAOgmLs0BAEoRIgBAKUIEAChFiAAApQgRAKAUIQIAlCJEAIBShAgAUIoQAQBKESIAQClCBAAoRYgAAKUIEQCgFCECAJRyksEf1D4n6a8DOtxtkv45oGN1CedtZThvK8N5W7munLv3JnnXYg8MJUSDZHsmyXT1jrbhvK0M521lOG8rx7nj0hwAoBghAgCUakOI9lcPaCnO28pw3laG87ZynT93I/8eEQBgbWvDKyIAwBo20iGyvcf2SdunbO+r3tMGtrfZ/o3tV20ft/1E9aY2sT1m+2Xbv6je0ha232H7oO0/2z5he1f1pjaw/eX+7+grtp+1vaF6U5WRDZHtMUlPSXpA0pSkvbanale1wiVJX00yJekeSZ/nvF2XJySdqB7RMt+T9EKSOyV9QJy/JdneIumLkqaTvF/SmKSHalfVGdkQSbpb0qkkp5MsSHpO0oPFm0ZekjeSHO3/+7x6fxS21K5qB9tbJX1K0tPVW9rC9mZJH5f0Q0lKspDk37WrWmNc0kbb45ImJP2teE+ZUQ7RFklnrro9K/6gXhfbd0i6S9KLtUta47uSvi7pSvWQFtku6ZykH/UvaT5te7J61KhLclbStyW9LukNSW8n+VXtqjqjHCKsgu1Nkn4q6UtJ5qr3jDrbn5b0ZpI/Vm9pmXFJH5T0gyR3SZqXxPu5S7D9TvWu8GyXdLukSdsP166qM8ohOitp21W3t/bvwxJsr1MvQs8keb56T0vcK+kztv+i3mXg+2z/pHZSK8xKmk3yv1fdB9ULE67tfkmvJTmX5KKk5yV9tHhTmVEO0UuSdtjebnu9em/kHSreNPJsW73r9SeSfKd6T1sk+UaSrUnuUO//2q+TdPYZ6nIl+bukM7Z39u/aLenVwklt8bqke2xP9H9nd6vDH/IYrx7QJMkl249LOqLeJ0oOJDlePKsN7pX0iKQ/2T7Wv++bSQ4XbsLa9gVJz/SfMJ6W9LniPSMvyYu2D0o6qt4nXV9Wh79hgW9WAACUGuVLcwCADiBEAIBShAgAUIoQAQBKESIAQClCBAAoRYgAAKUIEQCg1H8BOeNrqkPyFUIAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_chain(dist.marginals[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAABtCAYAAADjwmW6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAGQUlEQVR4nO3czYtdBxnH8e/PmaY1Ed+7yQs2i7YSpLYyxGrBhRGSothtA7oQoRtrqwhS/RtEdFGEUKMLS7uIXRQJjqCCGwmNbdCmsRKiNi+VxhatNNg0+riYESJknMnk3jznzP1+Vrl3hsPDw5n5zj335KaqkCSpy9u6B5AkzTZDJElqZYgkSa0MkSSplSGSJLUyRJKkVvPTOOim3Fg3sWUah75mt91xoXuEFf3ht5u7R9CEDfl8g2Gfc0Pe3ZD3NlT/5A0u1pu50tcyjf9H9M68tz6aPRM/7iQsnjvWPcKK9m69s3sETdiQzzcY9jk35N0NeW9DdaR+zuv12hVD5KU5SVIrQyRJamWIJEmtDJEkqZUhkiS1MkSSpFaGSJLUyhBJkloZIklSK0MkSWpliCRJrQyRJKmVIZIktTJEkqRWawpRkn1JXkxyMskj0x5KkjQ7Vg1RkjngUeBeYBewP8muaQ8mSZoNa3lFtBs4WVWnquoi8CRw33THkiTNirWEaBtw+rLHZ5af+x9JHkhyNMnRt3hzUvNJkja4id2sUFUHqmqhqhZu4MZJHVaStMGtJURngR2XPd6+/JwkSddsLSF6Brg1yc4km4D7gaenO5YkaVbMr/YNVXUpyYPAIjAHHKyq41OfTJI0E1YNEUBVHQYOT3kWSdIM8pMVJEmtDJEkqZUhkiS1MkSSpFaGSJLUyhBJkloZIklSK0MkSWpliCRJrQyRJKmVIZIktTJEkqRWhkiS1GpNn759tW674wKLi8emcehrtnfrnd0jrGjx3DB3pvUb8vkGwz7nhry7Ie9tqHbvvbDi13xFJElqZYgkSa0MkSSplSGSJLUyRJKkVoZIktTKEEmSWhkiSVIrQyRJamWIJEmtDJEkqZUhkiS1MkSSpFaGSJLUyhBJkloZIklSq1VDlORgkleSPH89BpIkzZa1vCL6IbBvynNIkmbUqiGqql8Br12HWSRJM2hi7xEleSDJ0SRHz7/6r0kdVpK0wU0sRFV1oKoWqmrh5vfNTeqwkqQNzrvmJEmtDJEkqdVabt9+Avg1cHuSM0m+OP2xJEmzYn61b6iq/ddjEEnSbPLSnCSplSGSJLUyRJKkVoZIktTKEEmSWhkiSVIrQyRJamWIJEmtDJEkqZUhkiS1MkSSpFaGSJLUyhBJkloZIklSq1TV5A+anAf+PKHDvR/464SONUvc2/q4t/Vxb+s3K7v7QFXdfKUvTCVEk5TkaFUtdM8xNu5tfdzb+ri39XN3XpqTJDUzRJKkVmMI0YHuAUbKva2Pe1sf97Z+M7+7wb9HJEna2MbwikiStIENOkRJ9iV5McnJJI90zzMGSXYk+WWSF5IcT/Jw90xjkmQuyXNJftI9y1gkeXeSQ0l+n+REko91zzQGSb66/DP6fJInktzUPVOXwYYoyRzwKHAvsAvYn2RX71SjcAn4WlXtAu4GvuTersrDwInuIUbmu8BPq+qDwIdxf6tKsg14CFioqg8Bc8D9vVP1GWyIgN3Ayao6VVUXgSeB+5pnGryqermqnl3+9z9Y+qWwrXeqcUiyHfg08Fj3LGOR5F3AJ4DvA1TVxar6W+9UozEPvD3JPLAZONc8T5shh2gbcPqyx2fwF+pVSXILcBdwpHeS0fgO8HXg392DjMhO4Dzwg+VLmo8l2dI91NBV1VngW8BLwMvA36vqZ71T9RlyiHQNkrwD+DHwlap6vXueoUvyGeCVqvpN9ywjMw98BPheVd0FvAH4fu4qkryHpSs8O4GtwJYkn+udqs+QQ3QW2HHZ4+3Lz2kVSW5gKUKPV9VT3fOMxD3AZ5P8iaXLwJ9M8qPekUbhDHCmqv77qvsQS2HS//cp4I9Vdb6q3gKeAj7ePFObIYfoGeDWJDuTbGLpjbynm2cavCRh6Xr9iar6dvc8Y1FV36iq7VV1C0vn2i+qamb/Ql2rqvoLcDrJ7ctP7QFeaBxpLF4C7k6yeflndg8zfJPHfPcAK6mqS0keBBZZuqPkYFUdbx5rDO4BPg/8Lsmx5ee+WVWHG2fSxvZl4PHlPxhPAV9onmfwqupIkkPAsyzd6focM/wJC36ygiSp1ZAvzUmSZoAhkiS1MkSSpFaGSJLUyhBJkloZIklSK0MkSWpliCRJrf4DML5jrsusy7kAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "event = dist.to_event(torch.tensor([[0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1]]), 2)\n", + "show_chain(event[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HMM" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.HMM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Semi-Markov" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.SemiMarkovCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAABGCAYAAACOqZBPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAHyUlEQVR4nO3df+hVdx3H8eerrzrZWpvfOZZTaTPWwKAt+WIr1hgYTiVmRYQSZG0go4T2R4QwGKP/VtQfxShWyVaMTfqxknA4W8H+STcTddo2/U6MaU7XDK1GmevdH+fzlev13Ps92jn3c77e1wMu99xzPvd83+/7+XzP+3s/53zvVURgZmaWy7tyB2BmZsPNhcjMzLJyITIzs6xciMzMLCsXIjMzy8qFyMzMsprWxE5nj47EDfOn17a//Xsur21fbfaBD71d6/7a/LrVmWub8zSzwr/4J6fj3yrbpib+j2jslpnxwpb5te3vrutvrW1fbbblL7tq3V+bX7c6c21znmZW2B7PcSpOlBYiT82ZmVlWLkRmZpaVC5GZmWXlQmRmZllVKkSSlkl6VdK4pPVNB2VmZsNj0kIkaQR4BFgOLARWS1rYdGBmZjYcqrwjWgyMR8TBiDgNPAWsbDYsMzMbFlUK0Vzg9Y7Hh9M6MzOz/1ttFytIWitph6Qdb771Tl27NTOzS1yVQnQE6PyYhHlp3Tki4tGIGIuIsWuvGakrPjMzu8RVKUQvAjdJulHSDGAVsKnZsMzMbFhM+qGnEXFG0jpgCzACbIiIfY1HZmZmQ6HSp29HxGZgc8OxmJnZEPInK5iZWVYuRGZmlpULkZmZZeVCZGZmWTXyDa3v0Wh8REtq36/ZsPO3+F6cYckT2purv6HVzMxay4XIzMyyciEyM7OsXIjMzCwrFyIzM8uqyje0bpB0XNLeQQRkZmbDpco7oseAZQ3HYWZmQ2rSQhQRzwMnBhCLmZkNIZ8jMjOzrCp9DUQVktYCawFmcnlduzUzs0tcbe+IOr8qfDqX1bVbMzO7xHlqzszMsqpy+faTwB+AmyUdlnRv82GZmdmwmPQcUUSsHkQgZmY2nDw1Z2ZmWbkQmZlZVi5EZmaWlQuRmZll5UJkZmZZKSLq36n0JvDnSZrNBv5a+w8fLOfQDs4hv6kePziHpr0vIq4t29BIIapC0o6IGMvyw2viHNrBOeQ31eMH55CTp+bMzCwrFyIzM8sqZyF6NOPProtzaAfnkN9Ujx+cQzbZzhGZmZmBp+bMzCyzxguRpGWSXpU0Lml9yfbLJG1M27dLuqHpmC6EpPmSfi/pT5L2SfpqSZs7JZ2UtCvdHswRaz+SDkl6KcW3o2S7JH039cMeSYtyxNmLpJs7Xt9dkk5Jur+rTev6QdIGSccl7e1YNyppq6QD6X5Wj+euSW0OSFozuKjPiaEs/m9JeiWNk6clXd3juX3H3KD0yOEhSUc6xsqKHs/te/walB45bOyI/5CkXT2e24p+6CsiGrsBI8BrwAJgBrAbWNjV5svAD9LyKmBjkzFdRA5zgEVp+Upgf0kOdwK/yR3rJHkcAmb32b4CeAYQcBuwPXfMk4yrNyj+L6HV/QDcASwC9nas+yawPi2vBx4ued4ocDDdz0rLs1oS/1JgWlp+uCz+KmMucw4PAV+rMM76Hr9y5tC1/dvAg23uh363pt8RLQbGI+JgRJwGngJWdrVZCTyeln8OLJGkhuOqLCKORsTOtPx34GVgbt6oGrES+EkUtgFXS5qTO6gelgCvRcRk/zSdXUQ8D5zoWt055h8HPlXy1LuArRFxIiL+BmwFljUWaA9l8UfEsxFxJj3cBswbdFwXokcfVFHl+DUQ/XJIx8vPAU8ONKgaNV2I5gKvdzw+zPkH8bNt0uA+CVzTcFwXJU0bfhjYXrL5o5J2S3pG0gcHGlg1ATwr6Y+S1pZsr9JXbbGK3r90be8HgOsi4mhafgO4rqTNVOmPeyjeSZeZbMzlti5NL27oMT06Vfrg48CxiDjQY3vb+8EXK1Ql6d3AL4D7I+JU1+adFNNEtwDfA3416PgquD0iFgHLga9IuiN3QBdD0gzgbuBnJZunQj+cI4q5kyl56aqkB4AzwBM9mrR5zH0feD9wK3CUYmprqlpN/3dDbe4HoPlCdASY3/F4XlpX2kbSNOAq4K2G47ogkqZTFKEnIuKX3dsj4lRE/CMtbwamS5o94DD7iogj6f448DTFtEOnKn3VBsuBnRFxrHvDVOiH5NjEtGe6P17SptX9IemLwCeBz6diep4KYy6biDgWEe9ExH+BH1IeW6v7AM4eMz8DbOzVps39MKHpQvQicJOkG9NfsquATV1tNgETVwR9Fvhdr4GdQ5p//THwckR8p0eb906c15K0mOJ1bU0xlXSFpCsnlilONu/tarYJ+EK6eu424GTH9FGb9Pzrr+390KFzzK8Bfl3SZguwVNKsNG20NK3LTtIy4OvA3RHxdo82VcZcNl3nPz9NeWxVjl+5fQJ4JSIOl21sez+c1fTVEBRXY+2nuPrkgbTuGxSDGGAmxTTLOPACsCD3FRxd8d9OMXWyB9iVbiuA+4D7Upt1wD6Kq2q2AR/LHXdXDgtSbLtTnBP90JmDgEdSP70EjOWOuySPKygKy1Ud61rdDxRF8yjwH4pzDPdSnAN9DjgA/BYYTW3HgB91PPee9HsxDnypRfGPU5w7mfh9mLjq9Xpgc78x16IcfprG+R6K4jKnO4f0+LzjV1tySOsfmxj/HW1b2Q/9bv5kBTMzy8oXK5iZWVYuRGZmlpULkZmZZeVCZGZmWbkQmZlZVi5EZmaWlQuRmZll5UJkZmZZ/Q81+TX27ppu7AAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, C, K = 3, 20, 2, 6\n", + "def show_sm(chain):\n", + " plt.imshow(chain.detach().sum(1).sum(-1).transpose(0, 1))\n", + "\n", + "# batch, N, K, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, K, C, C)\n", + "log_potentials[:, :, :3] = -1e9\n", + "dist = torch_struct.SemiMarkovCRF(log_potentials)\n", + "show_sm(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAABGCAYAAACOqZBPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAImElEQVR4nO3dbYhU1x3H8e/Ph9VqRFeNqU+ksQSLvqgVEVvSkJJiVEpMSylKoTYJhNAKDbQUISCh9E1a2hctoSVtJWkJRvqQVoohsWkhFKLVio950I1YojWaRNFEWzer/76YuzKOM7ujPXfOXff3gWHu3Hvm7P/sOXP/d87cmauIwMzMLJcRuQMwM7PhzYnIzMyyciIyM7OsnIjMzCwrJyIzM8vKicjMzLIaVUal47rHxMQZ45LVd+5VJasrtQuzxiera/QHyaoCYOSFi8nqujh2ZLK6ACLhIdCIvrRfQVDC+mJU4rGb+NsWupSuwkuj0x7X6mK62EZ8eClZXZD29dD3kWRVAdD173NpK0zkv5yjNy40fUGUkogmzhjH/Rs/l6y+HQvS7gRT6vn2kmR1zfh72r3MTQfPJKvr/U9MTFYXQO/4dDutce/2JasLoOt0b7K6LkwZk6wuSJskAUafT/e/Oz+tK1ldAF1n0x1IjT2W9ijvzPxJyeo6NT/twcqt619JWl8q2+Ollts8NWdmZlk5EZmZWVZORGZmlpUTkZmZZdVWIpK0TNIbknokrSs7KDMzGz4GTUSSRgJPAMuBecBqSfPKDszMzIaHdt4RLQZ6IuJwRPQCzwIryw3LzMyGi3YS0UzgrbrHR4t1ZmZm/7dkJytIekjSTkk7z5++kKpaMzO7wbWTiI4Bs+sezyrWXSEinoyIRRGxaFx32m+Tm5nZjaudRLQDuF3SbZK6gFXA5nLDMjOz4WLQ35qLiD5Ja4EXgJHAhog4UHpkZmY2LLT1o6cRsQXYUnIsZmY2DPmXFczMLCsnIjMzy8qJyMzMsnIiMjOzrEq5QuvMUf/h+9P2JavvHhYkqyu1SHjx2FHn017OuG/S2GR1KW1odJ1LV2HvhMSXMVe678FdGp326ptn56R9yU44mu5YdGRv2qvHnp6b7oqv086m/W5jyrb2Tkl8/fchyO+IzMwsKyciMzPLyonIzMyyciIyM7OsnIjMzCyrdq7QukHSSUn7OxGQmZkNL+28I3oKWFZyHGZmNkwNmogi4mXgVAdiMTOzYcifEZmZWValXCr8nfcupqrWzMxucMkSUf2lwm+ekvYnV8zM7MblqTkzM8uqndO3NwKvAHMlHZX0YPlhmZnZcDHoT/lGxOpOBGJmZsOTp+bMzCwrJyIzM8vKicjMzLJyIjIzs6yciMzMLCtFpL9euqR3gH8NUmwq8G7yP95ZbkM1uA35DfX4wW0o260RcXOzDaUkonZI2hkRi7L88UTchmpwG/Ib6vGD25CTp+bMzCwrJyIzM8sqZyJ6MuPfTsVtqAa3Ib+hHj+4Ddlk+4zIzMwMPDVnZmaZlZ6IJC2T9IakHknrmmwfI2lTsX27pI+VHdO1kDRb0t8kvSrpgKRvNSlzl6QzknYXt/U5Yh2IpCOS9hXx7WyyXZJ+UvTDXkkLc8TZiqS5df/f3ZLOSnqkoUzl+kHSBkknJe2vWzdZ0lZJh4r77hbPXVOUOSRpTeeiviKGZvH/UNLrxTh5TtKkFs8dcMx1Sos2PCbpWN1YWdHiuQPuvzqlRRs21cV/RNLuFs+tRD8MKCJKuwEjgTeBOUAXsAeY11DmG8DPi+VVwKYyY7qONkwHFhbLE4CDTdpwF/Dn3LEO0o4jwNQBtq8AngcELAG25455kHH1NrXvJVS6H4A7gYXA/rp1PwDWFcvrgMebPG8ycLi47y6WuysS/1JgVLH8eLP42xlzmdvwGPCdNsbZgPuvnG1o2P4jYH2V+2GgW9nviBYDPRFxOCJ6gWeBlQ1lVgJPF8u/A+6WpJLjaltEHI+IXcXy+8BrwMy8UZViJfDrqNkGTJI0PXdQLdwNvBkRg31pOruIeBk41bC6fsw/DdzX5Kn3AFsj4lREnAa2AstKC7SFZvFHxIsR0Vc83AbM6nRc16JFH7Sjnf1XRwzUhmJ/+RVgY0eDSqjsRDQTeKvu8VGu3olfLlMM7jPAlJLjui7FtOGngO1NNn9a0h5Jz0ua39HA2hPAi5L+KemhJtvb6auqWEXrF13V+wHglog4Xiy/DdzSpMxQ6Y8HqL2TbmawMZfb2mJ6cUOL6dGh0gefBU5ExKEW26veDz5ZoV2SbgJ+DzwSEWcbNu+iNk30SeCnwB87HV8b7oiIhcBy4JuS7swd0PWQ1AXcC/y2yeah0A9XiNrcyZA8dVXSo0Af8EyLIlUecz8DPg4sAI5Tm9oaqlYz8LuhKvcDUH4iOgbMrns8q1jXtIykUcBE4L2S47omkkZTS0LPRMQfGrdHxNmI+KBY3gKMljS1w2EOKCKOFfcngeeoTTvUa6evqmA5sCsiTjRuGAr9UDjRP+1Z3J9sUqbS/SHp68AXgK8WyfQqbYy5bCLiRERcjIhLwC9oHlul+wAu7zO/BGxqVabK/dCv7ES0A7hd0m3FkewqYHNDmc1A/xlBXwb+2mpg51DMv/4KeC0iftyizEf7P9eStJja/7UyyVTSeEkT+pepfdi8v6HYZuBrxdlzS4AzddNHVdLy6K/q/VCnfsyvAf7UpMwLwFJJ3cW00dJiXXaSlgHfBe6NiPMtyrQz5rJp+PzzizSPrZ39V26fB16PiKPNNla9Hy4r+2wIamdjHaR29smjxbrvURvEAGOpTbP0AP8A5uQ+g6Mh/juoTZ3sBXYXtxXAw8DDRZm1wAFqZ9VsAz6TO+6GNswpYttTxNnfD/VtEPBE0U/7gEW5427SjvHUEsvEunWV7gdqSfM48CG1zxgepPYZ6EvAIeAvwOSi7CLgl3XPfaB4XfQA91co/h5qn530vx76z3qdAWwZaMxVqA2/Kcb5XmrJZXpjG4rHV+2/qtKGYv1T/eO/rmwl+2Ggm39ZwczMsvLJCmZmlpUTkZmZZeVEZGZmWTkRmZlZVk5EZmaWlRORmZll5URkZmZZORGZmVlW/wNrHltNBmq6RAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_sm(dist.marginals[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaIAAABtCAYAAADjwmW6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAGLElEQVR4nO3cz4vc9R3H8deru4kxKW1tzSU/qDlYZSlNIkuMFXowhSRY6tVAe5CCF62xCMX6NxSxhyAEjR4UPaQepAS3oIVeSnQbozWmSkhbs4liamiVSBJTXz3MCAkkzjqZyfv7ne/zccrMLl9efNjd586PrJMIAIAqX6seAADoNkIEAChFiAAApQgRAKAUIQIAlCJEAIBS0+O46FJfk2VaMY5Lo8j3fvBp9YTLevfN5dUTMAZ8zU2WMzqtcznrS33M4/h/RN/wt3Ort4z8uqgzd+Jg9YTL2rpqQ/UEjAFfc5Nlf17Wxzl1yRDx1BwAoBQhAgCUIkQAgFKECABQihABAEoRIgBAKUIEAChFiAAApQgRAKAUIQIAlCJEAIBShAgAUIoQAQBKESIAQKlFhcj2Ntvv2D5i++FxjwIAdMfAENmekrRL0nZJM5J22J4Z9zAAQDcs5hHRJklHkhxNck7S85LuGu8sAEBXLCZEqyUdu+D2Qv++i9i+1/a87fnPdHZU+wAAE25kb1ZIsjvJbJLZJbpmVJcFAEy4xYTouKS1F9xe078PAIArtpgQvSbpRtvrbC+VdLekF8c7CwDQFdODPiHJedv3S5qTNCVpT5JDY18GAOiEgSGSpCT7JO0b8xYAQAfxlxUAAKUIEQCgFCECAJQiRACAUoQIAFCKEAEAShEiAEApQgQAKEWIAAClCBEAoBQhAgCUIkQAgFKECABQyklGftHZ9cvy6tzawZ+Ii2xdtaF6Ajpm7sTB6gmX1eTvhyafW1Nt2npM82+c8aU+xiMiAEApQgQAKEWIAAClCBEAoBQhAgCUIkQAgFKECABQihABAEoRIgBAKUIEAChFiAAApQgRAKAUIQIAlCJEAIBShAgAUIoQAQBKDQyR7T22P7T91tUYBADolsU8Inpa0rYx7wAAdNTAECX5s6RTV2ELAKCDRvYake17bc/bnj/50f9GdVkAwIQbWYiS7E4ym2R25XemRnVZAMCE411zAIBShAgAUGoxb99+TtJfJN1ke8H2L8Y/CwDQFdODPiHJjqsxBADQTTw1BwAoRYgAAKUIEQCgFCECAJQiRACAUoQIAFCKEAEAShEiAEApQgQAKEWIAAClCBEAoBQhAgCUIkQAgFKECABQyklGf1H7pKR/jehy10v694iu1SWc23A4t+FwbsPrytl9N8nKS31gLCEaJdvzSWard7QN5zYczm04nNvwODuemgMAFCNEAIBSbQjR7uoBLcW5DYdzGw7nNrzOn13jXyMCAEy2NjwiAgBMsEaHyPY22+/YPmL74eo9bWB7re0/2X7b9iHbO6s3tYntKduv2/5D9Za2sP0t23tt/932Ydu3VW9qA9u/6n+PvmX7OdvLqjdVaWyIbE9J2iVpu6QZSTtsz9SuaoXzkh5KMiNps6T7OLevZKekw9UjWuZ3kl5KcrOk9eL8BrK9WtIDkmaTfF/SlKS7a1fVaWyIJG2SdCTJ0STnJD0v6a7iTY2X5P0kB/r//kS9Hwqra1e1g+01ku6U9ET1lraw/U1JP5L0pCQlOZfkP7WrWmNa0rW2pyUtl3SieE+ZJodotaRjF9xeED9QvxLbN0jaKGl/7ZLWeEzSryV9Xj2kRdZJOinpqf5Tmk/YXlE9qumSHJf0W0nvSXpf0n+T/LF2VZ0mhwhXwPbXJf1e0oNJPq7e03S2fyLpwyR/rd7SMtOSbpH0eJKNkk5L4vXcAWxfp94zPOskrZK0wvbPalfVaXKIjktae8HtNf37MIDtJepF6NkkL1TvaYnbJf3U9j/Vexr4DtvP1E5qhQVJC0m+eNS9V70w4cv9WNI/kpxM8pmkFyT9sHhTmSaH6DVJN9peZ3upei/kvVi8qfFsW73n6w8nebR6T1sk+U2SNUluUO9r7ZUknf0NdbGSfCDpmO2b+ndtkfR24aS2eE/SZtvL+9+zW9ThN3lMVw+4nCTnbd8vaU69d5TsSXKoeFYb3C7p55L+Zvtg/75Hkuwr3ITJ9ktJz/Z/YTwq6Z7iPY2XZL/tvZIOqPdO19fV4b+wwF9WAACUavJTcwCADiBEAIBShAgAUIoQAQBKESIAQClCBAAoRYgAAKUIEQCg1P8BnlJfqYxcYoYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Use -1 for segments.\n", + "event = dist.to_event(torch.tensor([[0, 1, -1, 1, -1, -1, 0, 1, 0, 1, 1]]), (2, 6))\n", + "show_sm(event[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Alignment" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.AlignmentCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAFECAYAAAByNKo5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAARDklEQVR4nO3df4wtZ1kH8O9jL4gUhGKVH20DhZAmSFTqhgAiEotYkFA0xJQIFmhyQxQFAyFVEkL8RxHFnwRyhQoaUhoLSENAqEhCTKByb2npL6ClFmgtLVgDKH9A5fGPPbdZtmf3bvfM3fNuz+eTbPacmffMPNOZe76dmXffqe4OAIzmh5ZdAADMI6AAGJKAAmBIAgqAIQkoAIYkoAAY0oG9XFmdXJ3HTLe8nz0y3bImXBQA9843uvvHN0/c04DKY5Icnm5xh2u6ZU24KADunS/Pm+gSHwBDElAADElAATAkAQXAkBYKqKo6u6q+UFU3VtUFUxUFALsOqKo6IclbkzwnyROSvKiqnjBVYQCstkXOoJ6c5Mbuvqm7v5vkvUnOmaYsAFbdIgF1SpKvbnh/y2waACzsuHeSqKqDVXW4qg7n68d7bQDcVywSULcmOW3D+1Nn035Adx/q7rXuXss9BrIAgPkWCajPJHl8VZ1eVfdPcm6SS6cpC4BVt+ux+Lr7rqp6ZZKPJjkhyYXdfe1klQGw0hYaLLa7P5zkwxPVAgB3M5IEAEMSUAAMSUABMCQBBcCQ9vaJukcy6aNrKz3dwjxTF2AozqAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAhCSgAhiSgABiSgAJgSAIKgCEJKACGJKAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAhHVh2AaPoiZdXEy8PYNU4gwJgSAIKgCEJKACGJKAAGJKAAmBIuw6oqjqtqj5RVddV1bVV9aopCwNgtS3SzfyuJK/p7iuq6sFJjlTVZd193US1AbDCdn0G1d23dfcVs9ffTnJ9klOmKgyA1TbJPaiqekySJyW5fIrlAcDCI0lU1YOSvC/Jq7v7W3PmH0xycNH1ALBaqnv3g/xU1f2SfCjJR7v7LTtoP/GIQtMtricenMhQRwA7dqS71zZPXKQXXyV5Z5LrdxJOAHBvLHIP6ueSvCTJL1bVlbOf505UFwArbtf3oLr73+JKFgDHiZEkABiSgAJgSAIKgCEJKACGJKAAGNLCI0ks13SdCGvCP/pNpv3DX10lgVXkDAqAIQkoAIYkoAAYkoACYEgCCoAhCSgAhiSgABiSgAJgSAIKgCEJKACGJKAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAhCSgAhrTPH/k+pWkfrD7tI+Q99B1YPc6gABiSgAJgSAIKgCEJKACGJKAAGNLCAVVVJ1TVZ6vqQ1MUBADJNGdQr0py/QTLAYC7LRRQVXVqkl9J8o5pygGAdYueQf1Fktcl+f4EtQDA3XYdUFX1vCR3dPeRY7Q7WFWHq+rwbtcFwOqp7t0NyVNVf5TkJUnuSvKAJD+a5P3d/eJtPjPl+D+DM9QRwA4d6e61zRN3HVA/sJCqZyZ5bXc/7xjtBNSuCCjgPm1uQPk7KACGNMkZ1I5X5gxql5xBAfdpzqAA2D8EFABDElAADElAATAkAQXAkA4suwCObequj/oEAvuBMygAhiSgABiSgAJgSAIKgCEJKACGJKAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAhCSgAhiSgABiSgAJgSAIKgCEJKACGJKAAGJJHvh830z1YvTz0HVhBzqAAGJKAAmBIAgqAIQkoAIYkoAAY0kIBVVUPrapLqurzVXV9VT11qsIAWG2LdjP/yyT/3N0vrKr7J3ngBDUBwO4DqqoekuQZSV6aJN393STfnaYsAFbdIpf4Tk/y9SR/V1Wfrap3VNWJE9UFwIpbJKAOJDkzydu6+0lJ/jfJBZsbVdXBqjpcVYcXWBcAK2aRgLolyS3dffns/SVZD6wf0N2Hunutu9cWWBcAK2bXAdXdX0vy1ao6YzbprCTXTVIVACtv0V58v5PkPbMefDcledniJQHAggHV3VcmcekOgMkZSQKAIQkoAIYkoAAYkoACYEgCCoAhLdrNnH2oJ1xWTbgsgI2cQQEwJAEFwJAEFABDElAADElAATAkAQXAkAQUAEMSUAAMSUABMCQBBcCQBBQAQxJQAAxJQAEwJAEFwJAEFABDElAADElAATAkAQXAkDzyfV+Y9sHqNeFD33uFHvq+OlsKY3AGBcCQBBQAQxJQAAxJQAEwJAEFwJAWCqiq+r2quraqrqmqi6rqAVMVBsBq23VAVdUpSX43yVp3PzHJCUnOnaowAFbbopf4DiT5kao6kOSBSf5z8ZIAYIGA6u5bk/xpkq8kuS3JN7v7Y1MVBsBqW+QS30lJzklyepJHJTmxql48p93BqjpcVYd3XyYAq2aRS3zPSvIf3f317v5ekvcnedrmRt19qLvXunttgXUBsGIWCaivJHlKVT2wqirJWUmun6YsAFbdIvegLk9ySZIrklw9W9ahieoCYMVV93QjWx9zZVV7tzK2YTTz3VidLYU9d2TebSAjSQAwJAEFwJAEFABDElAADElAATCkA8sugGWYrj/a2D3bpu40OvbWwn2NMygAhiSgABiSgAJgSAIKgCEJKACGJKAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAhCSgAhiSgABiSgAJgSAIKgCEJKACGJKAAGJJHvsMOTf0AedhOLbuAATiDAmBIAgqAIQkoAIYkoAAYkoACYEjHDKiqurCq7qiqazZMe1hVXVZVN8x+n3R8ywRg1ezkDOpdSc7eNO2CJB/v7scn+fjsPQBM5pgB1d2fTHLnpsnnJHn37PW7k7xg4roAWHG7vQf18O6+bfb6a0kePlE9AJBkgpEkururass/sq+qg0kOLroeAFbLbs+gbq+qRybJ7PcdWzXs7kPdvdbda7tcFwAraLcBdWmS82avz0vywWnKAYB1O+lmflGSTyU5o6puqarzk/xxkl+qqhuSPGv2HgAmU917N0bzdveqYHrTHm5tfGn20IodbUfm3QYykgQAQxJQAAxJQAEwJAEFwJAEFABDWngkCRjXtP2gVqxXFbsyXc/RVeo1utWWOoMCYEgCCoAhCSgAhiSgABiSgAJgSAIKgCEJKACGJKAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAhCSgAhiSgABiSgAJgSAIKgCF55DvAZKZ7TPvYD3yf7tH26+ZvrTMoAIYkoAAYkoACYEgCCoAhCSgAhnTMgKqqC6vqjqq6ZsO0N1fV56vqc1X1gap66PEtE4BVs5MzqHclOXvTtMuSPLG7fyrJF5P8/sR1AbDijhlQ3f3JJHdumvax7r5r9vbTSU49DrUBsMKmuAf18iQfmWA5AHC3hUaSqKrXJ7kryXu2aXMwycFF1gPA6tl1QFXVS5M8L8lZ3b3luBfdfSjJodlnph4fA4D7qF0FVFWdneR1SX6hu78zbUkAsLNu5hcl+VSSM6rqlqo6P8nfJHlwksuq6sqqevtxrhOAFVPbXJ2bfmUu8QHcB0w+mvmR7l7bPNVIEgAMSUABMCQBBcCQBBQAQxJQAAxJQAEwJAEFwJAEFABDElAADElAATAkAQXAkAQUAEMSUAAMSUABMCQBBcCQBBQAQxJQAAxJQAEwJAEFwJAEFABDElAADElAATAkAQXAkAQUAEMSUAAMSUABMCQBBcCQBBQAQxJQAAxJQAEwpGMGVFVdWFV3VNU1c+a9pqq6qk4+PuUBsKp2cgb1riRnb55YVacleXaSr0xcEwAcO6C6+5NJ7pwz68+TvC5JT10UAOzqHlRVnZPk1u6+auJ6ACBJcuDefqCqHpjkD7J+eW8n7Q8mOXhv1wPAatvNGdTjkpye5KqqujnJqUmuqKpHzGvc3Ye6e62713ZfJgCr5l6fQXX31Ul+4uj7WUitdfc3JqwLgBW3k27mFyX5VJIzquqWqjr/+JcFwKqr7r3rhFdVevwB7HtTf5XXkXm3gYwkAcCQBBQAQxJQAAxJQAEwJAEFwJDu9d9BLegbSb68g3Ynz9ruZ/t9G/Z7/YltGIVtWL6J66/pFrXu0XPXspfdzHeqqg7v95En9vs27Pf6E9swCtuwfPu1fpf4ABiSgAJgSKMG1KFlFzCB/b4N+73+xDaMwjYs376sf8h7UAAw6hkUACtuaQFVVWdX1Req6saqumDO/B+uqotn8y+vqsfsfZVbq6rTquoTVXVdVV1bVa+a0+aZVfXNqrpy9vOGZdS6naq6uaquntV3eM78qqq/mu2Hz1XVmcuocytVdcaG/75XVtW3qurVm9oMtx+q6sKquqOqrtkw7WFVdVlV3TD7fdIWnz1v1uaGqjpv76q+Rx3ztuHNVfX52bHygap66Baf3fa42ytbbMMbq+rWDcfLc7f47LbfYXthi/ov3lD7zVV15RafHWIfbKu79/wnyQlJvpTksUnun+SqJE/Y1Oa3krx99vrcJBcvo9ZttuGRSc6cvX5wki/O2YZnJvnQsms9xnbcnOTkbeY/N8lHsv6HD09Jcvmyaz7GcfW1JI8efT8keUaSM5Ncs2HanyS5YPb6giRvmvO5hyW5afb7pNnrkwbahmcnOTB7/aZ527CT427J2/DGJK/dwbG27XfYsurfNP/Pkrxh5H2w3c+yzqCenOTG7r6pu7+b5L1JztnU5pwk7569viTJWVU1+V+H7VZ339bdV8xefzvJ9UlOWW5Vx8U5Sf6+1306yUOr6pHLLmoLZyX5Unfv5I/Bl6q7P5nkzk2TNx7z707ygjkf/eUkl3X3nd3930kuS3L2cSt0G/O2obs/1t13zd5+OutP3B7WFvthJ3byHXbcbVf/7Pvy15NctKdFTWhZAXVKkq9ueH9L7vnlfneb2QH/zSQ/tifV3Uuzy49PSnL5nNlPraqrquojVfWTe1rYznSSj1XVkao6OGf+TvbVKM7N1v8YR98PSfLw7r5t9vprSR4+p81+2h8vz/rZ9zzHOu6W7ZWzy5QXbnGpdT/sh59Pcnt337DF/NH3gU4Si6qqByV5X5JXd/e3Ns2+IuuXm346yV8n+ae9rm8Hnt7dZyZ5TpLfrqpnLLug3aiq+yd5fpJ/nDN7P+yHH9Dr12D2bRfbqnp9kruSvGeLJiMfd29L8rgkP5PktqxfJtuPXpTtz55G3gdJlhdQtyY5bcP7U2fT5rapqgNJHpLkv/akuh2qqvtlPZze093v3zy/u7/V3f8ze/3hJPerqpP3uMxtdfets993JPlA1i9dbLSTfTWC5yS5ortv3zxjP+yHmduPXj6d/b5jTpvh90dVvTTJ85L8xixo72EHx93SdPft3f1/3f39JH+b+bUNvR9m35m/luTirdqMvA+OWlZAfSbJ46vq9Nn/+Z6b5NJNbS5NcrSH0guT/OtWB/syzK7vvjPJ9d39li3aPOLofbOqenLW/3sPE7JVdWJVPfjo66zf4L5mU7NLk/zmrDffU5J8c8NlqJFs+X+Lo++HDTYe8+cl+eCcNh9N8uyqOml26enZs2lDqKqzk7wuyfO7+ztbtNnJcbc0m+6x/mrm17aT77BlelaSz3f3LfNmjr4P7ras3hlZ7x32xaz3hHn9bNofZv3ATpIHZP1yzY1J/j3JY5fdo2RT/U/P+iWYzyW5cvbz3CSvSPKKWZtXJrk26z18Pp3kacuue9M2PHZW21WzOo/uh43bUEneOttPVydZW3bdc7bjxKwHzkM2TBt6P2Q9TG9L8r2s3784P+v3WD+e5IYk/5LkYbO2a0neseGzL5/9u7gxycsG24Ybs35v5ui/iaM9cR+V5MPbHXcDbcM/zI71z2U9dB65eRtm7+/xHTZC/bPp7zp6/G9oO+Q+2O7HSBIADEknCQCGJKAAGJKAAmBIAgqAIQkoAIYkoAAYkoACYEgCCoAh/T8PaKp5nr2t4AAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, M = 3, 15, 20\n", + "def show_deps(tree):\n", + " plt.imshow(tree.detach())\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, M, 3)\n", + "dist = torch_struct.AlignmentCRF(log_potentials)\n", + "show_deps(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAFECAYAAAByNKo5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAUiElEQVR4nO3dfYxld13H8fd39rG7rXRpoUBbaTHYBI1IMyHgAxqLWJBQNca08aEIycYoCgbTFEmE+JfPz0azQgW1KUQEbQwIBTXEBCrb2pY+AC1Yy5a2W1jYlqW783C//nHvbobpnd3p3O/OfIf7fiWTuffcM9/7PXPOnc+cc889v8hMJEnqZmajG5AkaRwDSpLUkgElSWrJgJIktWRASZJaMqAkSS1tXc8ni3Mjuaiu3oW31NU6VFcKgGOFtRYLawH4wQJJzXw5M5+xfOK6BhQXAfvryl0TdbWurysFwP8W1jpcWAtgobDWoLAW1IanQSxtGv83bqKH+CRJLRlQkqSWDChJUksGlCSppYkCKiIuj4jPRsR9EXFtVVOSJK05oCJiC/CXwCuBFwBXRcQLqhqTJE23SfagXgzcl5lfyMw54D3AFTVtSZKm3SQBdT7wxSX3D4ymSZI0sdN+kkRE7I2I/RGxn0dP97NJkr5VTBJQDwIXLrl/wWjaN8nMfZk5m5mzPOlCFpIkjTdJQH0KeH5EXBwR24ErgRtr2pIkTbs1X4svMxci4g3Ah4EtwHWZeVdZZ5KkqTbRxWIz84PAB4t6kSTpBK8kIUlqyYCSJLVkQEmSWjKgJEktReb6jTsacXbCD5bVe+Fzj5TVetaxz5bVAsivfrWs1t2DygHk4evzdeu8tjNYLBwHd7Gs0umpJ+mEWzJzdvlE96AkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLW9f36Y4Anyqr9sCXnlVWa/u5u8tqAZyx45yyWpcMvlhWC+DI4GhZrUOD+bJaAIfIslp1Szk0l4OyWotllYbqOhuqWwvS2rkHJUlqyYCSJLVkQEmSWjKgJEktGVCSpJbWHFARcWFE/EdE3B0Rd0XEGysbkyRNt0lOM18A3pyZt0bEWcAtEXFTZt5d1JskaYqteQ8qMx/KzFtHtx8H7gHOr2pMkjTdSt6DioiLgBcBN1fUkyRp4itJRMSZwD8Bb8rMx8Y8vhfYO7znORmSpNWZKKAiYhvDcLo+M98/bp7M3AfsG86/zSuoSJJWZZKz+AJ4J3BPZv5RXUuSJE12zO37gZ8HfiQibht9vaqoL0nSlFvzIb7M/C8gCnuRJOkEz1qQJLVkQEmSWjKgJEktGVCSpJYMKElSSxNfSeKpWQC+Ulbt8OKTLlyxZnd/pfYyghfsPlpW6+wtZ5fVAtgxf6Ss1pkzC2W1AL6+UPd7Y3FQVwvYQt2yLmRtb/PUfga+cq1WfzrfT/tPD/egJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJbWech3gMWySoNB3eDPT8w9WFYL4H6eWVZrd2wpqwVwVuwqq3UmZ5TVArh46yNlteYW58tqARwprPVw1L0OAAYRpfUY1A36XrukMCgc9N3h43tzD0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktTRxQEbElIv4nIv61oiFJkqBmD+qNwD0FdSRJOmGigIqIC4AfB95R044kSUOT7kH9CXANMCjoRZKkE9YcUBHxauBgZt5yivn2RsT+iNi/1ueSJE2fSfagvh94TUTcD7wH+JGI+IflM2XmvsyczczZCZ5LkjRl1hxQmfmWzLwgMy8CrgT+PTN/rqwzSdJU83NQkqSWSq5mnpn/CfxnRS1JksA9KElSUwaUJKklA0qS1JIBJUlqyYCSJLVUchbfU5OFtRbLKg04VlYLYH7hibJaMzFfVgvg2M66Wjtzrq4Y8HjuKKs1v31LWS2AweJZZbWes+VwWS2ALy9Uvq7gaOHVy45mbW/zhfWqr9FWu6RyD0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqaQOGfO+qdvDnweBrZbW+GnXDoANsPbZQVuuxrN2Etm7dVlZr29ba39uehcfrim3ZU1cLOK94CPmFxTPKan1tbr6sFsBh6upVD0e/WDjou8PHuwclSWrKgJIktWRASZJaMqAkSS0ZUJKkliYKqIg4OyLeFxGfiYh7IuKlVY1JkqbbpOcI/ynwb5n50xGxHdhV0JMkSWsPqIh4GvAy4LUAmTkHzNW0JUmadpMc4rsYeBT424j4n4h4R0TsLupLkjTlJgmorcClwF9l5ouAI8C1y2eKiL0RsT8i9k/wXJKkKTNJQB0ADmTmzaP772MYWN8kM/dl5mxmzk7wXJKkKbPmgMrMh4EvRsQlo0mXAXeXdCVJmnqTnsX3q8D1ozP4vgD84uQtSZI0YUBl5m2Ah+4kSeW8koQkqSUDSpLUkgElSWrJgJIktWRASZJamvQ0c61ovqzSgCirBRCDuv9LBgzKagEwv1BW6mjW1QL4UhYu69G67QNg17ba6zTvWdxRVuvbtn+jrBbAzoUnymo9NjhWVgvgicyyWkcrtzeg9tWwPtyDkiS1ZEBJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktSSASVJasmAkiS1ZEBJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktSSASVJasmAkiS1ZEBJklpyyPfTpm7o58y5sloAx9heVmvHTN3w2wAzUTfU+O6txZv3Yt06ZaZ2GPSZ+cXSegejrt7MzBlltQDOjLrtdwuHy2oBzGypq7V9sfZ1P8i6Qd9rB6NfmXtQkqSWDChJUksGlCSpJQNKktSSASVJammigIqIX4+IuyLizoi4ISJ2VjUmSZpuaw6oiDgf+DVgNjO/G9gCXFnVmCRpuk16iG8rcEZEbAV2AV+avCVJkiYIqMx8EPgD4AHgIeBwZn6kqjFJ0nSb5BDfHuAK4GLgOcDuiPi5MfPtjYj9EbF/7W1KkqbNJIf4Xg78b2Y+mpnzwPuB71s+U2buy8zZzJyd4LkkSVNmkoB6AHhJROyKiAAuA+6paUuSNO0meQ/qZuB9wK3Ap0e19hX1JUmachNd7jkz3wa8ragXSZJO8EoSkqSWDChJUksGlCSpJQNKktSSASVJammis/i0XrK2Ws6V1To2qP0fZyEWymodXXhGWS2As+NIWa2sXaVUDyTwtJn5slq5uK2sFsD8oK7W3K7iARiO1W2/OVP72poZfKOsVmbhSmDlv3DuQUmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktOeT7VKobbzxzsawWwGJGWa2ZucNltQCObqv7f+7MQe2Q2duoG6IdoG4twI7iocvP2VH3u9u1eGZZLYCdgyfKat1B7Wvr4ThWWq/SwgpDyLsHJUlqyYCSJLVkQEmSWjKgJEktGVCSpJZOGVARcV1EHIyIO5dMe3pE3BQR946+7zm9bUqSps1q9qDeBVy+bNq1wMcy8/nAx0b3JUkqc8qAysyPA4eWTb4CePfo9ruBnyjuS5I05db6HtR5mfnQ6PbDwHlF/UiSBBRcSSIzMyJWvDRBROwF9k76PJKk6bLWPahHIuLZAKPvB1eaMTP3ZeZsZs6u8bkkSVNorQF1I3D16PbVwL/UtCNJ0tBqTjO/AfgEcElEHIiI1wO/A/xoRNwLvHx0X5KkMqd8Dyozr1rhocuKe5Ek6QSvJCFJasmAkiS1ZEBJkloyoCRJLRlQkqSWJr6ShFQpWSirNceRsloACwtbymo9QZTVAtg6U/tS3jkYlNXaNnisrBbAIc4oq7WDo2W1AHLrt5fVeuZZ3yirBXDO146V1Xpgse51CvAo43tzD0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqySHf9S0sS6sNsm6Y62PFQ77PFw7RDjBX2N/W2FlWC+AsjpTVemJ73fDxALvi0bJa84+dU1YLIHZcWFZr91ztcPSPzj8wdrp7UJKklgwoSVJLBpQkqSUDSpLUkgElSWrplAEVEddFxMGIuHPJtN+PiM9ExB0R8YGIOPv0tilJmjar2YN6F3D5smk3Ad+dmd8DfA54S3FfkqQpd8qAysyPA4eWTftI5okPhXwSuOA09CZJmmIV70G9DvhQQR1Jkk6Y6EoSEfFWYAG4/iTz7AX2TvI8kqTps+aAiojXAq8GLsvMFa8pk5n7gH2jn6m99owk6VvWmgIqIi4HrgF+KDNrL8okSRKrO838BuATwCURcSAiXg/8BXAWcFNE3BYRf32a+5QkTZlT7kFl5lVjJr/zNPQiSdIJXklCktSSASVJasmAkiS1ZEBJkloyoCRJLU10JQlJa1X7mfVBLtbWI8pqLeTjZbUA5mN7Wa0di7Uf4zw2c7Ss1tGdu8tqAXx77imr9Z1bai+/ev/8A2OnuwclSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIkteSQ75LGqBuSPouHt1/IubJaubCjrBbA7pm63lg4UFcLePyMw2W1nnPOW8pqAfDgR8dOdg9KktSSASVJasmAkiS1ZEBJkloyoCRJLZ0yoCLiuog4GBF3jnnszRGREXHu6WlPkjStVrMH9S7g8uUTI+JC4BXAA8U9SZJ06oDKzI8Dh8Y89MfANVR+YEKSpJE1vQcVEVcAD2bm7cX9SJIErOFKEhGxC/hNhof3VjP/XmDvU30eSdJ0W8se1HcAFwO3R8T9wAXArRHxrHEzZ+a+zJzNzNm1tylJmjZPeQ8qMz8NPPP4/VFIzWbmlwv7kiRNudWcZn4D8Angkog4EBGvP/1tSZKm3Sn3oDLzqlM8flFZN5IkjXglCUlSSwaUJKklA0qS1JIBJUlqyYCSJLUUmet3Kb2IeBT4v1XMei6w2T9XtdmXYbP3Dy5DFy7Dxuve/3Mz8xnLJ65rQK1WROzf7Fee2OzLsNn7B5ehC5dh423W/j3EJ0lqyYCSJLXUNaD2bXQDBTb7Mmz2/sFl6MJl2Hibsv+W70FJktR1D0qSNOU2LKAi4vKI+GxE3BcR1455fEdEvHf0+M0RcdH6d7myiLgwIv4jIu6OiLsi4o1j5vnhiDgcEbeNvn5rI3o9mYi4PyI+Pepv/5jHIyL+bLQe7oiISzeiz5VExCVLfr+3RcRjEfGmZfO0Ww8RcV1EHIyIO5dMe3pE3BQR946+71nhZ68ezXNvRFy9fl0/qY9xy/D7EfGZ0bbygYg4e4WfPel2t15WWIa3R8SDS7aXV63wsyf9G7YeVuj/vUt6vz8iblvhZ1usg5PKzHX/ArYAnweeB2wHbgdesGyeXwb+enT7SuC9G9HrSZbh2cClo9tnAZ8bsww/DPzrRvd6iuW4Hzj3JI+/CvgQEMBLgJs3uudTbFcPM/xMRev1ALwMuBS4c8m03wOuHd2+FvjdMT/3dOALo+97Rrf3NFqGVwBbR7d/d9wyrGa72+BleDvwG6vY1k76N2yj+l/2+B8Cv9V5HZzsa6P2oF4M3JeZX8jMOeA9wBXL5rkCePfo9vuAyyIi1rHHk8rMhzLz1tHtx4F7gPM3tqvT4grg73Lok8DZEfHsjW5qBZcBn8/M1XwYfENl5seBQ8smL93m3w38xJgf/THgpsw8lJlfBW4CLj9tjZ7EuGXIzI9k5sLo7icZjrjd1grrYTVW8zfstDtZ/6O/lz8D3LCuTRXaqIA6H/jikvsHePIf9xPzjDb4w8A569LdUzQ6/Pgi4OYxD780Im6PiA9FxHeta2Ork8BHIuKWiNg75vHVrKsurmTlF2P39QBwXmY+NLr9MHDemHk20/p4HcO973FOtd1ttDeMDlNet8Kh1s2wHn4QeCQz713h8e7rwJMkJhURZwL/BLwpMx9b9vCtDA83vRD4c+Cf17u/VfiBzLwUeCXwKxHxso1uaC0iYjvwGuAfxzy8GdbDN8nhMZhNe4ptRLwVWACuX2GWztvdXwHfAXwv8BDDw2Sb0VWcfO+p8zoANi6gHgQuXHL/gtG0sfNExFbgacBX1qW7VYqIbQzD6frMfP/yxzPzscz8+uj2B4FtEXHuOrd5Upn54Oj7QeADDA9dLLWaddXBK4FbM/OR5Q9shvUw8sjxw6ej7wfHzNN+fUTEa4FXAz87CtonWcV2t2Ey85HMXMzMAfA3jO+t9XoY/c38KeC9K83TeR0ct1EB9Sng+RFx8eg/3yuBG5fNcyNw/Aylnwb+faWNfSOMju++E7gnM/9ohXmedfx9s4h4McPfd5uQjYjdEXHW8dsM3+C+c9lsNwK/MDqb7yXA4SWHoTpZ8b/F7uthiaXb/NXAv4yZ58PAKyJiz+jQ0ytG01qIiMuBa4DXZOY3VphnNdvdhln2HutPMr631fwN20gvBz6TmQfGPdh9HZywUWdnMDw77HMMz4R562jabzPcsAF2Mjxccx/w38DzNvqMkmX9/wDDQzB3ALeNvl4F/BLwS6N53gDcxfAMn08C37fRfS9bhueNert91Ofx9bB0GQL4y9F6+jQwu9F9j1mO3QwD52lLprVeDwzD9CFgnuH7F69n+B7rx4B7gY8CTx/NOwu8Y8nPvm70urgP+MVmy3Afw/dmjr8mjp+J+xzggyfb7hotw9+PtvU7GIbOs5cvw+j+k/6Gdeh/NP1dx7f/JfO2XAcn+/JKEpKkljxJQpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqaX/Bwv7y52fgQ5UAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_deps(dist.marginals[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dependency Trees" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.DependencyCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGbCAYAAAD5r4b7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAANaUlEQVR4nO3dz6vdB5nH8c8zuWlr6ow6zGySlGkXjkORmVQutVpw0QjRUexmFhUUxk02o1YRpM7Gf0BEFyKEqhuLXcQuRIpx8MdiNsG0DWgbhVK1P1KxwzBWKtO0+Mwid6DGxHvuzH1yzrl5vaDQe87p4cOX3L7z/Z5zz63uDgDstj9b9gAA9iaBAWCEwAAwQmAAGCEwAIzYmHjS6+r6viE3Tjw1ACvkv/NSLvTLdbn7RgJzQ27M2+voxFMDsEJO9/eueJ9LZACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWDEQoGpqvdU1c+q6smqum96FADrb9vAVNW+JF9K8t4ktyb5YFXdOj0MgPW2yBnM7Ume7O6nuvtCkgeT3D07C4B1t0hgDiV55jVfP7t12x+oquNVdaaqzrySl3drHwBratde5O/uE9292d2b+3P9bj0tAGtqkcA8l+Sm13x9eOs2ALiiRQLzoyRvrqpbquq6JPck+dbsLADW3cZ2D+juV6vqo0lOJdmX5Kvd/fj4MgDW2raBSZLufjjJw8NbANhD/CQ/ACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwYqFPU2bWqfNnlz3hso4dPLLsCZfleMF6cAYDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACM2Fj2AJJjB48se8JlnTp/dtkTLmtVjxfwh5zBADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8CIbQNTVTdV1Q+q6omqeryq7r0awwBYb4v8wrFXk3yqux+tqj9P8khV/Vt3PzG8DYA1tu0ZTHc/392Pbv37b5OcS3JoehgA621HvzK5qm5OcluS05e573iS40lyQw7swjQA1tnCL/JX1euTfDPJJ7r7xUvv7+4T3b3Z3Zv7c/1ubgRgDS0UmKran4txeaC7H5qdBMBesMi7yCrJV5Kc6+7Pz08CYC9Y5AzmziQfTnJXVZ3d+ucfh3cBsOa2fZG/u/89SV2FLQDsIX6SH4ARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGLGjX5nMteXYwSPLngCsMWcwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8CIhQNTVfuq6rGq+vbkIAD2hp2cwdyb5NzUEAD2loUCU1WHk7wvyf2zcwDYKxY9g/lCkk8n+f3gFgD2kG0DU1XvT/Lr7n5km8cdr6ozVXXmlby8awMBWE+LnMHcmeQDVfWLJA8muauqvn7pg7r7RHdvdvfm/ly/yzMBWDfbBqa7P9Pdh7v75iT3JPl+d39ofBkAa83PwQAwYmMnD+7uHyb54cgSAPYUZzAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8CIHX2a8ro7df7ssidc1rGDR5Y9AWDXOYMBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIzaWPeBqOnbwyLInAFwznMEAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwIiFAlNVb6yqk1X106o6V1XvmB4GwHpb9BeOfTHJd7r7n6rquiQHBjcBsAdsG5iqekOSdyX55yTp7gtJLszOAmDdLXKJ7JYkLyT5WlU9VlX3V9WNlz6oqo5X1ZmqOvNKXt71oQCsl0UCs5HkbUm+3N23JXkpyX2XPqi7T3T3Zndv7s/1uzwTgHWzSGCeTfJsd5/e+vpkLgYHAK5o28B096+SPFNVb9m66WiSJ0ZXAbD2Fn0X2ceSPLD1DrKnknxkbhIAe8FCgenus0k2h7cAsIf4SX4ARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWDEoh/XvyN/+/e/y6lTZyee+v/l2MEjy57AHnbq/Or9mV91vif3NmcwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPAiOruXX/Sv6i/7LfX0V1/XmBvOXX+7LInrJVjB48se8IfOd3fy4v9n3W5+5zBADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8CIhQJTVZ+sqser6idV9Y2qumF6GADrbdvAVNWhJB9Pstndb02yL8k908MAWG+LXiLbSPK6qtpIciDJ+blJAOwF2wamu59L8rkkTyd5Pslvuvu7lz6uqo5X1ZmqOvNKXt79pQCslUUukb0pyd1JbklyMMmNVfWhSx/X3Se6e7O7N/fn+t1fCsBaWeQS2buT/Ly7X+juV5I8lOSds7MAWHeLBObpJHdU1YGqqiRHk5ybnQXAulvkNZjTSU4meTTJj7f+mxPDuwBYcxuLPKi7P5vks8NbANhD/CQ/ACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwYqFPUwaYcOzgkWVPWCunzp9d9oQ/cvux313xPmcwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPAiOru3X/SqheS/HKXnu6vkvzHLj3XtcDx2hnHa2ccr525Fo7X33T3X1/ujpHA7KaqOtPdm8vesS4cr51xvHbG8dqZa/14uUQGwAiBAWDEOgTmxLIHrBnHa2ccr51xvHbmmj5eK/8aDADraR3OYABYQwIDwIiVDUxVvaeqflZVT1bVfcves8qq6qaq+kFVPVFVj1fVvcvetA6qal9VPVZV3172llVXVW+sqpNV9dOqOldV71j2plVXVZ/c+n78SVV9o6puWPamq20lA1NV+5J8Kcl7k9ya5INVdetyV620V5N8qrtvTXJHkn9xvBZyb5Jzyx6xJr6Y5Dvd/XdJ/iGO259UVYeSfDzJZne/Ncm+JPcsd9XVt5KBSXJ7kie7+6nuvpDkwSR3L3nTyuru57v70a1//20ufvMfWu6q1VZVh5O8L8n9y96y6qrqDUneleQrSdLdF7r7v5a7ai1sJHldVW0kOZDk/JL3XHWrGphDSZ55zdfPxv8wF1JVNye5Lcnp5S5ZeV9I8ukkv1/2kDVwS5IXknxt65Li/VV147JHrbLufi7J55I8neT5JL/p7u8ud9XVt6qB4f+gql6f5JtJPtHdLy57z6qqqvcn+XV3P7LsLWtiI8nbkny5u29L8lISr4v+CVX1ply86nJLkoNJbqyqDy131dW3qoF5LslNr/n68NZtXEFV7c/FuDzQ3Q8te8+KuzPJB6rqF7l4+fWuqvr6ciettGeTPNvd/3tWfDIXg8OVvTvJz7v7he5+JclDSd655E1X3aoG5kdJ3lxVt1TVdbn44ti3lrxpZVVV5eL18XPd/fll71l13f2Z7j7c3Tfn4p+t73f3Nfe3y0V196+SPFNVb9m66WiSJ5Y4aR08neSOqjqw9f15NNfgGyM2lj3gcrr71ar6aJJTufjui6929+NLnrXK7kzy4SQ/rqqzW7f9a3c/vMRN7C0fS/LA1l/4nkrykSXvWWndfbqqTiZ5NBff5flYrsGPjfFRMQCMWNVLZACsOYEBYITAADBCYAAYITAAjBAYAEYIDAAj/gez/SIJ4nTCCgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, N = 3, 10, 10\n", + "def show_deps(tree):\n", + " plt.imshow(tree.detach())\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, N)\n", + "dist = torch_struct.DependencyCRF(log_potentials)\n", + "show_deps(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGbCAYAAAD5r4b7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAP/ElEQVR4nO3dX4jld3nH8eeZmd1uNolJrFXqbuIuNbUGS4ksQQ14YbTVKnpT2ghK681CqRpFkFgKXvZC8Q9FhCUqtQa9iKG1ImpBvWihwTURTLL+iUnMf1yDyWazm2zWeXqxU0izv3HPdOfZ7zmT1wsCO3NOfnzyy5x57++c+ZNVFQCw2ZZGDwBgaxIYAFoIDAAtBAaAFgIDQIuVloO+YGdte/HFHYc+O0eXRy+YtO3J1dETJuWxp0dPmLaUoxesY153RUTN58dYrLR8Cjp7q3N6vn4zf7uOrx6NE/XU5Ad/y//dbS++OPZ8bH/Hoc/K8n9dNHrCpJd8//joCZNWbv3p6AmT8rwdoydMyzl+QuDp+fzLQr7ohaMnTDv+1OgFk1aPPDF6wmn++9jX171tjh8RACwygQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFrMFJjMfHNm/iQz78rM67tHAbD4zhiYzFyOiM9ExFsi4oqIeGdmXtE9DIDFNssVzFURcVdV3V1VJyLiKxHxjt5ZACy6WQKzKyLuf9bbD6y97//IzP2ZeTAzD548cmyz9gGwoDbtRf6qOlBV+6pq38oLdm7WYQFYULME5sGIuPRZb+9eex8ArGuWwHw/Ii7PzL2ZuT0iro2Ir/XOAmDRrZzpDlV1MjPfGxHfiojliPh8Vd3RvgyAhXbGwEREVNU3IuIbzVsA2EJ8Jz8ALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBipp+mvFHLjy7F737x/I5Dn5Wlk8+MnjDp/jftGD1h0t4nXjZ6wrSjx0cvmFQP/3L0hHXlBfP3eIyIWH3okdETJtVvVkdPmFbzt6t+yyZXMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQIuVjoMuH3kqzv/OoY5Dn5VH/vqPR0+YtOffHh89YdLP//Ki0RMmvfyTh0dPmLR0ycWjJ6xr9dePjZ4wbdu20QsmLV2wffSESatPHB094XQnc92bXMEA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQIszBiYzL83M72bmnZl5R2Zedy6GAbDYZvmFYycj4kNVdWtmXhgRP8jM/6iqO5u3AbDAzngFU1UPV9Wta39+IiIORcSu7mEALLYN/crkzNwTEVdGxC0Tt+2PiP0RETvy/E2YBsAim/lF/sy8ICK+GhEfqKojz729qg5U1b6q2rc9d2zmRgAW0EyBycxtcSouN1bVzb2TANgKZvkqsoyIz0XEoar6RP8kALaCWa5gro6Id0fEGzLzh2v//HnzLgAW3Blf5K+q/4yIPAdbANhCfCc/AC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQYkO/MnlmmZErPYc+G7//5R+PnjDpsTf94egJk17+ybtGT5h07/7LR0+YtPeffzF6wuJ55pnRCyatHj8+esKkOnly9ITTVa17kysYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGgxUrHQWt1NVaPHes49FnJlZb/3LN2yS0Pj54w6dE//YPREybt/eJ9oydMOnLV7tET1nXhj35n9IRphx8dvWBSHT05esK0zNELTlfr3+QKBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABazByYzFzOzNsy8+udgwDYGjZyBXNdRBzqGgLA1jJTYDJzd0S8NSJu6J0DwFYx6xXMpyLiwxGx2rgFgC3kjIHJzLdFxC+r6gdnuN/+zDyYmQefqac2bSAAi2mWK5irI+LtmXlvRHwlIt6QmV967p2q6kBV7auqfdtyxybPBGDRnDEwVfWRqtpdVXsi4tqI+E5Vvat9GQALzffBANBiZSN3rqrvRcT3WpYAsKW4ggGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFps6KcpzyojIjM7Dn1Wll54yegJk+qJo6MnTLr4J/O5676/umz0hEmXffXB0RPWdewVLx49YdJ5R+bzYyyfenr0hEm5ffvoCafJ4+tfp7iCAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC1WWo6aGbG83HLos7F6+FejJ0zK7dtHT5i0dM9DoydMuuxXR0ZPmPSz/btGT1jX5f90z+gJk5547Z7REyZdcPvh0RMm/ebn946ecJpaXV33NlcwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBipsBk5sWZeVNm/jgzD2Xma7uHAbDYZv2FY5+OiG9W1V9k5vaI2Nm4CYAt4IyBycyLIuL1EfE3ERFVdSIiTvTOAmDRzfIU2d6IOBwRX8jM2zLzhsw8/7l3ysz9mXkwMw+eqKc2fSgAi2WWwKxExKsj4rNVdWVEPBkR1z/3TlV1oKr2VdW+7bljk2cCsGhmCcwDEfFAVd2y9vZNcSo4ALCuMwamqh6JiPsz8xVr77omIu5sXQXAwpv1q8jeFxE3rn0F2d0R8Z6+SQBsBTMFpqp+GBH7mrcAsIX4Tn4AWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaDFrD+uf2OWl2LpBRe2HPqsLM1nT1d//djoCZOWdp43esKkOvrk6AmT9vz7sdET1nXoH3eNnjDp5QeeHj1h0ondl4yeMGn7Y0dGTzhN/np53dvm8zMuAAtPYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALVZajloRtbracuizUUefHD1hUj1zcvSESauPPT56wrTl5dELJi3ffvfoCet65fU7R0+YdOgf9oyeMOkFP53Pj7GXHHvp6Amnqdu3r3ubKxgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaDFTYDLzg5l5R2benplfzswd3cMAWGxnDExm7oqI90fEvqp6VUQsR8S13cMAWGyzPkW2EhHnZeZKROyMiIf6JgGwFZwxMFX1YER8PCLui4iHI+Lxqvr2c++Xmfsz82BmHjyxenzzlwKwUGZ5iuySiHhHROyNiJdGxPmZ+a7n3q+qDlTVvqrat33pvM1fCsBCmeUpsjdGxD1VdbiqnomImyPidb2zAFh0swTmvoh4TWbuzMyMiGsi4lDvLAAW3SyvwdwSETdFxK0R8aO1f+dA8y4AFtzKLHeqqo9GxEebtwCwhfhOfgBaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoMVMP015o2r7tli97CUdhz4rSz+7b/SEactz2vnl5dELpi3N5/nKlZaH06aoiy4cPWHSKz/20OgJk372t7tHT5j0rX/9l9ETTnPVnz267m3z+UgFYOEJDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGgRVbV5h8083BE/GKTDveiiPjVJh3r+cD52hjna2Ocr415Ppyvl1XV703d0BKYzZSZB6tq3+gdi8L52hjna2Ocr415vp8vT5EB0EJgAGixCIE5MHrAgnG+Nsb52hjna2Oe1+dr7l+DAWAxLcIVDAALSGAAaDG3gcnMN2fmTzLzrsy8fvSeeZaZl2bmdzPzzsy8IzOvG71pEWTmcmbelplfH71l3mXmxZl5U2b+ODMPZeZrR2+ad5n5wbXH4+2Z+eXM3DF607k2l4HJzOWI+ExEvCUiroiId2bmFWNXzbWTEfGhqroiIl4TEX/nfM3kuog4NHrEgvh0RHyzqv4oIv4knLffKjN3RcT7I2JfVb0qIpYj4tqxq869uQxMRFwVEXdV1d1VdSIivhIR7xi8aW5V1cNVdevan5+IUw/+XWNXzbfM3B0Rb42IG0ZvmXeZeVFEvD4iPhcRUVUnquqxsasWwkpEnJeZKxGxMyIeGrznnJvXwOyKiPuf9fYD4RPmTDJzT0RcGRG3jF0y9z4VER+OiNXRQxbA3og4HBFfWHtK8YbMPH/0qHlWVQ9GxMcj4r6IeDgiHq+qb49dde7Na2D4f8jMCyLiqxHxgao6MnrPvMrMt0XEL6vqB6O3LIiViHh1RHy2qq6MiCcjwuuiv0VmXhKnnnXZGxEvjYjzM/NdY1ede/MamAcj4tJnvb177X2sIzO3xam43FhVN4/eM+eujoi3Z+a9cerp1zdk5pfGTpprD0TEA1X1v1fFN8Wp4LC+N0bEPVV1uKqeiYibI+J1gzedc/MamO9HxOWZuTczt8epF8e+NnjT3MrMjFPPjx+qqk+M3jPvquojVbW7qvbEqY+t71TV8+5vl7Oqqkci4v7MfMXau66JiDsHTloE90XEazJz59rj85p4Hn5hxMroAVOq6mRmvjcivhWnvvri81V1x+BZ8+zqiHh3RPwoM3+49r6/r6pvDNzE1vK+iLhx7S98d0fEewbvmWtVdUtm3hQRt8apr/K8LZ6HPzbGj4oBoMW8PkUGwIITGABaCAwALQQGgBYCA0ALgQGghcAA0OJ/AHVosauhJqG7AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_deps(dist.marginals[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGbCAYAAAD5r4b7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAANn0lEQVR4nO3b34ulB33H8c+3u2tioiJtU0myofHCCiJ1U4b0IlJoim7UoL1U0CthbypEWhC99B8Qb3qzaGiL1iDEgFjrGjQiAU3cxDU1PwxBUkwU1h+IptL88tuLHdsYN50jzPc8c868XjDszJ6Hk8/DbvY9zznPVHcHAPbbHyw9AIDtJDAAjBAYAEYIDAAjBAaAEUcnnvRldUlfmssnnhoOrD/7818tPWFRjz5w2dITWMB/57/yTD9dF3tsJDCX5vL8Zf3NxFPDgXXmzLmlJyzq5FUnlp7AAu7pr7zkY14iA2CEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGLFSYKrqpqr6XlU9VlUfnh4FwObbMzBVdSTJPyZ5W5I3JHlPVb1hehgAm22VK5jrkzzW3d/v7meS3JbkXbOzANh0qwTm6iQ/eMHXT+z+3m+pqlNVdbaqzj6bp/drHwAbat/e5O/u09290907x3LJfj0tABtqlcA8meSaF3x9fPf3AOAlrRKYbyV5XVW9tqpeluTdST4/OwuATXd0rwO6+7mq+kCSM0mOJLm1ux8cXwbARtszMEnS3V9M8sXhLQBsET/JD8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8CI6u59f9KdN13a9565Zt+fdxOcvOrE0hMA1uae/kp+0T+riz3mCgaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwYs/AVNWtVXW+qr67jkEAbIdVrmD+KclNwzsA2DJ7Bqa7v57kZ2vYAsAW2bf3YKrqVFWdraqzP/7p8/v1tABsqH0LTHef7u6d7t654o+O7NfTArCh3EUGwAiBAWDEKrcpfybJN5K8vqqeqKr3z88CYNMd3euA7n7POoYAsF28RAbACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYMTRiSd99IHLcvKqExNPfeCd+eG5pScs6rD+uQO/yxUMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYMSegamqa6rqrqp6qKoerKpb1jEMgM12dIVjnkvyD919f1W9Msl9VXVndz80vA2ADbbnFUx3/6i779/9/JdJHk5y9fQwADbbKlcw/6uqrk1yXZJ7LvLYqSSnkuTSXLYP0wDYZCu/yV9Vr0hye5IPdvcvXvx4d5/u7p3u3jmWS/ZzIwAbaKXAVNWxXIjLp7v7c7OTANgGq9xFVkk+meTh7v7Y/CQAtsEqVzA3JHlfkhur6tzux9uHdwGw4fZ8k7+7705Sa9gCwBbxk/wAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjDi69IBtc/KqE0tPWNSZH55besJiDvuf/WF2mP/eX3/yVy/5mCsYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwIg9A1NVl1bVvVX1nap6sKo+uo5hAGy2oysc83SSG7v7qao6luTuqvr37v7m8DYANtiegenuTvLU7pfHdj96chQAm2+l92Cq6khVnUtyPsmd3X3PRY45VVVnq+rss3l6v3cCsGFWCkx3P9/dJ5IcT3J9Vb3xIsec7u6d7t45lkv2eycAG+b3uousu3+e5K4kN83MAWBbrHIX2RVV9erdz1+e5C1JHpkeBsBmW+UusiuT/HNVHcmFIH22u78wOwuATbfKXWQPJLluDVsA2CJ+kh+AEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwIijSw9gu5y86sTSE1jImR+eW3rCYg7z3/tH+6cv+ZgrGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8CIlQNTVUeq6ttV9YXJQQBsh9/nCuaWJA9PDQFgu6wUmKo6nuQdST4xOweAbbHqFczHk3woya9f6oCqOlVVZ6vq7LN5el/GAbC59gxMVd2c5Hx33/f/Hdfdp7t7p7t3juWSfRsIwGZa5QrmhiTvrKrHk9yW5Maq+tToKgA23p6B6e6PdPfx7r42ybuTfLW73zu+DICN5udgABhx9Pc5uLu/luRrI0sA2CquYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYcXTpAcB2OHnViaUncMC4ggFghMAAMEJgABghMACMEBgARggMACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMOLrKQVX1eJJfJnk+yXPdvTM5CoDNt1Jgdv11d/9kbAkAW8VLZACMWDUwneTLVXVfVZ262AFVdaqqzlbV2Wfz9P4tBGAjrfoS2Zu7+8mq+pMkd1bVI9399Rce0N2nk5xOklfVH/Y+7wRgw6x0BdPdT+7+ej7JHUmunxwFwObbMzBVdXlVvfI3nyd5a5LvTg8DYLOt8hLZa5LcUVW/Of5fu/tLo6sA2Hh7Bqa7v5/kTWvYAsAWcZsyACMEBoARAgPACIEBYITAADBCYAAYITAAjBAYAEYIDAAjBAaAEQIDwAiBAWCEwAAwQmAAGCEwAIwQGABGCAwAIwQGgBECA8AIgQFghMAAMEJgABghMACMEBgARggMACMEBoAR1d37/6RVP07yn/v+xKv54yQ/Wei/fRAc5vM/zOeeHO7zd+7L+dPuvuJiD4wEZklVdba7d5besZTDfP6H+dyTw33+zv1gnruXyAAYITAAjNjGwJxeesDCDvP5H+ZzTw73+Tv3A2jr3oMB4GDYxisYAA4AgQFgxFYFpqpuqqrvVdVjVfXhpfesU1XdWlXnq+q7S29Zt6q6pqruqqqHqurBqrpl6U3rUlWXVtW9VfWd3XP/6NKb1q2qjlTVt6vqC0tvWbeqeryq/qOqzlXV2aX3vNjWvAdTVUeSPJrkLUmeSPKtJO/p7ocWHbYmVfVXSZ5K8i/d/cal96xTVV2Z5Mruvr+qXpnkviR/exj+7Kuqklze3U9V1bEkdye5pbu/ufC0tamqv0+yk+RV3X3z0nvWqaoeT7LT3Qfyh0y36Qrm+iSPdff3u/uZJLcledfCm9amu7+e5GdL71hCd/+ou+/f/fyXSR5OcvWyq9ajL3hq98tjux/b8V3jCqrqeJJ3JPnE0lv4XdsUmKuT/OAFXz+RQ/KPDP+nqq5Ncl2Se5Zdsj67LxGdS3I+yZ3dfWjOPcnHk3woya+XHrKQTvLlqrqvqk4tPebFtikwHHJV9Yoktyf5YHf/Yuk969Ldz3f3iSTHk1xfVYfiJdKqujnJ+e6+b+ktC3pzd/9Fkrcl+bvdl8oPjG0KzJNJrnnB18d3f49DYPf9h9uTfLq7P7f0niV098+T3JXkpqW3rMkNSd65+z7EbUlurKpPLTtpvbr7yd1fzye5IxfeKjgwtikw30ryuqp6bVW9LMm7k3x+4U2swe4b3Z9M8nB3f2zpPetUVVdU1at3P395Ltzk8siyq9ajuz/S3ce7+9pc+P/9q9393oVnrU1VXb57U0uq6vIkb01yoO4i3ZrAdPdzST6Q5EwuvMn72e5+cNlV61NVn0nyjSSvr6onqur9S29aoxuSvC8XvoM9t/vx9qVHrcmVSe6qqgdy4ZusO7v70N2ue0i9JsndVfWdJPcm+bfu/tLCm37L1tymDMDBsjVXMAAcLAIDwAiBAWCEwAAwQmAAGCEwAIwQGABG/A9gCmL7XTdKRwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Convert from 1-index standard format. (Head is 0)\n", + "event = dist.to_event(torch.tensor([[2, 3, 4, 1, 0, 4]]), None)\n", + "show_deps(event[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Non-Projective Dependency Trees" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.NonProjectiveDependencyCRF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGbCAYAAAD5r4b7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAQUklEQVR4nO3db4xlBXnH8edhZtdl1+WfCsqC7CZVGwKt0K1BabWCRi1G0qZNMNGkxkjS+gfFRrFv9JVJU2MlqTUliG+k8mIljVrin1RNtUlXlz+NwkqLoLALsouIyLKyLPv0xU4TCmfYO848e+4dPp+EZGfu5eSXszvznXPvnZmsqgCAlXbM2AMAWJ0EBoAWAgNAC4EBoIXAANBivuOgc+s31JoTTuo49LKc9YK9Y08Y9IOHnj/2hGFz0/kKw7M2PjD2hEG37j157AmLWrNvOv8u84WPjz1h0FxO5/l67ImWT9nLcuD+h+Lgw4/m0G0ta9eccFJsfuflHYdelu+9+x/HnjBoy5cuHXvCoLnjDow9YdB/vuaasScMOvsz7xl7wqJe+L3p/Luc+/D9Y08YdPza/WNPGPTjB6fvi9H/ufyzi97mITIAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQYqLAZOYbM/P2zLwjM6/oHgXA7DtiYDJzLiI+HRFviogzI+KtmXlm9zAAZtskVzCviIg7qurOqjoQEddFxMW9swCYdZMEZlNE3POkt3ctvO//ycxLM3NHZu54Yt++ldoHwIxasSf5q+qqqtpaVVvnNmxYqcMCMKMmCczuiDj9SW+ftvA+AFjUJIH5fkS8JDO3ZObaiLgkIr7UOwuAWTd/pDtU1cHMfE9EfC0i5iLimqq6tX0ZADPtiIGJiKiqGyLihuYtAKwivpMfgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABokVW14gfdeMJp9fLXXLbix12uuccOjT1h0KG/fmDsCYN+ccOpY08YtHHXE2NPGHTva8desLg6ZuU/zlfCmX/7s7EnDNpzwdN+K/xUOObxsRc83W3/+vex74F7cug2VzAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0CL+Y6DHjipYvclj3ccelme97V1Y08YdGD/dO569IU19oRB+0+ezq+LTvmP6TxfEREPXrx/7AmDau2asScMevDs6fy7PPl7Yy8Y8Aynajo/UgGYeQIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWhwxMJl5emZ+KzNvy8xbM/OyozEMgNk2yS8cOxgRH6yqmzJzY0TcmJnfqKrbmrcBMMOOeAVTVfdV1U0Lf/5VROyMiE3dwwCYbUt6DiYzN0fEORGxfeC2SzNzR2bueOJX+1ZmHQAza+LAZOZzI+KLEfH+qnr4qbdX1VVVtbWqts5t3LCSGwGYQRMFJjPXxOG4XFtV1/dOAmA1mORVZBkRn42InVX1yf5JAKwGk1zBnB8Rb4+ICzLzloX//rh5FwAz7ogvU66q70ZEHoUtAKwivpMfgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoccSfpvybWLfrYLzsw3s6Dr0stXH92BMG1bbdY08YtO/1x409YdCjJ8+NPWHQxrv2jz1hUY+sPTj2hGH/NJ3n7KKNN449YdB3fvz7Y094mnqGiriCAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWsx3HPTgcWtj7+te3HHoZXneLb8ce8Kg//7474w9YdAZNxwce8KgY3+eY08Y9OM/P3bsCYv6rT/5r7EnDDrwR+eOPWHQv5+9eewJgz59+T+MPeFp3vXtPYve5goGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFpMHJjMnMvMmzPzK52DAFgdlnIFc1lE7OwaAsDqMlFgMvO0iLgoIq7unQPAajHpFcynIuJDEXGocQsAq8gRA5OZb46IPVV14xHud2lm7sjMHQd/vW/FBgIwmya5gjk/It6SmT+JiOsi4oLM/PxT71RVV1XV1qraOr9uwwrPBGDWHDEwVfWRqjqtqjZHxCUR8c2qelv7MgBmmu+DAaDF/FLuXFXfjohvtywBYFVxBQNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALRY0k9TntShNRH7XpQdh16Wt1x+09gTBv3ou68ee8Kgn//VdP5m0rNPvm/sCYM+dsp3xp6wqH9+1SvHnjBox3Xrxp4w6JjXPjj2hEFXfPAvx57wNLvuuXLR21zBANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBbzHQdde/yB2PKGuzoOvSz/cuVrx54w6PmHxl4wLOuEsScMuvuhjWNPGPR3P9009oRF/eKs48aeMOiRN+wfe8Kgl7717rEnDLrvnS8fe8LTHFq7+G2uYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGgxUSBycwTMnNbZv4oM3dm5iu7hwEw2yb9hWNXRsRXq+rPMnNtRKxv3ATAKnDEwGTm8RHx6oj4i4iIqjoQEQd6ZwEw6yZ5iGxLROyNiM9l5s2ZeXVmbnjqnTLz0szckZk7Djw0nb8GFYCjZ5LAzEfEuRHxmao6JyL2RcQVT71TVV1VVVurauvaE45d4ZkAzJpJArMrInZV1faFt7fF4eAAwKKOGJiq+llE3JOZL1t414URcVvrKgBm3qSvIntvRFy78AqyOyPiHX2TAFgNJgpMVd0SEVubtwCwivhOfgBaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoMWkP65/SX79yNq4ffvmjkMvyxPnHhp7wqAzvjydux55Ucs/j2W777y5sScMWvfS48eesKjTv7x37AmD9p530tgTBj32B2eOPWFQTeGHZOXit7mCAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWsx3HHTNoxGnfP9Qx6GX5ZiDNfaEQetv3zP2hEF7ztk09oRBL7nm/rEnDNp87e6xJyxq+x++eOwJg9Y8un/sCYPW7fr12BMGnbqn5VP2stz98BOL3uYKBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaTBSYzPxAZt6amT/MzC9k5rruYQDMtiMGJjM3RcT7ImJrVZ0VEXMRcUn3MABm26QPkc1HxLGZOR8R6yPi3r5JAKwGRwxMVe2OiE9ExN0RcV9E/LKqvv7U+2XmpZm5IzN3PP7YIyu/FICZMslDZCdGxMURsSUiTo2IDZn5tqfer6quqqqtVbV1zXOeu/JLAZgpkzxE9rqIuKuq9lbV4xFxfUS8qncWALNuksDcHRHnZeb6zMyIuDAidvbOAmDWTfIczPaI2BYRN0XEDxb+n6uadwEw4+YnuVNVfTQiPtq8BYBVxHfyA9BCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALSb6acpLdeKpD8effuwbHYdelut++ntjTxj08lPuHXvCoDPq/rEnDPq3DeeMPWHQox87aewJi3rwohx7wqDn7G35FLRsuz/++NgTBm058edjT3iafNfBRW9zBQNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALQQGABaCAwALQQGgBYCA0ALgQGghcAA0EJgAGghMAC0EBgAWggMAC0EBoAWAgNAC4EBoIXAANBCYABoITAAtBAYAFoIDAAtBAaAFgIDQAuBAaCFwADQQmAAaCEwALTIqlr5g2bujYifrtDhnh8RD6zQsZ4NnK+lcb6WxvlammfD+Tqjql4wdENLYFZSZu6oqq1j75gVztfSOF9L43wtzbP9fHmIDIAWAgNAi1kIzFVjD5gxztfSOF9L43wtzbP6fE39czAAzKZZuIIBYAYJDAAtpjYwmfnGzLw9M+/IzCvG3jPNMvP0zPxWZt6Wmbdm5mVjb5oFmTmXmTdn5lfG3jLtMvOEzNyWmT/KzJ2Z+cqxN027zPzAwsfjDzPzC5m5buxNR9tUBiYz5yLi0xHxpog4MyLemplnjrtqqh2MiA9W1ZkRcV5EvNv5mshlEbFz7BEz4sqI+GpV/XZE/G44b88oMzdFxPsiYmtVnRURcxFxybirjr6pDExEvCIi7qiqO6vqQERcFxEXj7xpalXVfVV108KffxWHP/g3jbtqumXmaRFxUURcPfaWaZeZx0fEqyPisxERVXWgqh4ad9VMmI+IYzNzPiLWR8S9I+856qY1MJsi4p4nvb0rfMKcSGZujohzImL7uEum3qci4kMRcWjsITNgS0TsjYjPLTykeHVmbhh71DSrqt0R8YmIuDsi7ouIX1bV18dddfRNa2D4DWTmcyPiixHx/qp6eOw90yoz3xwRe6rqxrG3zIj5iDg3Ij5TVedExL6I8LzoM8jME+Pwoy5bIuLUiNiQmW8bd9XRN62B2R0Rpz/p7dMW3sciMnNNHI7LtVV1/dh7ptz5EfGWzPxJHH749YLM/Py4k6barojYVVX/d1W8LQ4Hh8W9LiLuqqq9VfV4RFwfEa8aedNRN62B+X5EvCQzt2Tm2jj85NiXRt40tTIz4/Dj4zur6pNj75l2VfWRqjqtqjbH4X9b36yqZ91Xl5Oqqp9FxD2Z+bKFd10YEbeNOGkW3B0R52Xm+oWPzwvjWfjCiPmxBwypqoOZ+Z6I+FocfvXFNVV168izptn5EfH2iPhBZt6y8L6/qaobRtzE6vLeiLh24Qu+OyPiHSPvmWpVtT0zt0XETXH4VZ43x7Pwx8b4UTEAtJjWh8gAmHECA0ALgQGghcAA0EJgAGghMAC0EBgAWvwvHUnHHvk3xm8AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, N = 3, 10, 10\n", + "def show_deps(tree):\n", + " plt.imshow(tree.detach())\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, N)\n", + "dist = torch_struct.NonProjectiveDependencyCRF(log_potentials)\n", + "show_deps(dist.marginals[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Binary Labeled Trees" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + " .. autoclass:: torch_struct.TreeCRF" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAGbCAYAAACRXATDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXOElEQVR4nO3df6xndX3n8edrYZBA+TWOIr/WH4UlcZtK3ckou7bBxfJjYsR22S6k2WJ1M9WtSd2sadg1UeP+s25jTVqMdKpEbVwk1WLJdirM2ibUBNCRDAiKMhIaZkAGxB1U/DX2vX/cM+Zy53vnXu/3MPf9vff5SL75nu/5fL6f855zz9zXnPM98/2kqpAkqZt/ttoFSJI0iQElSWrJgJIktWRASZJaMqAkSS0du9oFTHLs8SfWcSdtHGWsnHZwlHEAjnnwR6ONpbXhX/zyM6ON9Y17TxhtLK0dYx5jX/vBaaONVd8ZJz5+/N2nOPjD72dSWzreZn7CC86p8//dfxllrA1vfGKUcQBO2bpntLG0Ntz66O7Rxrr0zAtGG0trx5jH2Kt3XznaWD/57AtGGefrn/kgzzzxyMSA8hKfJKklA0qS1JIBJUlqyYCSJLVkQEmSWpoqoJJcluTrSfYkuXZC+/OS3DS035XkJdNsT5K0fqw4oJIcA3wIuBx4OXB1kpcv6PYW4DtVdS7wQeD9K92eJGl9meYMaguwp6oeqqofA58CrljQ5wrg48Pyp4GLk0y8312SpPmmCaizgEfmvd47rJvYp6oOAgeA508aLMm2JLuS7Dr4w+9PUZYkaS1oc5NEVW2vqs1VtfnY409c7XIkSatsmoDaB5wz7/XZw7qJfZIcC5wCfHuKbUqS1olpAupLwHlJXprkOOAq4JYFfW4BrhmWrwT+rjp++Z8kqZ0Vfx1tVR1M8nbgVuAY4Iaquj/J+4BdVXUL8FHgL5LsAZ5iLsQkSVrSVN+XXlU7gB0L1r173vIPgX8/zTYkSetTm5skJEmaz4CSJLVkQEmSWmo5o+7J2VivysWjjOWMp5LU1131eZ6up5xRV5I0OwwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLx652Ac+1MadpP7Dj3NHGOmXrntHG0s/nyW0XjjbWhjc+MdpYHhPSs3kGJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKmlFQdUknOS/H2Srya5P8kfTOhzUZIDSXYPj3dPV64kab2YZsLCg8B/raq7k5wEfDnJzqr66oJ+/1BVr59iO5KkdWjFZ1BV9VhV3T0sfxf4GnDWWIVJkta3UaZ8T/IS4FeAuyY0X5jkHuBR4J1Vdf8iY2wDtgEczwljlDW6MafkvvXR3aONNea09mPW9erdV4421pg24DTt0iyYOqCS/ALwGeAdVfX0gua7gRdX1feSbAU+C5w3aZyq2g5sBzg5G2vauiRJs22qu/iSbGAunD5ZVX+1sL2qnq6q7w3LO4ANSTZNs01J0vowzV18AT4KfK2q/niRPi8a+pFky7C9b690m5Kk9WOaS3z/BviPwFeSHPrg4r8D/xygqq4HrgTeluQg8APgqqry8p0kaUkrDqiq+gKQJfpcB1y30m1IktYvv0lCktSSASVJasmAkiS1ZEBJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktSSASVJasmAkiS1ZEBJkloyoCRJLY0y5bt+fl2naf9X733baGNt2n7HaGNJWn88g5IktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJacsr3NWDM6eM37HhitLHYPt5QktYfz6AkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWpp6oBK8nCSryTZnWTXhPYk+ZMke5Lcm+SV025TkrT2jfX/oF5bVU8u0nY5cN7weBXw4eFZkqRFHY1LfFcAn6g5dwKnJjnjKGxXkjTDxgioAm5L8uUk2ya0nwU8Mu/13mHdsyTZlmRXkl0/4UcjlCVJmmVjXOJ7TVXtS/JCYGeSB6rq9p93kKrazvDlOCdnY41QlyRphk19BlVV+4bn/cDNwJYFXfYB58x7ffawTpKkRU0VUElOTHLSoWXgEuC+Bd1uAX5nuJvv1cCBqnpsmu1Kkta+aS/xnQ7cnOTQWP+7qj6X5K0AVXU9sAPYCuwBngF+d8ptSpLWgakCqqoeAl4xYf3185YL+P1ptiNJWn/8JglJUksGlCSpJQNKktSSU77rWU7Zume0sZ7cduFoY23afsdoY0maDZ5BSZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS055bueM2NO0+708dL64xmUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJZWHFBJzk+ye97j6STvWNDnoiQH5vV59/QlS5LWgxVPWFhVXwcuAEhyDLAPuHlC13+oqtevdDuSpPVprEt8FwPfrKp/HGk8SdI6N9aU71cBNy7SdmGSe4BHgXdW1f2TOiXZBmwDOJ4TRipLa8WY07Qf2HHuaGOdsnXPaGNJerapz6CSHAe8AfjLCc13Ay+uqlcAfwp8drFxqmp7VW2uqs0beN60ZUmSZtwYl/guB+6uqscXNlTV01X1vWF5B7AhyaYRtilJWuPGCKirWeTyXpIXJcmwvGXY3rdH2KYkaY2b6jOoJCcCvw783rx1bwWoquuBK4G3JTkI/AC4qqpqmm1KktaHqQKqqr4PPH/BuuvnLV8HXDfNNiRJ65PfJCFJasmAkiS1ZEBJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktSSASVJasmAkiS1ZEBJkloyoCRJLRlQkqSWxpryXZoZY07T/uS2C0cba8xp7aW1wDMoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSWnfJemMOY07bc+unu0sS4984LRxtLqOrDj3NHGOmXrntHGOho8g5IktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqaVlBVSSG5LsT3LfvHUbk+xM8uDwfNoi771m6PNgkmvGKlyStLYt9wzqY8BlC9ZdC3y+qs4DPj+8fpYkG4H3AK8CtgDvWSzIJEmab1kBVVW3A08tWH0F8PFh+ePAGye89VJgZ1U9VVXfAXZyeNBJknSYab5J4vSqemxY/hZw+oQ+ZwGPzHu9d1h3mCTbgG0Ax3PCFGVJktaCUW6SqKoCasoxtlfV5qravIHnjVGWJGmGTRNQjyc5A2B43j+hzz7gnHmvzx7WSZJ0RNME1C3AobvyrgH+ekKfW4FLkpw23BxxybBOkqQjWu5t5jcCdwDnJ9mb5C3A/wR+PcmDwOuG1yTZnOQjAFX1FPA/gC8Nj/cN6yRJOqJl3SRRVVcv0nTxhL67gP807/UNwA0rqk6StG75TRKSpJYMKElSSwaUJKklA0qS1NI03yQhaUSXnnnBaGM9ue3C0cYC2LT9jtHGGrO2DW98YrSx7rzg06ONNa7do410KeMdY0eDZ1CSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSS6mq1a7hMCdnY70qF692GZIGtz463rTjY7r0zNmawlyHu6s+z9P1VCa1eQYlSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqaUlAyrJDUn2J7lv3ro/SvJAknuT3Jzk1EXe+3CSryTZnWTXmIVLkta25ZxBfQy4bMG6ncAvVdUvA98A/tsR3v/aqrqgqjavrERJ0nq0ZEBV1e3AUwvW3VZVB4eXdwJnPwe1SZLWsWNHGOPNwE2LtBVwW5IC/qyqti82SJJtwDaA4zlhhLIkjWXMqdWf3HbhaGNt4o7RxlI/UwVUkncBB4FPLtLlNVW1L8kLgZ1JHhjOyA4zhNd2gJOzsaapS5I0+1Z8F1+SNwGvB367qiYGSlXtG573AzcDW1a6PUnS+rKigEpyGfCHwBuq6plF+pyY5KRDy8AlwH2T+kqStNBybjO/EbgDOD/J3iRvAa4DTmLust3uJNcPfc9MsmN46+nAF5LcA3wR+Juq+txz8qeQJK05S34GVVVXT1j90UX6PgpsHZYfAl4xVXWSpHXLb5KQJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSS2NM+S5Jy7Zp+3jTtB/Yce5oY52ydc9oY2kcnkFJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktSSASVJasmAkiS1ZEBJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktSSASVJasmAkiS1ZEBJkloyoCRJLTnlu6SZNeY07U4f349nUJKklgwoSVJLBpQkqSUDSpLUkgElSWppyYBKckOS/Unum7fuvUn2Jdk9PLYu8t7Lknw9yZ4k145ZuCRpbVvOGdTHgMsmrP9gVV0wPHYsbExyDPAh4HLg5cDVSV4+TbGSpPVjyYCqqtuBp1Yw9hZgT1U9VFU/Bj4FXLGCcSRJ69A0n0G9Pcm9wyXA0ya0nwU8Mu/13mHdREm2JdmVZNdP+NEUZUmS1oKVBtSHgV8ELgAeAz4wbSFVtb2qNlfV5g08b9rhJEkzbkUBVVWPV9VPq+qfgD9n7nLeQvuAc+a9PntYJ0nSklYUUEnOmPfyN4D7JnT7EnBekpcmOQ64CrhlJduTJK0/S35ZbJIbgYuATUn2Au8BLkpyAVDAw8DvDX3PBD5SVVur6mCStwO3AscAN1TV/c/Jn0KStOYsGVBVdfWE1R9dpO+jwNZ5r3cAh92CLknSUvwmCUlSSwaUJKklA0qS1JIBJUlqySnfJYlxp2m/9dHdo4116ZkXjDbWrPEMSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJKd8laWRjTtN+YMe5o4015rT2R4NnUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaWnI+qCQ3AK8H9lfVLw3rbgLOH7qcCvy/qjpsApQkDwPfBX4KHKyqzSPVLUla45YzYeHHgOuATxxaUVX/4dBykg8AB47w/tdW1ZMrLVCStD4tGVBVdXuSl0xqSxLgt4B/O25ZkqT1btop338VeLyqHlykvYDbkhTwZ1W1fbGBkmwDtgEczwlTliVJa8OY07TP2vTx0wbU1cCNR2h/TVXtS/JCYGeSB6rq9kkdh/DaDnByNtaUdUmSZtyK7+JLcizwm8BNi/Wpqn3D837gZmDLSrcnSVpfprnN/HXAA1W1d1JjkhOTnHRoGbgEuG+K7UmS1pElAyrJjcAdwPlJ9iZ5y9B0FQsu7yU5M8mO4eXpwBeS3AN8EfibqvrceKVLktay5dzFd/Ui6980Yd2jwNZh+SHgFVPWJ0lap/wmCUlSSwaUJKklA0qS1JIBJUlqyYCSJLVkQEmSWjKgJEktGVCSpJYMKElSSwaUJKklA0qS1JIBJUlqyYCSJLU07Yy6kqQZMeY07bc+unuUcbZc+syibZ5BSZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS0ZUJKklgwoSVJLBpQkqSUDSpLUkgElSWrJgJIktWRASZJaMqAkSS2lqla7hsMkeQL4xyW6bQKePArlPFdmuf5Zrh1mu/5Zrh1mu/5Zrh361v/iqnrBpIaWAbUcSXZV1ebVrmOlZrn+Wa4dZrv+Wa4dZrv+Wa4dZrN+L/FJkloyoCRJLc1yQG1f7QKmNMv1z3LtMNv1z3LtMNv1z3LtMIP1z+xnUJKktW2Wz6AkSWuYASVJaql9QCW5LMnXk+xJcu2E9ucluWlovyvJS45+lYdLck6Sv0/y1ST3J/mDCX0uSnIgye7h8e7VqHUxSR5O8pWhtl0T2pPkT4Z9f2+SV65GnZMkOX/eft2d5Okk71jQp83+T3JDkv1J7pu3bmOSnUkeHJ5PW+S91wx9HkxyzdGr+lk1TKr/j5I8MBwbNyc5dZH3HvE4e64tUvt7k+ybd2xsXeS9R/z9dDQsUv9N82p/OMnuRd67qvt+SVXV9gEcA3wTeBlwHHAP8PIFff4zcP2wfBVw02rXPdRyBvDKYfkk4BsTar8I+D+rXesR/gwPA5uO0L4V+FsgwKuBu1a75iMcR99i7j8Ettz/wK8BrwTum7fufwHXDsvXAu+f8L6NwEPD82nD8mlN6r8EOHZYfv+k+pdznK1S7e8F3rmM4+qIv59Wq/4F7R8A3t1x3y/16H4GtQXYU1UPVdWPgU8BVyzocwXw8WH508DFSXIUa5yoqh6rqruH5e8CXwPOWt2qRncF8ImacydwapIzVruoCS4GvllVS307yaqpqtuBpxasnn9sfxx444S3XgrsrKqnquo7wE7gsues0EVMqr+qbquqg8PLO4Gzj3Zdy7HIvl+O5fx+es4dqf7hd+FvATce1aJG0j2gzgIemfd6L4f/kv9Zn+EvwwHg+UelumUaLjv+CnDXhOYLk9yT5G+T/MujWtjSCrgtyZeTbJvQvpyfTwdXsfhf0M77//SqemxY/hZw+oQ+s/IzeDNzZ9uTLHWcrZa3D5cnb1jk8uos7PtfBR6vqgcXae+674H+ATXzkvwC8BngHVX19ILmu5m77PQK4E+Bzx7t+pbwmqp6JXA58PtJfm21C/p5JTkOeAPwlxOau+//n6m56zEz+X9CkrwLOAh8cpEuHY+zDwO/CFwAPMbcZbJZdDVHPnvquO9/pntA7QPOmff67GHdxD5JjgVOAb59VKpbQpINzIXTJ6vqrxa2V9XTVfW9YXkHsCHJpqNc5qKqat/wvB+4mblLGvMt5+ez2i4H7q6qxxc2dN//wOOHLpkOz/sn9Gn9M0jyJuD1wG8PIXuYZRxnR11VPV5VP62qfwL+fJGauu/7Y4HfBG5arE/HfT9f94D6EnBekpcO/xK+CrhlQZ9bgEN3Ll0J/N1ifxGOpuHa70eBr1XVHy/S50WHPi9LsoW5n0eXcD0xyUmHlpn7wPu+Bd1uAX5nuJvv1cCBeZekulj0X5Cd9/9g/rF9DfDXE/rcClyS5LThMtQlw7pVl+Qy4A+BN1TVM4v0Wc5xdtQt+Cz1N5hc03J+P62m1wEPVNXeSY1d9/2zrPZdGks9mLtT7BvM3S3zrmHd+5g76AGOZ+7yzR7gi8DLVrvmoa7XMHdJ5l5g9/DYCrwVeOvQ5+3A/czd/XMn8K9Xu+559b9sqOueocZD+35+/QE+NPxsvgJsXu26F/wZTmQucE6Zt67l/mcuRB8DfsLcZxlvYe6z1M8DDwL/F9g49N0MfGTee988HP97gN9tVP8e5j6jOXT8H7rb9kxgx5GOswa1/8VwTN/LXOicsbD24fVhv5861D+s/9ihY31e31b7fqmHX3UkSWqp+yU+SdI6ZUBJkloyoCRJLRlQkqSWDChJUksGlCSpJQNKktTS/weE0YBC5HWm3wAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "batch, N, NT = 3, 20, 3\n", + "def show_tree(tree):\n", + " t = tree.detach()\n", + " plt.imshow(t[ :, : , 0] + \n", + " 2 * t[ :,:, 1] +\n", + " 3 * t[ :,:, 2])\n", + "\n", + "# batch, N, z_n, z_n_1\n", + "log_potentials = torch.rand(batch, N, N, NT)\n", + "dist = torch_struct.TreeCRF(log_potentials)\n", + "show_tree(dist.argmax[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAGbCAYAAACRXATDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaK0lEQVR4nO3dcZSddX3n8c8nySQBJEJkjZCgIEbOYdtCaTbollosNoYcDrQ96iaHtqipEVtYcbXK6jlo3T9ca9VVYElTyYKuRY62aM42VbLWs9RTQSMnQCJoQhpLQiBqNAETkszku3/ME89lcu/MOPfLzPfOvF/nzJl7n+d3v/c7zzxzP/M888z9OSIEAEA10ya6AQAA2iGgAAAlEVAAgJIIKABASQQUAKCkGRPdQDtz5s6IefP7Umrt3nVaSh1Jmn7gSFotDRzNq3U0sZbzSqWqerGpEzfYVLmidlrmNssrldpX4n4RM/Nepo/25R2TLFiwJ6XOkzv79bO9A203WMmAmje/T5/4yitSan3kfVen1JGkOQ/9KK2W9j+TVioOHEyr5cwX3EQxMDDRLbTlmTPTakV/f1qt1F9aknn2rLxiifuFZ89Oq6W+nF+wJenIS/N+yX5mQd7X+LGP/M+UOm+7YmfHdZziAwCUREABAEoioAAAJRFQAICSCCgAQEldBZTtpba/b3ub7RvarJ9l+65m/f22z+rm+QAAU8eYA8r2dEm3SLpM0nmSVtg+b8iwlZJ+GhGvkPRJSR8d6/MBAKaWbo6gFkvaFhHbI+KwpC9IunLImCsl3dHc/pKkS131H20AAKV0E1DzJT3ecn9ns6ztmIjol7RP0ovaFbO9yvZG2xv37a35T5kAgPFT5iKJiFgTEYsiYtEL506f6HYAABOsm4DaJenMlvsLmmVtx9ieIemFkn7SxXMCAKaIbgLqO5IW2j7b9kxJyyWtGzJmnaRjb4b3Bkn/FMwxDwAYhTG/WWxE9Nu+VtLXJE2XtDYittj+sKSNEbFO0m2SPmd7m6S9GgwxAABG1NW7mUfEeknrhyy7seX2s5Le2M1zAACmpjIXSQAA0IqAAgCUREABAEpyxYvq5nhuXORLU2p9bMd9KXUk6d1/eE1arb4f5c2oq7378mplOpr4D9fTiv5vXOTNXBuHDqfVypxpNptnJE7kPT1vv/CsvNmRM2fUjRMSZyCekbe9fn7OKSl1Nv2/T+npn+1s+w5DHEEBAEoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJSXOvVzTn5/1qrRaq394c1qtle94V1qtE4/mTTvuI/1ptdSfOO145hTmmVOO9ydur4G876OmJ//umTi1utx2du8x1kr8OhOnaVdf4j4WUbLWiY8/k1Jn2uHO+z1HUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKGnNA2T7T9jdsf8/2FtvvbDPmEtv7bG9qPm7srl0AwFTRzaxa/ZLeHREP2D5Z0ndtb4iI7w0Z988RcXkXzwMAmILGfAQVEbsj4oHm9tOSHpE0P6sxAMDUljIvse2zJP26pPvbrH617QclPSHpPRGxpUONVZJWSdJsnZjRVro/fdnFabU27FqdVmvJW9+eVmvGs3lTq/ft3p9WS5nT2idOrR7PHkqrlSkOH0mumFfPs2al1VLiTPRx8GBesQOJ07TbeaVmz86r9UxSXwOdX3O6DijbL5D0d5Kuj4ihr0gPSHpZRDxje5mkL0ta2K5ORKyRtEaS5nhu4ncXANCLurqKz3afBsPp8xHx90PXR8T+iHimub1eUp/t07p5TgDA1NDNVXyWdJukRyLiEx3GvKQZJ9uLm+f7yVifEwAwdXRziu83Jf2RpIdtb2qWvV/SSyUpIlZLeoOkd9jul3RQ0vKI4PQdAGBEYw6oiPimpGH/ShYRN0u6eazPAQCYungnCQBASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJSUMuU7fnmXz/+NtFpf3nlTWq3X3/CutFqn/rQvrZYPJE6tnjlNe39/Xq1EnpH8o92XWC9xCvNMcehwWi1nfo2Z2z5xf02bOelo5zocQQEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlMeX7JPCmBa9Oq/Wlf/urtFrL3/XutFov+OHP02pNHziaVksHD+bVShQDA7kFE6cKz5yOPv3rTBLT8n73T5w8XnHkSF6xI0n7RHT+eeQICgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkroOKNs7bD9se5PtjW3W2/anbW+z/ZDtC7t9TgDA5Jf1DwmvjYgfd1h3maSFzcdFkm5tPgMA0NF4nOK7UtJnY9B9kk6xffo4PC8AoIdlBFRIusf2d22varN+vqTHW+7vbJY9h+1Vtjfa3nhEhxLaAgD0soxTfBdHxC7bL5a0wfajEXHvL1skItZIWiNJczw3EvoCAPSwro+gImJX83mPpLslLR4yZJekM1vuL2iWAQDQUVcBZfsk2ycfuy1piaTNQ4atk/THzdV8r5K0LyJ2d/O8AIDJr9tTfPMk3W37WK2/jYiv2r5GkiJitaT1kpZJ2ibpgKS3dPmcAIApoKuAiojtks5vs3x1y+2Q9GfdPA8AYOrhnSQAACURUACAkggoAEBJTPmO51j50ovTar136+fSan3yuqvSap2wP2+adp90UlotJU4Trn3782pJ0vTpebWGmeL7l5Y543vmNO1OnKg9c1r7o4n/Yjot62vsXIcjKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkpnzH8+aWha9Mq/WaB/8lrdbXP5w3rf3Jjz2dVssHD6fV0oEDebWkutOOR2KtRJHZ10DNr9FHsyp1/vo4ggIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQ0pgDyva5tje1fOy3ff2QMZfY3tcy5sbuWwYATAVjnrAwIr4v6QJJsj1d0i5Jd7cZ+s8RcflYnwcAMDVlneK7VNJjEfHDpHoAgCkua8r35ZLu7LDu1bYflPSEpPdExJZ2g2yvkrRKkmbrxKS2MFncd35fWq3//IMvpNW69bo3ptU64bGDabVk59WS5Nmz02rFQNpc4XJf1kuYpBmJtQYGEmvlbS9NS9wvjmZNRd+5p66PoGzPlHSFpC+2Wf2ApJdFxPmSbpL05U51ImJNRCyKiEV9mtVtWwCAHpdxiu8ySQ9ExFNDV0TE/oh4prm9XlKf7dMSnhMAMMllBNQKdTi9Z/sl9uC5BtuLm+f7ScJzAgAmua5Outo+SdLvSnp7y7JrJCkiVkt6g6R32O6XdFDS8ojIOnEJAJjEugqoiPi5pBcNWba65fbNkm7u5jkAAFMT7yQBACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkhLnOAZ6w22vPDut1ru2/m1arb/68z9Mq/WCR6an1ZKkSJzC3D/Pm9o+juZNYe5peb+vp84p1N+fWCxxv0jbJzpvLY6gAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJKY8h3owi0LX5lW69P/elNarXded11aLUk64ckDabWmOXGa9oOH0mql6s+aDl3S0bwJ5ONIwe01zJfHERQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACWNKqBsr7W9x/bmlmVzbW+wvbX5fGqHx17djNlq++qsxgEAk9toj6Bul7R0yLIbJH09IhZK+npz/zlsz5X0QUkXSVos6YOdggwAgFajCqiIuFfS3iGLr5R0R3P7Dkm/1+ahr5e0ISL2RsRPJW3Q8UEHAMBxunkniXkRsbu5/aSkeW3GzJf0eMv9nc2y49heJWmVJM3WiV20BQCYDFIukoiI0LBvWDGqGmsiYlFELOrTrIy2AAA9rJuAesr26ZLUfN7TZswuSWe23F/QLAMAYFjdBNQ6Sceuyrta0lfajPmapCW2T20ujljSLAMAYFijvcz8TknfknSu7Z22V0r675J+1/ZWSa9r7sv2ItufkaSI2Cvpv0n6TvPx4WYZAADDGtVFEhGxosOqS9uM3SjpT1rur5W0dkzdAQCmLN5JAgBQEgEFACiJgAIAlERAAQBK6uadJAAket/ZF6XV+ovHPpNWS5Le//5VabXmbHdaLZ8wM6/Wof68WtPzfvfv6h0QhvCRw2m1YuBoWq1OOIICAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASiKgAAAlEVAAgJIIKABASQQUAKAkAgoAUBIBBQAoiYACAJREQAEASmLKd2AS+sg5v5Za7xu7bkqr9dvvuy6t1ilbnk6rNe3Is2m14tlDibXy+tLAQF6tccARFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSRgwo22tt77G9uWXZx2w/avsh23fbPqXDY3fYftj2JtsbMxsHAExuozmCul3S0iHLNkj6lYj4NUk/kPRfh3n8ayPigohYNLYWAQBT0YgBFRH3Sto7ZNk9EdHf3L1P0oLnoTcAwBSWMeX7WyXd1WFdSLrHdkj664hY06mI7VWSVknSbJ2Y0BaALFfM/w9ptW7Y+r/Tan3q2uVptU44dCStlg8dTqulGRkv0w07rVSMw/TxXX3ltj8gqV/S5zsMuTgidtl+saQNth9tjsiO04TXGkma47nRTV8AgN435qv4bL9Z0uWSroqItoESEbuaz3sk3S1p8VifDwAwtYwpoGwvlfReSVdExIEOY06yffKx25KWSNrcbiwAAEON5jLzOyV9S9K5tnfaXinpZkkna/C03Sbbq5uxZ9he3zx0nqRv2n5Q0rcl/UNEfPV5+SoAAJPOiH+DiogVbRbf1mHsE5KWNbe3Szq/q+4AAFMW7yQBACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkhLnEgaAkd268BVptb74+KfSar3xbden1TrhX4+m1VJ/f1ope2ZarbS+nu08DT1HUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJTPkOoGdddeZvptX63L99Mq3WW666Lq3WzEOH02ppWt4xiQ8cTCrElO8AgB5DQAEASiKgAAAlEVAAgJIIKABASSMGlO21tvfY3tyy7EO2d9ne1Hws6/DYpba/b3ub7RsyGwcATG6jOYK6XdLSNss/GREXNB/rh660PV3SLZIuk3SepBW2z+umWQDA1DFiQEXEvZL2jqH2YknbImJ7RByW9AVJV46hDgBgCurmb1DX2n6oOQV4apv18yU93nJ/Z7OsLdurbG+0vfGIDnXRFgBgMhhrQN0q6RxJF0jaLenj3TYSEWsiYlFELOrTrG7LAQB63JgCKiKeioiBiDgq6W80eDpvqF2Szmy5v6BZBgDAiMYUULZPb7n7+5I2txn2HUkLbZ9te6ak5ZLWjeX5AABTz4hvFmv7TkmXSDrN9k5JH5R0ie0LJIWkHZLe3ow9Q9JnImJZRPTbvlbS1yRNl7Q2IrY8L18FAGDSGTGgImJFm8W3dRj7hKRlLffXSzruEnQAAEbCO0kAAEoioAAAJRFQAICSCCgAQElM+Q4Akla+9OK0Whue+F9ptX571aq0WjOfPpJXa1vWVPRM+Q4A6DEEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSCCgAQEkEFACgJKZ8B4Bkrz/jgrRaNz52W1qtv/gvK9Nq9Z0wK6fQNKZ8BwD0GAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSRpwPyvZaSZdL2hMRv9Isu0vSuc2QUyT9LCKOmwDF9g5JT0sakNQfEYuS+gYATHKjmbDwdkk3S/rssQUR8Z+O3bb9cUn7hnn8ayPix2NtEAAwNY0YUBFxr+2z2q2zbUlvkvQ7uW0BAKa6bqd8/y1JT0XE1g7rQ9I9tkPSX0fEmk6FbK+StEqSZuvELtsCgMnhL8/51bRaH9m+Oq3WjW97W0qdo092jqFuA2qFpDuHWX9xROyy/WJJG2w/GhH3thvYhNcaSZrjudFlXwCAHjfmq/hsz5D0B5Lu6jQmInY1n/dIulvS4rE+HwBgaunmMvPXSXo0Ina2W2n7JNsnH7staYmkzV08HwBgChkxoGzfKelbks61vdP2ymbVcg05vWf7DNvrm7vzJH3T9oOSvi3pHyLiq3mtAwAms9Fcxbeiw/I3t1n2hKRlze3tks7vsj8AwBTFO0kAAEoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKImAAgCUREABAEoioAAAJRFQAICSCCgAQEkEFACgJEfUm7x2jufGRb50otsAAHTwP3b8S0qdN13+I2156LDbreMICgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACiJgAIAlERAAQBKIqAAACURUACAkggoAEBJBBQAoCQCCgBQEgEFACip5JTvtn8k6YcjDDtN0o/HoZ3nSy/338u9S73dfy/3LvV2/73cu1S3/5dFxL9rt6JkQI2G7Y0RsWii+xirXu6/l3uXerv/Xu5d6u3+e7l3qTf75xQfAKAkAgoAUFIvB9SaiW6gS73cfy/3LvV2/73cu9Tb/fdy71IP9t+zf4MCAExuvXwEBQCYxAgoAEBJ5QPK9lLb37e9zfYNbdbPsn1Xs/5+22eNf5fHs32m7W/Y/p7tLbbf2WbMJbb32d7UfNw4Eb12YnuH7Yeb3ja2WW/bn262/UO2L5yIPtuxfW7Ldt1ke7/t64eMKbP9ba+1vcf25pZlc21vsL21+Xxqh8de3YzZavvq8ev6OT206/9jth9t9o27bZ/S4bHD7mfPtw69f8j2rpZ9Y1mHxw77+jQeOvR/V0vvO2xv6vDYCd32I4qIsh+Spkt6TNLLJc2U9KCk84aM+VNJq5vbyyXdNdF9N72cLunC5vbJkn7QpvdLJP2fie51mK9hh6TThlm/TNI/SrKkV0m6f6J7HmY/elKD/xBYcvtLeo2kCyVtbln2l5JuaG7fIOmjbR43V9L25vOpze1Ti/S/RNKM5vZH2/U/mv1sgnr/kKT3jGK/Gvb1aaL6H7L+45JurLjtR/qofgS1WNK2iNgeEYclfUHSlUPGXCnpjub2lyRdatvj2GNbEbE7Ih5obj8t6RFJ8ye2q3RXSvpsDLpP0im2T5/optq4VNJjETHSu5NMmIi4V9LeIYtb9+07JP1em4e+XtKGiNgbET+VtEHS0uet0Q7a9R8R90REf3P3PkkLxruv0eiw7UdjNK9Pz7vh+m9eC98k6c5xbSpJ9YCaL+nxlvs7dfyL/C/GND8M+yS9aFy6G6XmtOOvS7q/zepX237Q9j/a/vfj2tjIQtI9tr9re1Wb9aP5/lSwXJ1/QCtv/3kRsbu5/aSkeW3G9Mr34K0aPNpuZ6T9bKJc25yeXNvh9GovbPvfkvRURGztsL7qtpdUP6B6nu0XSPo7SddHxP4hqx/Q4Gmn8yXdJOnL493fCC6OiAslXSbpz2y/ZqIb+mXZninpCklfbLO6+vb/hRg8H9OT/xNi+wOS+iV9vsOQivvZrZLOkXSBpN0aPE3Wi1Zo+KOnitv+F6oH1C5JZ7bcX9AsazvG9gxJL5T0k3HpbgS2+zQYTp+PiL8fuj4i9kfEM83t9ZL6bJ82zm12FBG7ms97JN2twVMarUbz/Zlol0l6ICKeGrqi+vaX9NSxU6bN5z1txpT+Hth+s6TLJV3VhOxxRrGfjbuIeCoiBiLiqKS/6dBT9W0/Q9IfSLqr05iK275V9YD6jqSFts9ufhNeLmndkDHrJB27cukNkv6p0w/CeGrO/d4m6ZGI+ESHMS859vcy24s1+P2oEq4n2T752G0N/sF785Bh6yT9cXM136sk7Ws5JVVFx98gK2//Ruu+fbWkr7QZ8zVJS2yf2pyGWtIsm3C2l0p6r6QrIuJAhzGj2c/G3ZC/pf6+2vc0mtenifQ6SY9GxM52K6tu++eY6Ks0RvrQ4JViP9Dg1TIfaJZ9WIM7vSTN1uDpm22Svi3p5RPdc9PXxRo8JfOQpE3NxzJJ10i6phlzraQtGrz65z5J/3Gi+27p/+VNXw82PR7b9q39W9ItzffmYUmLJrrvIV/DSRoMnBe2LCu5/TUYorslHdHg3zJWavBvqV+XtFXS/5U0txm7SNJnWh771mb/3ybpLYX636bBv9Ec2/+PXW17hqT1w+1nBXr/XLNPP6TB0Dl9aO/N/eNenyr03yy//di+3jK21LYf6YO3OgIAlFT9FB8AYIoioAAAJRFQAICSCCgAQEkEFACgJAIKAFASAQUAKOn/A0SyQUomcgOVAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "show_tree(dist.marginals[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Probabilistic Context-Free Grammars" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + " .. autoclass:: torch_struct.SentCFG" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Autoregressive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " .. autoclass:: torch_struct.Autoregressive" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Base Class" + ] + }, + { + "cell_type": "raw", + "metadata": { + "raw_mimetype": "text/restructuredtext" + }, + "source": [ + ".. autoclass:: torch_struct.StructDistribution\n", + " :members: " + ] + } + ], + "metadata": { + "celltoolbar": "Raw Cell Format", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/networks.rst b/docs/source/networks.rst new file mode 100644 index 00000000..81a5a2ab --- /dev/null +++ b/docs/source/networks.rst @@ -0,0 +1,24 @@ +================= +Networks and Data +================ + +Networks +================== + +Common structured networks. + + +.. autoclass:: torch_struct.networks.TreeLSTM + +.. autoclass:: torch_struct.networks.NeuralCFG + +.. autoclass:: torch_struct.networks.SpanLSTM + + +Datasets +======== + +Datasets for common structured prediction tasks. + +.. autoclass:: torch_struct.data.ConllXDataset +.. autoclass:: torch_struct.data.ListOpsDataset diff --git a/docs/source/refs.rst b/docs/source/refs.rst new file mode 100644 index 00000000..f2153d73 --- /dev/null +++ b/docs/source/refs.rst @@ -0,0 +1,4 @@ +References +========== + +.. bibliography:: refs.bib diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index c9d49af3..aeb1ebd0 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -196,6 +196,28 @@ class LinearChainCRF(StructDistribution): struct = LinearChain +class HMM(StructDistribution): + r""" + Represents hidden-markov smoothing with C hidden states. + + Event shape is of the form: + + Parameters: + transition: C X C + emission: V x C + init: C + observations: b x N between [0, V-1] + + Compact representation: N long tensor in [0, ..., C-1] + """ + + def __init__(self, transition, emission, init, observations, lengths=None): + log_potentials = HMM.struct.hmm(transition, emission, init, observations) + super().__init__(log_potentials, lengths) + + struct = LinearChain + + class SemiMarkovCRF(StructDistribution): r""" Represents a semi-markov or segmental CRF with C classes of max width K diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index ae27983f..0da7fdeb 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -1,3 +1,22 @@ +r""" + +A linear-chain dynamic program. + +Considers parameterized functions of the form :math:`f: {\cal Y} \rightarrow \mathbb{R}`. + +Combinatorial set :math:`{y_{1:N} \in \cal Y}` with each :math:`y_n \in {1, \ldots, C}` + +Function factors as :math:`f(y) = \prod_{n=1}^N \phi(n, y_n, y_n{-1})` + +Example use cases: + +* Part-of-Speech Tagging +* Sequence Labeling +* Hidden Markov Models + +""" + + import torch from .helpers import _Struct import math