In [1]:
%load_ext autoreload
%autoreload 2

import random
import torch
from torch.utils.data import DataLoader, IterableDataset
from pseudo_data import LinearChainDataset

In [2]:
print(LinearChainDataset.__doc__)
dataset = LinearChainDataset(length=25)
for i, ele in enumerate(dataset):
    nodes, masks, targets = ele
    print(f"=== {i} ===")
    print(dataset.to_sent(nodes))
    print(dataset.to_label(targets))
    if i == 2:
        break


    This is a dataset simulating sequence labeling in NLP.
    An item looks like:
        number1 operator number2 = number3
    where:
    - a number can be: 0.133/-0.133/.333/3.
    - the label segments the numbers and operators, following the BMES-style in sequence labeling
    
=== 0 ===
-774.283--.632=-319@@@@@@
BMMMMMMESBMMMESBMME@@@@@@
=== 1 ===
.739*385.107=.300@@@@@@@@
BMMESBMMMMMESBMME@@@@@@@@
=== 2 ===
544./715=820.255@@@@@@@@@
BMMESBMESBMMMMME@@@@@@@@@


In [3]:
dataset = LinearChainDataset(length=100)
train_data = DataLoader(dataset, batch_size=20)
test_data = DataLoader(dataset, batch_size=100)

In [4]:
import time
from tqdm import tqdm
from tabulate import tabulate
from torch_random_fields.models import LinearChainCRF
from torch_random_fields.models.constants import Inference, Learning


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embed = torch.nn.Embedding(num_embeddings=len(dataset.word2idx), embedding_dim=10)
        self.pred = torch.nn.Linear(10, len(dataset.label2idx))
        self.crf = LinearChainCRF(
            len(dataset.label2idx),
            low_rank=5,
            learning=Learning.PIECEWISE,
            inference=Inference.VITERBI,
            feature_size=10,
        )

    def forward(self, nodes, masks, targets):
        feats = self.embed(nodes)
        unaries = self.pred(feats)
        loss = self.crf(
            unaries=unaries,
            masks=masks,
            node_features=feats,
            targets=targets,
        )
        return loss

    def decode(self, nodes, masks):
        return self(nodes, masks, None)[1]

    def evaulate(self, nodes, masks, targets):
        pred = self(nodes, masks, None)[1]
        pred.masked_fill_(~masks, dataset.label2idx[dataset.pad])
        corr = pred == targets
        accu = corr[masks].sum() / masks.sum()
        return accu


cost_table = []
accu_table = []

for learning in [Learning.PIECEWISE, Learning.PSEUDO_LIKELIHOOD, Learning.PERCEPTRON, Learning.EXACT_LIKELIHOOD]:
    # for learning in [Learning.PERCEPTRON]:
    cost_table.append([learning])
    accu_table.append([learning])

    model = Model()
    model.crf.learning = learning
    opt = torch.optim.Adam(model.parameters(), lr=0.001)

    last_time = time.time()
    for i in tqdm(range(301)):
        nodes, masks, targets = next(iter(train_data))
        loss = model(nodes, masks, targets)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if i > 0 and i % 100 == 0:
            current_time = time.time()
            cost_time = current_time - last_time
            last_time = current_time
            accu = model.evaulate(*next(iter(test_data)))
            cost_table[-1].append(cost_time)
            accu_table[-1].append(accu)

headers = ["iteration"] + [str(i * 100) for i in range(1, len(cost_table[0]) + 1)]
print("====== cost ======")
print(tabulate(cost_table, headers=headers, floatfmt=".3f"))
print("====== accu ======")
print(tabulate(accu_table, headers=headers, floatfmt=".3f"))


100%|██████████| 301/301 [00:05<00:00, 50.85it/s] 
100%|██████████| 301/301 [00:02<00:00, 119.83it/s]
100%|██████████| 301/301 [00:05<00:00, 57.37it/s]
100%|██████████| 301/301 [00:13<00:00, 22.28it/s]

iteration            100    200    300
-----------------  -----  -----  -----
piecewise          4.137  0.889  0.868
pseudo-likelihood  1.023  0.789  0.682
perceptron         1.762  1.686  1.770
exact-likelihood   4.698  4.284  4.506
iteration            100    200    300
-----------------  -----  -----  -----
piecewise          0.876  0.887  0.911
pseudo-likelihood  0.871  0.875  0.880
perceptron         0.309  0.393  0.471
exact-likelihood   0.541  0.756  0.820





In [5]:
cost_table = []
accu_table = []

INFERENCE_METHODS = [Inference.VITERBI, Inference.BATCH_MEAN_FIELD]
for inference in INFERENCE_METHODS:
    cost_table.append([inference])
    accu_table.append([inference])

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=0.001)

for i in tqdm(range(701)):
    nodes, masks, targets = next(iter(train_data))
    loss = model(nodes, masks, targets)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i > 0 and i % 100 == 0:
        for inference in INFERENCE_METHODS:
            last_time = time.time()
            model.crf.inference = inference
            accu = model.evaulate(*next(iter(test_data)))
            cost_time = time.time() - last_time

            cost_table[INFERENCE_METHODS.index(inference)].append(cost_time)
            accu_table[INFERENCE_METHODS.index(inference)].append(accu)

headers = ["iteration"] + [str(i * 100) for i in range(1, len(cost_table[0]) + 1)]
print("====== cost ======")
print(tabulate(cost_table, headers=headers, floatfmt=".3f"))
print("====== accu ======")
print(tabulate(accu_table, headers=headers, floatfmt=".3f"))

100%|██████████| 701/701 [00:05<00:00, 127.06it/s]

iteration           100    200    300    400    500    600    700
----------------  -----  -----  -----  -----  -----  -----  -----
viterbi           0.020  0.021  0.023  0.022  0.021  0.022  0.023
batch-mean-field  0.017  0.017  0.017  0.017  0.017  0.017  0.017
iteration           100    200    300    400    500    600    700
----------------  -----  -----  -----  -----  -----  -----  -----
viterbi           0.866  0.892  0.909  0.919  0.927  0.932  0.948
batch-mean-field  0.868  0.877  0.889  0.901  0.914  0.913  0.928



