<a href="https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/BertTagger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -qqq torchtext wandb pytorch-transformers
!pip install -qqqU git+https://github.com/harvardnlp/pytorch-struct@prob

  Building wheel for torch-struct (setup.py) ... [?25l[?25hdone


In [2]:
import torchtext
import torch
import torch.nn as nn
from torch_struct import LinearChainCRF
import torch_struct.data
from pytorch_transformers import *
config = {"bert": "bert-base-cased", "H" : 768, "dropout": 0.2
         }

# Comment or add your wandb
import wandb
wandb.init(project="pytorch-struct-tagging", config=config)


W&B Run: https://app.wandb.ai/srush/pytorch-struct-tagging/runs/67eo4ise

Setup data and batching.

In [3]:
model_class, tokenizer_class, pretrained_weights = BertModel, BertTokenizer, config["bert"]
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)    
WORD = torch_struct.data.SubTokenizedField(tokenizer)
UD_TAG = torchtext.data.Field(init_token="<bos>", eos_token="<eos>", include_lengths=True)

train, val, test = torchtext.datasets.UDPOS.splits(
    fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), 
    filter_pred=lambda ex: len(ex.word[0]) < 200
)

#WORD.build_vocab(train.word, min_freq=3)
UD_TAG.build_vocab(train.udtag)
train_iter = torch_struct.data.TokenBucket(train, 750)
val_iter = torchtext.data.BucketIterator(val, 
    batch_size=10,
    device="cuda:0")

error


Setup transformer and a simple one-layer model.

In [0]:
C = len(UD_TAG.vocab)


class Model(nn.Module):
    def __init__(self, hidden, classes):
        super().__init__()
        self.base_model = model_class.from_pretrained(pretrained_weights)
        self.linear = nn.Linear(hidden, C)
        self.transition = nn.Linear(C, C)
        self.dropout = nn.Dropout(config["dropout"])
        
    def forward(self, words, mapper):
        out = self.dropout(self.base_model(words)[0])
        out = torch.einsum("bca,bch->bah", mapper.float().cuda(), out)
        final = torch.einsum("bnh,ch->bnc", out, self.linear.weight)
        batch, N, C = final.shape
        vals = final.view(batch, N, C, 1)[:, 1:N] + self.transition.weight.view(1, 1, C, C)
        vals[:, 0, :, :] += final.view(batch, N, 1, C)[:, 0] 
        return vals
model = Model(config["H"], C)
wandb.watch(model)
model.cuda()
None

Generic train validation loop. 

In [0]:
def validate(itera):
    incorrect_edges = 0
    total = 0 
    model.eval()
    for i, ex in enumerate(itera):
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        dist = LinearChainCRF(model(words.cuda(), mapper),
                              lengths=lengths)        
        argmax = dist.argmax
        gold = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C,
                                              lengths=lengths).type_as(argmax)
        incorrect_edges += (argmax.sum(-1) - gold.sum(-1)).abs().sum() / 2.0
        total += argmax.sum()            
        
    model.train()    
    return incorrect_edges / total   
    
def train(train_iter, val_iter, model):
    opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8)
    scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500)

    model.train()
    losses = []
    for i, ex in enumerate(train_iter):
        opt.zero_grad()
        words, mapper, _ = ex.word
        label, lengths = ex.udtag
        N_1, batch = label.shape

        # Model
        log_potentials = model(words.cuda(), mapper)
        if not lengths.max() <= log_potentials.shape[1] + 1:
            print("fail")
            continue

        dist = LinearChainCRF(log_potentials,
                              lengths=lengths.cuda())    

        
        labels = LinearChainCRF.struct.to_parts(label.transpose(0, 1), C, lengths=lengths) \
                            .type_as(dist.log_potentials)
        loss = dist.log_prob(labels).sum()
        (-loss).backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        scheduler.step()

        losses.append(loss.detach())
        
        
        if i % 100 == 10:            
            print(-torch.tensor(losses).mean(), words.shape)
            val_loss = validate(val_iter)
            wandb.log({"train_loss":-torch.tensor(losses).mean(), 
                       "val_loss" : val_loss})
            losses = []
            

In [0]:
train(train_iter, val_iter, model.cuda()) 

tensor(1828.1356) torch.Size([14, 54])
tensor(608.0067) torch.Size([27, 27])
