<a href="https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/BertDependencies.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -qqq torchtext
!pip install -qqq pytorch-transformers
!pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct
!git clone -q http://github.com/srush/temp

  Building wheel for torch-struct (setup.py) ... [?25l[?25hdone
fatal: destination path 'temp' already exists and is not an empty directory.


In [0]:
import torchtext
import torch
import torch.nn as nn
from torch_struct import DependencyCRF
import torch_struct.data
import torchtext.data as data
from pytorch_transformers import AdamW, WarmupLinearSchedule
from pytorch_transformers import *

Parse the conll dependency data.

TorchText batching setup.

In [0]:
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, 'bert-large-cased'
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
def batch_num(nums):
    lengths = torch.tensor([len(n) for n in nums]).long()
    n = lengths.max()
    out = torch.zeros(len(nums), n).long()
    for b, n in enumerate(nums):
        out[b, :len(n)] = torch.tensor(n)
    return out, lengths
HEAD = data.RawField(preprocessing= lambda x: [int(i) for i in x],
                     postprocessing=batch_num)
WORD = torch_struct.data.SubTokenizedField(tokenizer)
HEAD.is_target = True
train = torch_struct.data.ConllXDataset("temp/wsj.train.conllx", (('word', WORD), ('head', HEAD)),
                     filter_pred=lambda x: 5 < len(x.word[0]) < 40)
train_iter = torch_struct.data.TokenBucket(train, 750)
val = torch_struct.data.ConllXDataset("temp/wsj.dev.conllx", (('word', WORD), ('head', HEAD)),
                     filter_pred=lambda x: 5 < len(x.word[0]) < 40)
val_iter = torchtext.data.BucketIterator(val, 
    batch_size=20,
    device="cuda:0")

Make a Bert model to compute the potentials

In [0]:
H = 1024 #768
class Model(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.base_model = model_class.from_pretrained(pretrained_weights)
        self.linear = nn.Linear(H, H)
        self.bilinear = nn.Linear(H, H)
        self.root = nn.Parameter(torch.rand(H))
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, words, mapper):
        out = self.dropout(self.base_model(words))
        out = torch.einsum("bca,bch->bah", mapper.float().cuda(), out)
        final2 = torch.einsum("bnh,hg->bng", out, self.linear.weight)
        final = torch.einsum("bnh,hg,bmg->bnm", out, self.bilinear.weight, final2)
        root_score = torch.einsum("bnh,h->bn", out, self.root)
        final = final[:, 1:-1, 1:-1]
        N = final.shape[1]
        final[:, torch.arange(N), torch.arange(N)] += root_score[:, 1:-1]
        return final

model = Model(H)
wandb.watch(model)
model.cuda()


Generic training loop. 

In [0]:
def validate():
    incorrect_edges = 0
    total_edges = 0
    model.eval()
    for i, ex in enumerate(val_iter):
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape

        final = model(words.cuda(), mapper)
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0
        dist = DependencyCRF(final, lengths=lengths)
        argmax = dist.argmax
        gold = dist.struct.to_parts(label, lengths=lengths).type_as(argmax)
        incorrect_edges += (out[:, :].cpu() - gold[:, :].cpu()).abs().sum() / 2.0
        total_edges += gold.sum()

    print(total_edges, incorrect_edges)   
    model.train()

def train(train_iter, val_iter, model):
    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)
    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape
        
        # Model
        final = model(words.cuda(), mapper)
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0
        
        if not lengths.max() <= final.shape[1] + 1:
            print("fail")
            continue
        dist = DependencyCRF(final, lengths=lengths)

        labels = dist.struct.to_parts(label, lengths=lengths).type_as(final)
        log_prob = dist.log_prob(final, labels)

        loss = log_prob.sum()
        (-loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        opt.step()
        scheduler.step()
        losses.append(loss.detach())
        if i % 50 == 1:            
            print(-torch.tensor(losses).mean(), words.shape)
            losses = []
        if i % 600 == 500:
            validate(val_iter)        

In [0]:
train(train_iter, val_iter, model)