Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 4, 2019
1 parent 0946e17 commit 5bcd715
Showing 1 changed file with 57 additions and 144 deletions.
201 changes: 57 additions & 144 deletions notebooks/BertDependencies.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"metadata": {
"colab": {
"name": "Tree.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
Expand All @@ -31,30 +30,25 @@
"metadata": {
"id": "dENII4iGN3S4",
"colab_type": "code",
"outputId": "d03c3cdb-9d40-4b58-aaf0-ff28bd1877d3",
"outputId": "7dd16b64-34ca-46c6-d404-c9107587077a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
"height": 52
}
},
"source": [
"!pip install -q torchtext\n",
"!pip install -q pytorch-transformers\n",
"!pip install -qU git+https://github.com/harvardnlp/pytorch-struct\n",
"!git clone http://github.com/srush/temp"
"!pip install -qqq torchtext\n",
"!pip install -qqq pytorch-transformers\n",
"!pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct\n",
"!git clone -q http://github.com/srush/temp"
],
"execution_count": 0,
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 184kB 2.9MB/s \n",
"\u001b[K |████████████████████████████████| 808kB 34.1MB/s \n",
"\u001b[K |████████████████████████████████| 655kB 40.6MB/s \n",
"\u001b[K |████████████████████████████████| 1.0MB 37.7MB/s \n",
"\u001b[?25h Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for regex (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
" Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"fatal: destination path 'temp' already exists and is not an empty directory.\n"
],
"name": "stdout"
}
Expand All @@ -65,33 +59,20 @@
"metadata": {
"id": "EghAJ7aJDl4V",
"colab_type": "code",
"outputId": "5463b0c0-44ba-4863-8b47-c254bf999699",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
"colab": {}
},
"source": [
"import torchtext\n",
"import torch\n",
"from torch_struct import DepTree, MaxSemiring\n",
"import torch.nn as nn\n",
"from torch_struct import DependencyCRF\n",
"import torch_struct.data\n",
"import torchtext.data as data\n",
"from pytorch_transformers import AdamW, WarmupLinearSchedule\n",
"\n",
"from pytorch_transformers import *\n",
"\n"
"from pytorch_transformers import *"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 213450/213450 [00:00<00:00, 5794504.92B/s]\n"
],
"name": "stderr"
}
]
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -103,39 +84,6 @@
"Parse the conll dependency data."
]
},
{
"cell_type": "code",
"metadata": {
"id": "3I6vJEAWN_yB",
"colab_type": "code",
"colab": {}
},
"source": [
"class ConllXDataset(data.Dataset):\n",
" def __init__(self, path, fields, encoding=\"utf-8\", separator=\"\\t\", **kwargs):\n",
" examples = []\n",
" columns = [[], []]\n",
" column_map = {1 : 0, 6: 1}\n",
" with open(path, encoding=encoding) as input_file:\n",
" for line in input_file:\n",
" line = line.strip()\n",
" if line == \"\":\n",
" if columns:\n",
" examples.append(data.Example.fromlist(columns, fields))\n",
" columns = [[], []]\n",
" else:\n",
" for i, column in enumerate(line.split(separator)):\n",
" if i in column_map:\n",
" columns[column_map[i]].append(column)\n",
"\n",
" if columns:\n",
" examples.append(data.Example.fromlist(columns, fields))\n",
" super(ConllXDataset, self).__init__(examples, fields,\n",
" **kwargs)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -167,10 +115,10 @@
" postprocessing=batch_num)\n",
"WORD = torch_struct.data.SubTokenizedField(tokenizer)\n",
"HEAD.is_target = True\n",
"train = ConllXDataset(\"temp/wsj.train.conllx\", (('word', WORD), ('head', HEAD)),\n",
"train = torch_struct.data.ConllXDataset(\"temp/wsj.train.conllx\", (('word', WORD), ('head', HEAD)),\n",
" filter_pred=lambda x: 5 < len(x.word[0]) < 40)\n",
"train_iter = torch_struct.data.TokenBucket(train, 750)\n",
"val = ConllXDataset(\"temp/wsj.dev.conllx\", (('word', WORD), ('head', HEAD)),\n",
"val = torch_struct.data.ConllXDataset(\"temp/wsj.dev.conllx\", (('word', WORD), ('head', HEAD)),\n",
" filter_pred=lambda x: 5 < len(x.word[0]) < 40)\n",
"val_iter = torchtext.data.BucketIterator(val, \n",
" batch_size=20,\n",
Expand All @@ -194,57 +142,33 @@
"metadata": {
"id": "QUKVqGUsSmYY",
"colab_type": "code",
"outputId": "3c41337b-cfb3-48e3-e2db-33b28b95d27a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
"colab": {}
},
"source": [
"model = model_class.from_pretrained(pretrained_weights)\n",
"model.cuda()\n",
"H = 1024 #768\n",
"linear = torch.rand(H, H).cuda().requires_grad_(True)\n",
"bilinear = torch.rand(H, H).cuda().requires_grad_(True)\n",
"root = torch.rand(H).cuda().requires_grad_(True)\n",
"root.data.normal_(mean=0, std=0.02)\n",
"bilinear.data.normal_(mean=0, std=0.02)\n",
"linear.data.normal_(mean=0, std=0.02)\n",
"class Model(nn.Module):\n",
" def __init__(self, hidden):\n",
" super().__init__()\n",
" self.base_model = model_class.from_pretrained(pretrained_weights)\n",
" self.linear = nn.Linear(H, H)\n",
" self.bilinear = nn.Linear(H, H)\n",
" self.root = nn.Parameter(torch.rand(H))\n",
" self.dropout = nn.Dropout(0.1)\n",
" \n",
" def forward(self, words, mapper):\n",
" out = self.dropout(self.base_model(words))\n",
" out = torch.einsum(\"bca,bch->bah\", mapper.float().cuda(), out)\n",
" final2 = torch.einsum(\"bnh,hg->bng\", out, self.linear.weight)\n",
" final = torch.einsum(\"bnh,hg,bmg->bnm\", out, self.bilinear.weight, final2)\n",
" root_score = torch.einsum(\"bnh,h->bn\", out, self.root)\n",
" final = final[:, 1:-1, 1:-1]\n",
" N = final.shape[1]\n",
" final[:, torch.arange(N), torch.arange(N)] += root_score[:, 1:-1]\n",
" return final\n",
"\n",
"opt = AdamW([linear, bilinear, root] + list(model.parameters()), lr=1e-4, eps=1e-8)\n",
"scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 521/521 [00:00<00:00, 231731.96B/s]\n",
"100%|██████████| 1338740706/1338740706 [00:27<00:00, 48483202.14B/s]\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xOYWpYiCI2w5",
"colab_type": "code",
"colab": {}
},
"source": [
"def potentials(words, mapper):\n",
" out = model(words)\n",
" out = torch.nn.functional.dropout(out[0], p=0.1, training=model.training)\n",
" out = torch.einsum(\"bca,bch->bah\", mapper.float().cuda(), out)\n",
" final2 = torch.einsum(\"bnh,hg->bng\", out, linear)\n",
" final = torch.einsum(\"bnh,hg,bmg->bnm\", out, bilinear, final2)\n",
" root_score = torch.einsum(\"bnh,h->bn\", out, root)\n",
" final = final[:, 1:-1, 1:-1]\n",
" N = final.shape[1]\n",
" final[:, torch.arange(N), torch.arange(N)] += root_score[:, 1:-1]\n",
" return final"
"model = Model(H)\n",
"wandb.watch(model)\n",
"model.cuda()\n"
],
"execution_count": 0,
"outputs": []
Expand Down Expand Up @@ -276,55 +200,56 @@
" label, lengths = ex.head\n",
" batch, _ = label.shape\n",
"\n",
" final = potentials(words.cuda(), mapper)\n",
" final = model(words.cuda(), mapper)\n",
" for b in range(batch):\n",
" final[b, lengths[b]-1:, :] = 0\n",
" final[b, :, lengths[b]-1:] = 0\n",
" out = DepTree(MaxSemiring).marginals(final, lengths=lengths)\n",
" gold = DepTree.to_parts(label, lengths=lengths).type_as(out)\n",
" dist = DependencyCRF(final, lengths=lengths)\n",
" argmax = dist.argmax\n",
" gold = dist.struct.to_parts(label, lengths=lengths).type_as(argmax)\n",
" incorrect_edges += (out[:, :].cpu() - gold[:, :].cpu()).abs().sum() / 2.0\n",
" total_edges += gold.sum()\n",
"\n",
" print(total_edges, incorrect_edges) \n",
" model.train()\n",
"\n",
"def train(train_iter):\n",
"def train(train_iter, val_iter, model):\n",
" opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)\n",
" scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)\n",
" model.train()\n",
" losses = []\n",
" for i, ex in enumerate(train_iter):\n",
" opt.zero_grad()\n",
" words, mapper, _ = ex.word\n",
" label, lengths = ex.head\n",
" batch, _ = label.shape\n",
" # Model\n",
" final = potentials(words.cuda(), mapper)\n",
" \n",
" # Model\n",
" final = model(words.cuda(), mapper)\n",
" for b in range(batch):\n",
" final[b, lengths[b]-1:, :] = 0\n",
" final[b, :, lengths[b]-1:] = 0\n",
" \n",
" if not lengths.max() <= final.shape[1] + 1:\n",
" print(\"fail\")\n",
" continue\n",
" log_partition = DepTree().sum(final, lengths=lengths)\n",
" # Compute loss.\n",
" labels = DepTree.to_parts(label, lengths=lengths).type_as(final)\n",
" log_prob = DepTree().score(final, labels) - log_partition\n",
" \n",
" dist = DependencyCRF(final, lengths=lengths)\n",
"\n",
" labels = dist.struct.to_parts(label, lengths=lengths).type_as(final)\n",
" log_prob = dist.log_prob(final, labels)\n",
"\n",
" loss = log_prob.sum()\n",
" (-loss).backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" opt.step()\n",
" scheduler.step()\n",
" opt.zero_grad()\n",
" losses.append(loss.detach())\n",
" if i % 50 == 1: \n",
" print(-torch.tensor(losses).mean(), words.shape)\n",
" losses = []\n",
" if i % 600 == 500:\n",
" validate()\n",
" \n"
" validate(val_iter) "
],
"execution_count": 0,
"outputs": []
Expand All @@ -334,25 +259,13 @@
"metadata": {
"id": "D-HW3z1VT2MG",
"colab_type": "code",
"outputId": "38f50978-e21c-443c-bed7-79b47a45c877",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
}
"colab": {}
},
"source": [
"train(train_iter)"
"train(train_iter, val_iter, model)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(3448.2988) torch.Size([58, 13])\n"
],
"name": "stdout"
}
]
"outputs": []
}
]
}

0 comments on commit 5bcd715

Please sign in to comment.