<a href="https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/BertDependencies.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install -q torchtext
!pip install -q pytorch-transformers
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
!git clone http://github.com/srush/temp

[K     |████████████████████████████████| 184kB 2.9MB/s 
[K     |████████████████████████████████| 808kB 34.1MB/s 
[K     |████████████████████████████████| 655kB 40.6MB/s 
[K     |████████████████████████████████| 1.0MB 37.7MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Building wheel for regex (setup.py) ... [?25l[?25hdone
  Building wheel for torch-struct (setup.py) ... [?25l[?25hdone


In [0]:
import torchtext
import torch
from torch_struct import DepTree, MaxSemiring
import torch_struct.data
import torchtext.data as data
from pytorch_transformers import AdamW, WarmupLinearSchedule

from pytorch_transformers import *



100%|██████████| 213450/213450 [00:00<00:00, 5794504.92B/s]


Parse the conll dependency data.

In [0]:
class ConllXDataset(data.Dataset):
    def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs):
        examples = []
        columns = [[], []]
        column_map = {1 : 0, 6: 1}
        with open(path, encoding=encoding) as input_file:
            for line in input_file:
                line = line.strip()
                if line == "":
                    if columns:
                        examples.append(data.Example.fromlist(columns, fields))
                    columns = [[], []]
                else:
                    for i, column in enumerate(line.split(separator)):
                        if i in column_map:
                            columns[column_map[i]].append(column)

            if columns:
                examples.append(data.Example.fromlist(columns, fields))
        super(ConllXDataset, self).__init__(examples, fields,
                                                     **kwargs)

TorchText batching setup.

In [0]:
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, 'bert-large-cased'
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
def batch_num(nums):
    lengths = torch.tensor([len(n) for n in nums]).long()
    n = lengths.max()
    out = torch.zeros(len(nums), n).long()
    for b, n in enumerate(nums):
        out[b, :len(n)] = torch.tensor(n)
    return out, lengths
HEAD = data.RawField(preprocessing= lambda x: [int(i) for i in x],
                     postprocessing=batch_num)
WORD = torch_struct.data.SubTokenizedField(tokenizer)
HEAD.is_target = True
train = ConllXDataset("temp/wsj.train.conllx", (('word', WORD), ('head', HEAD)),
                     filter_pred=lambda x: 5 < len(x.word[0]) < 40)
train_iter = torch_struct.data.TokenBucket(train, 750)
val = ConllXDataset("temp/wsj.dev.conllx", (('word', WORD), ('head', HEAD)),
                     filter_pred=lambda x: 5 < len(x.word[0]) < 40)
val_iter = torchtext.data.BucketIterator(val, 
    batch_size=20,
    device="cuda:0")

Make a Bert model to compute the potentials

In [0]:
model = model_class.from_pretrained(pretrained_weights)
model.cuda()
H = 1024 #768
linear = torch.rand(H, H).cuda().requires_grad_(True)
bilinear = torch.rand(H, H).cuda().requires_grad_(True)
root = torch.rand(H).cuda().requires_grad_(True)
root.data.normal_(mean=0, std=0.02)
bilinear.data.normal_(mean=0, std=0.02)
linear.data.normal_(mean=0, std=0.02)

opt = AdamW([linear, bilinear, root] + list(model.parameters()), lr=1e-4, eps=1e-8)
scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)

100%|██████████| 521/521 [00:00<00:00, 231731.96B/s]
100%|██████████| 1338740706/1338740706 [00:27<00:00, 48483202.14B/s]


In [0]:
def potentials(words, mapper):
    out = model(words)
    out = torch.nn.functional.dropout(out[0], p=0.1, training=model.training)
    out = torch.einsum("bca,bch->bah", mapper.float().cuda(), out)
    final2 = torch.einsum("bnh,hg->bng", out, linear)
    final = torch.einsum("bnh,hg,bmg->bnm", out, bilinear, final2)
    root_score = torch.einsum("bnh,h->bn", out, root)
    final = final[:, 1:-1, 1:-1]
    N = final.shape[1]
    final[:, torch.arange(N), torch.arange(N)] += root_score[:, 1:-1]
    return final

Generic training loop. 

In [0]:
def validate():
    incorrect_edges = 0
    total_edges = 0
    model.eval()
    for i, ex in enumerate(val_iter):
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape

        final = potentials(words.cuda(), mapper)
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0
        out = DepTree(MaxSemiring).marginals(final, lengths=lengths)
        gold = DepTree.to_parts(label, lengths=lengths).type_as(out)
        incorrect_edges += (out[:, :].cpu() - gold[:, :].cpu()).abs().sum() / 2.0
        total_edges += gold.sum()

    print(total_edges, incorrect_edges)   
    model.train()

def train(train_iter):
    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.head
        batch, _ = label.shape
        # Model
        final = potentials(words.cuda(), mapper)
        
        for b in range(batch):
            final[b, lengths[b]-1:, :] = 0
            final[b, :, lengths[b]-1:] = 0
        
        if not lengths.max() <= final.shape[1] + 1:
            print("fail")
            continue
        log_partition = DepTree().sum(final, lengths=lengths)
        # Compute loss.
        labels = DepTree.to_parts(label, lengths=lengths).type_as(final)
        log_prob = DepTree().score(final, labels) - log_partition
        
        loss = log_prob.sum()
        (-loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        opt.step()
        scheduler.step()
        opt.zero_grad()
        losses.append(loss.detach())
        if i % 50 == 1:            
            print(-torch.tensor(losses).mean(), words.shape)
            losses = []
        if i % 600 == 500:
            validate()
        


In [0]:
train(train_iter)

tensor(3448.2988) torch.Size([58, 13])
