<a href="https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/BertTagger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q torchtext
!pip install -q pytorch-transformers
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct

[K     |████████████████████████████████| 184kB 9.9MB/s 
[K     |████████████████████████████████| 655kB 59.7MB/s 
[K     |████████████████████████████████| 808kB 52.9MB/s 
[K     |████████████████████████████████| 1.0MB 62.6MB/s 
[?25h  Building wheel for regex (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Building wheel for torch-struct (setup.py) ... [?25l[?25hdone


In [2]:
import torchtext
import torch
from torch_struct import LinearChain, MaxSemiring
import torch_struct.data
from pytorch_transformers import *


100%|██████████| 213450/213450 [00:00<00:00, 303952.00B/s]


Setup data and batching.

In [3]:
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, 'bert-base-cased'
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)    
WORD = torch_struct.data.SubTokenizedField(tokenizer)
UD_TAG = torchtext.data.Field(init_token="<bos>", eos_token="<eos>", include_lengths=True)

# Download and the load default data.
train, val, test = torchtext.datasets.UDPOS.splits(
    fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), 
    filter_pred=lambda ex: len(ex.word[0]) < 50
)

#WORD.build_vocab(train.word, min_freq=3)
UD_TAG.build_vocab(train.udtag)
train_iter = torch_struct.data.TokenBucket(train, 1500)

downloading en-ud-v2.zip


en-ud-v2.zip: 100%|██████████| 688k/688k [00:00<00:00, 1.55MB/s]


extracting
error


Setup transformer and a simple one-layer model.

In [0]:
from pytorch_transformers import AdamW, WarmupLinearSchedule

model = model_class.from_pretrained(pretrained_weights)
model.cuda()
C = len(UD_TAG.vocab)
H = 768
linear = torch.zeros(H, C).cuda().requires_grad_(True)
linear.data.normal_(mean=0, std=0.02)
transition = torch.zeros(C, C).cuda().requires_grad_(True)
transition.data.normal_(mean=0, std=0.02)

opt = AdamW([linear, transition] + list(model.parameters()), lr=1e-4, eps=1e-8)
scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)
def potentials(words, mapper):
    out = model(words)
    out = torch.nn.functional.dropout(out[0], p=0.1, training=model.training)
    out = torch.einsum("bca,bch->bah", mapper.float().cuda(), out)
    final = torch.einsum("bnh,hc->bnc", out, linear)
    batch, N, C = final.shape
    vals = final.view(batch, N, C, 1)[:, 1:N] + transition.view(1, 1, C, C)
    vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] 
    return vals

Generic train validation loop. 

In [0]:
val_iter = torchtext.data.BucketIterator(val, 
    batch_size=20,
    device="cuda:0")

def validate(itera):
    incorrect_edges = 0
    total = 0 
    model.eval()
    for i, ex in enumerate(itera):
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        final = potentials(words.cuda(), mapper)
        argmax = LinearChain(MaxSemiring).marginals(final, lengths=lengths)
        gold = LinearChain.to_parts(label.transpose(0, 1), C,
                                    lengths=lengths).type_as(argmax)
        incorrect_edges += (argmax.sum(-1) - gold.sum(-1)).abs().sum() / 2.0
        total += argmax[:, :].sum()            
        #if i == 50:
        #    break
    print(incorrect_edges, total)   
    model.train()

def train():
    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        N_1, batch = label.shape
        # Model
        final = potentials(words.cuda(), mapper)
        
        if not lengths.max() <= final.shape[1] + 1:
            print("fail")
            continue
        
        log_partition = LinearChain().sum(final, lengths=lengths)
        labels = LinearChain.to_parts(label.transpose(0, 1), C, lengths=lengths) \
                            .type_as(final)
        log_prob = LinearChain().score(final, labels) - log_partition
        loss = log_prob.sum()
        (-loss).backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        scheduler.step()
        opt.zero_grad()
        losses.append(loss.detach())
        if i % 50 == 1:            
            print(-torch.tensor(losses).mean(), words.shape)
            losses = []
        if i % 600 == 1:
            validate(val_iter)

In [16]:
train()

tensor(4135.5098) torch.Size([116, 13])
tensor(12609., device='cuda:0') tensor(12790., device='cuda:0')
tensor(1918.0920) torch.Size([300, 5])
tensor(689.8613) torch.Size([75, 20])
tensor(464.5420) torch.Size([88, 17])
tensor(316.9471) torch.Size([117, 13])
tensor(306.2497) torch.Size([68, 22])
tensor(258.1658) torch.Size([46, 33])
tensor(186.1516) torch.Size([58, 26])
tensor(174.9373) torch.Size([47, 32])
tensor(171.7171) torch.Size([147, 11])
tensor(141.1298) torch.Size([31, 49])
tensor(120.0469) torch.Size([37, 41])
tensor(124.4972) torch.Size([127, 12])
tensor(902., device='cuda:0') tensor(12704., device='cuda:0')
tensor(94.6860) torch.Size([300, 4])
tensor(91.4593) torch.Size([65, 23])
tensor(79.3262) torch.Size([43, 35])
tensor(83.9711) torch.Size([35, 43])
tensor(63.0983) torch.Size([60, 25])
tensor(62.3029) torch.Size([57, 26])
tensor(62.7254) torch.Size([111, 14])
tensor(46.1658) torch.Size([50, 30])
tensor(63.7202) torch.Size([62, 24])
tensor(41.7853) torch.Size([107, 14])
te

KeyboardInterrupt: ignored