Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
added PKM notebook implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Lample committed Aug 26, 2019
1 parent a60f46d commit 46d6595
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 0 deletions.
327 changes: 327 additions & 0 deletions PKM-layer.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Product-Key Memory (PKM)\n",
"**Minimalist implementation of a Product-Key Memory layer** https://arxiv.org/abs/1907.05242\n",
"\n",
"This notebook contains a simple implementation of a PKM layer.\n",
"<br>\n",
"Overall, the PKM layer can be seen as a network with very high capacity that maps elements from $R^d$ to $R^n$, but very efficiently.\n",
"<br>\n",
"In particular, a 12-layer transformer model that leverages a PKM layer outperforms a 24-layer model without memory, and is almost twice faster at inference.\n",
"\n",
"A more detailed implementation can be found at https://github.com/facebookresearch/XLM/tree/master/src/model/memory,\n",
"with options to make the query network more powerful, to shuffle the key indices, to compute the value scores differently\n",
"than with a softmax, etc., but the code below is much simpler and implements a configuration that worked well in our experiments (and that we used to report the majority of our results).\n",
"\n",
"#### Note: at training time, we recommend to use a different optimizer for the values, as these are learned with sparse updates. In particular, we obtained our best performance with the Adam optimizer, and a constant learning rate of 1e-4 to learn the values, independently of the optimizer / learning rate used to learn the rest of the network."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def get_uniform_keys(n_keys, dim, seed):\n",
" \"\"\"\n",
" Generate random uniform keys (same initialization as nn.Linear).\n",
" \"\"\"\n",
" rng = np.random.RandomState(seed)\n",
" bound = 1 / math.sqrt(dim)\n",
" keys = rng.uniform(-bound, bound, (n_keys, dim))\n",
" return keys.astype(np.float32)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class HashingMemory(nn.Module):\n",
"\n",
" def __init__(self, input_dim, output_dim, params):\n",
"\n",
" super().__init__()\n",
"\n",
" # global parameters\n",
" self.input_dim = input_dim\n",
" self.output_dim = output_dim\n",
" self.k_dim = params.k_dim\n",
" self.v_dim = output_dim\n",
" self.n_keys = params.n_keys\n",
" self.size = self.n_keys ** 2\n",
" self.heads = params.heads\n",
" self.knn = params.knn\n",
" assert self.k_dim >= 2 and self.k_dim % 2 == 0\n",
"\n",
" # dropout\n",
" self.input_dropout = params.input_dropout\n",
" self.query_dropout = params.query_dropout\n",
" self.value_dropout = params.value_dropout\n",
"\n",
" # initialize keys / values\n",
" self.initialize_keys()\n",
" self.values = nn.EmbeddingBag(self.size, self.v_dim, mode='sum', sparse=params.sparse)\n",
" nn.init.normal_(self.values.weight, mean=0, std=self.v_dim ** -0.5)\n",
"\n",
" # query network\n",
" self.query_proj = nn.Sequential(*filter(None, [\n",
" nn.Linear(self.input_dim, self.heads * self.k_dim, bias=True),\n",
" nn.BatchNorm1d(self.heads * self.k_dim) if params.query_batchnorm else None\n",
" ]))\n",
"\n",
" if params.query_batchnorm:\n",
" print(\"WARNING: Applying batch normalization to queries improves the performance \"\n",
" \"and memory usage. But if you use it, be sure that you use batches of \"\n",
" \"sentences with the same size at training time (i.e. without padding). \"\n",
" \"Otherwise, the padding token will result in incorrect mean/variance \"\n",
" \"estimations in the BatchNorm layer.\\n\")\n",
"\n",
" def initialize_keys(self):\n",
" \"\"\"\n",
" Create two subkey sets per head.\n",
" `self.keys` is of shape (heads, 2, n_keys, k_dim // 2)\n",
" \"\"\"\n",
" half = self.k_dim // 2\n",
" keys = nn.Parameter(torch.from_numpy(np.array([\n",
" get_uniform_keys(self.n_keys, half, seed=(2 * i + j))\n",
" for i in range(self.heads)\n",
" for j in range(2)\n",
" ])).view(self.heads, 2, self.n_keys, half))\n",
" self.keys = nn.Parameter(keys)\n",
"\n",
" def _get_indices(self, query, subkeys):\n",
" \"\"\"\n",
" Generate scores and indices for a specific head.\n",
" \"\"\"\n",
" assert query.dim() == 2 and query.size(1) == self.k_dim\n",
" bs = query.size(0)\n",
" knn = self.knn\n",
" half = self.k_dim // 2\n",
" n_keys = len(subkeys[0])\n",
"\n",
" # split query for product quantization\n",
" q1 = query[:, :half] # (bs,half)\n",
" q2 = query[:, half:] # (bs,half)\n",
"\n",
" # compute indices with associated scores\n",
" scores1 = F.linear(q1, subkeys[0], bias=None) # (bs,n_keys)\n",
" scores2 = F.linear(q2, subkeys[1], bias=None) # (bs,n_keys)\n",
" scores1, indices1 = scores1.topk(knn, dim=1) # (bs,knn)\n",
" scores2, indices2 = scores2.topk(knn, dim=1) # (bs,knn)\n",
"\n",
" # cartesian product on best candidate keys\n",
" all_scores = (\n",
" scores1.view(bs, knn, 1).expand(bs, knn, knn) +\n",
" scores2.view(bs, 1, knn).expand(bs, knn, knn)\n",
" ).view(bs, -1) # (bs,knn**2)\n",
" all_indices = (\n",
" indices1.view(bs, knn, 1).expand(bs, knn, knn) * n_keys +\n",
" indices2.view(bs, 1, knn).expand(bs, knn, knn)\n",
" ).view(bs, -1) # (bs,knn**2)\n",
"\n",
" # select best scores with associated indices\n",
" scores, best_indices = torch.topk(all_scores, k=knn, dim=1) # (bs,knn)\n",
" indices = all_indices.gather(1, best_indices) # (bs,knn)\n",
"\n",
" assert scores.shape == indices.shape == (bs, knn)\n",
" return scores, indices\n",
"\n",
" def get_indices(self, query):\n",
" \"\"\"\n",
" Generate scores and indices.\n",
" \"\"\"\n",
" assert query.dim() == 2 and query.size(1) == self.k_dim\n",
" query = query.view(-1, self.heads, self.k_dim)\n",
" bs = len(query)\n",
" outputs = [self._get_indices(query[:, i], self.keys[i]) for i in range(self.heads)]\n",
" s = torch.cat([s.view(bs, 1, self.knn) for s, _ in outputs], 1) # (bs,heads,knn)\n",
" i = torch.cat([i.view(bs, 1, self.knn) for _, i in outputs], 1) # (bs,heads,knn)\n",
" return s.view(-1, self.knn), i.view(-1, self.knn)\n",
"\n",
" def forward(self, input):\n",
" \"\"\"\n",
" Read from the memory.\n",
" \"\"\"\n",
" # input dimensions\n",
" assert input.shape[-1] == self.input_dim\n",
" prefix_shape = input.shape[:-1]\n",
" bs = np.prod(prefix_shape)\n",
"\n",
" # compute query\n",
" input = F.dropout(input, p=self.input_dropout, training=self.training) # (...,i_dim)\n",
" query = self.query_proj(input.contiguous().view(-1, self.input_dim)) # (bs,heads*k_dim)\n",
" query = query.view(bs * self.heads, self.k_dim) # (bs*heads,k_dim)\n",
" query = F.dropout(query, p=self.query_dropout, training=self.training) # (bs*heads,k_dim)\n",
" assert query.shape == (bs * self.heads, self.k_dim)\n",
"\n",
" # retrieve indices and scores\n",
" scores, indices = self.get_indices(query) # (bs*heads,knn)\n",
" scores = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs*heads,knn)\n",
"\n",
" # merge heads / knn (since we sum heads)\n",
" indices = indices.view(bs, self.heads * self.knn) # (bs,heads*knn)\n",
" scores = scores.view(bs, self.heads * self.knn) # (bs,heads*knn)\n",
"\n",
" # weighted sum of values\n",
" output = self.values(indices, per_sample_weights=scores) # (bs,v_dim)\n",
" output = F.dropout(output, p=self.value_dropout, training=self.training)# (bs,v_dim)\n",
"\n",
" # reshape output\n",
" if len(prefix_shape) >= 2:\n",
" output = output.view(prefix_shape + (self.v_dim,)) # (...,v_dim)\n",
"\n",
" return output\n",
"\n",
" @staticmethod\n",
" def register_args(parser):\n",
" \"\"\"\n",
" Register memory parameters.\n",
" \"\"\"\n",
" # memory parameters\n",
" parser.add_argument(\"--sparse\", type=bool_flag, default=False,\n",
" help=\"Perform sparse updates for the values\")\n",
" parser.add_argument(\"--k_dim\", type=int, default=256,\n",
" help=\"Memory keys dimension\")\n",
" parser.add_argument(\"--heads\", type=int, default=4,\n",
" help=\"Number of memory heads\")\n",
" parser.add_argument(\"--knn\", type=int, default=32,\n",
" help=\"Number of memory slots to read / update - k-NN to the query\")\n",
" parser.add_argument(\"--n_keys\", type=int, default=512,\n",
" help=\"Number of keys\")\n",
" parser.add_argument(\"--query_batchnorm\", type=bool_flag, default=False,\n",
" help=\"Query MLP batch norm\")\n",
"\n",
" # dropout\n",
" parser.add_argument(\"--input_dropout\", type=float, default=0,\n",
" help=\"Input dropout\")\n",
" parser.add_argument(\"--query_dropout\", type=float, default=0,\n",
" help=\"Query dropout\")\n",
" parser.add_argument(\"--value_dropout\", type=float, default=0,\n",
" help=\"Value dropout\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class AttrDict(dict):\n",
" def __init__(self, *args, **kwargs):\n",
" super(AttrDict, self).__init__(*args, **kwargs)\n",
" self.__dict__ = self\n",
"\n",
"\n",
"params = AttrDict({\n",
" \"sparse\": False,\n",
" \"k_dim\": 128,\n",
" \"heads\": 4,\n",
" \"knn\": 32,\n",
" \"n_keys\": 512, # the memory will have (n_keys ** 2) values\n",
" \"query_batchnorm\": True,\n",
" \"input_dropout\": 0,\n",
" \"query_dropout\": 0,\n",
" \"value_dropout\": 0,\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING: Applying batch normalization to queries improves the performance and memory usage. But if you use it, be sure that you use batches of sentences with the same size at training time (i.e. without padding). Otherwise, the padding token will result in incorrect mean/variance estimations in the BatchNorm layer.\n",
"\n",
"HashingMemory(\n",
" (values): EmbeddingBag(262144, 100, mode=sum)\n",
" (query_proj): Sequential(\n",
" (0): Linear(in_features=50, out_features=512, bias=True)\n",
" (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
")\n"
]
}
],
"source": [
"device = 'cuda' # cpu / cuda\n",
"input_dim = 50\n",
"output_dim = 100\n",
"memory = HashingMemory(input_dim, output_dim, params).to(device=device)\n",
"print(memory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.14277362823486328\n",
"torch.Size([2, 3, 4, 100])\n"
]
}
],
"source": [
"x = torch.randn(2, 3, 4, input_dim).to(device=device)\n",
"output = memory(x)\n",
"print(output.sum().item())\n",
"print(output.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ python glue-xnli.py
## V. Product-Key Memory Layers (PKM)

XLM also implements the Product-Key Memory layer (PKM) described in [[4]](https://arxiv.org/abs/1907.05242). To add a memory in (for instance) the layers 4 and 7 of an encoder, you can simply provide `--use_memory true --mem_enc_positions 4,7` as argument of `train.py` (and similarly for `--mem_dec_positions` and the decoder). All memory layer parameters can be found [here](https://github.com/facebookresearch/XLM/blob/master/src/model/memory/memory.py#L225).
A minimalist and simple implementation of the PKM layer, that uses the same configuration as in the paper, can be found in this **[ipython notebook](https://github.com/facebookresearch/XLM/blob/master/PKM-layer.ipynb)**.


## Frequently Asked Questions
Expand Down

0 comments on commit 46d6595

Please sign in to comment.