In [7]:
%load_ext autoreload
%autoreload 2

import random
import torch
from torch.utils.data import DataLoader, IterableDataset
from pseudo_data import LinearChainDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
print(LinearChainDataset.__doc__)
dataset = LinearChainDataset(length=25)
for i, ele in enumerate(dataset):
    nodes, masks, targets = ele
    print(f"=== {i} ===")
    print(dataset.to_sent(nodes))
    print(dataset.to_label(targets))
    if i == 2:
        break


    This is a dataset simulating sequence labeling in NLP.
    An item looks like:
        number1 operator number2 = number3
    where:
    - a number can be: 0.133/-0.133/.333/3.
    - the label segments the numbers and operators, following the BMES-style in sequence labeling
    
=== 0 ===
-819.215*-567=-21.359@@@@
BMMMMMMESBMMESBMMMMME@@@@
=== 1 ===
.925+-91.131=148@@@@@@@@@
BMMESBMMMMMESBME@@@@@@@@@
=== 2 ===
-571.170/-.406=-631.@@@@@
BMMMMMMESBMMMESBMMME@@@@@


In [9]:
dataset = LinearChainDataset(length=100)
train_data = DataLoader(dataset, batch_size=20)
test_data = DataLoader(dataset, batch_size=100)

In [14]:
import time
from tqdm import tqdm
from tabulate import tabulate
from torch_random_fields.models import LinearChainCRF
from torch_random_fields.models.constants import Inference, Training


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embed = torch.nn.Embedding(num_embeddings=len(dataset.word2idx), embedding_dim=10)
        self.pred = torch.nn.Linear(10, len(dataset.label2idx))
        self.crf = LinearChainCRF(
            len(dataset.label2idx),
            low_rank=5,
            training=Training.PIECEWISE,
            inference=Inference.VITERBI,
            feature_size=10,
        )

    def forward(self, nodes, masks, targets):
        feats = self.embed(nodes)
        unaries = self.pred(feats)
        loss = self.crf(
            unaries=unaries,
            masks=masks,
            node_features=feats,
            targets=targets,
        )
        return loss

    def decode(self, nodes, masks):
        return self(nodes, masks, None)[1]

    def evaulate(self, nodes, masks, targets):
        pred = self(nodes, masks, None)[1]
        pred.masked_fill_(~masks, dataset.label2idx[dataset.pad])
        corr = pred == targets
        accu = corr[masks].sum() / masks.sum()
        return accu


cost_table = []
accu_table = []

for training in [Training.PIECEWISE, Training.PSEUDO_LIKELIHOOD, Training.PERCEPTRON, Training.EXACT_LIKELIHOOD]:
    # for training in [Training.PERCEPTRON]:
    cost_table.append([training])
    accu_table.append([training])

    model = Model()
    model.crf.training = training
    opt = torch.optim.Adam(model.parameters(), lr=0.001)

    last_time = time.time()
    for i in tqdm(range(301)):
        nodes, masks, targets = next(iter(train_data))
        loss = model(nodes, masks, targets)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if i > 0 and i % 100 == 0:
            current_time = time.time()
            cost_time = current_time - last_time
            last_time = current_time
            accu = model.evaulate(*next(iter(test_data)))
            cost_table[-1].append(cost_time)
            accu_table[-1].append(accu)

headers = ["iteration"] + [str(i * 100) for i in range(1, len(cost_table[0]) + 1)]
print("====== cost ======")
print(tabulate(cost_table, headers=headers, floatfmt=".3f"))
print("====== accu ======")
print(tabulate(accu_table, headers=headers, floatfmt=".3f"))


100%|██████████| 301/301 [00:02<00:00, 112.93it/s]
100%|██████████| 301/301 [00:02<00:00, 124.94it/s]
100%|██████████| 301/301 [00:05<00:00, 55.29it/s]
100%|██████████| 301/301 [00:14<00:00, 21.39it/s]

iteration            100    200    300
-----------------  -----  -----  -----
piecewise          0.828  0.897  0.919
pseudo-likelihood  0.742  0.771  0.861
perceptron         1.805  1.745  1.873
exact-likelihood   4.882  4.458  4.713
iteration            100    200    300
-----------------  -----  -----  -----
piecewise          0.843  0.876  0.896
pseudo-likelihood  0.101  0.801  0.894
perceptron         0.393  0.447  0.533
exact-likelihood   0.441  0.561  0.699





In [15]:
cost_table = []
accu_table = []

INFERENCE_METHODS = [Inference.VITERBI, Inference.BATCH_MEAN_FIELD]
for inference in INFERENCE_METHODS:
    cost_table.append([inference])
    accu_table.append([inference])

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=0.001)

for i in tqdm(range(701)):
    nodes, masks, targets = next(iter(train_data))
    loss = model(nodes, masks, targets)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i > 0 and i % 100 == 0:
        for inference in INFERENCE_METHODS:
            last_time = time.time()
            model.crf.inference = inference
            accu = model.evaulate(*next(iter(test_data)))
            cost_time = time.time() - last_time

            cost_table[INFERENCE_METHODS.index(inference)].append(cost_time)
            accu_table[INFERENCE_METHODS.index(inference)].append(accu)

headers = ["iteration"] + [str(i * 100) for i in range(1, len(cost_table[0]) + 1)]
print("====== cost ======")
print(tabulate(cost_table, headers=headers, floatfmt=".3f"))
print("====== accu ======")
print(tabulate(accu_table, headers=headers, floatfmt=".3f"))

100%|██████████| 701/701 [00:08<00:00, 85.61it/s] 

iteration           100    200    300    400    500    600    700
----------------  -----  -----  -----  -----  -----  -----  -----
viterbi           0.025  0.076  0.025  0.023  0.021  0.021  0.023
batch-mean-field  0.024  0.066  0.022  0.017  0.018  0.017  0.016
iteration           100    200    300    400    500    600    700
----------------  -----  -----  -----  -----  -----  -----  -----
viterbi           0.849  0.874  0.922  0.939  0.948  0.954  0.958
batch-mean-field  0.850  0.865  0.892  0.900  0.918  0.923  0.938



