In [16]:
%load_ext autoreload
%autoreload 2

import random
import torch
from torch.utils.data import DataLoader, IterableDataset

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
class PseudoDataset(IterableDataset):
    def __init__(self, length=25) -> None:
        super().__init__()

        self.length = length

        vocab_ = [str(i) for i in range(10)] + [".", "+", "-", "*", "/", "=", "^", "$", "@"]
        self.word2idx = {ele: idx for idx, ele in enumerate(vocab_)}
        self.idx2word = {idx: ele for idx, ele in enumerate(vocab_)}

        vocab_ = ["B", "M", "E", "S", "@"]
        self.label2idx = {ele: idx for idx, ele in enumerate(vocab_)}
        self.idx2label = {idx: ele for idx, ele in enumerate(vocab_)}

        self.pad = "@"

    def rand_word(self):
        max_num = 10**(self.length // 8)
        x = random.randint(0, 1)
        ret = ""
        if x == 1:
            ret += "-"
        else:
            ret += ""
        x = random.randint(0, 3)
        if x == 0:
            ret += str(random.randint(1, max_num))
        elif x == 1:
            ret += str(random.randint(1, max_num))
            ret += "."
        elif x == 2:
            ret += str(random.randint(1, max_num))
            ret += "."
            ret += str(random.randint(1, max_num))
        else:
            ret += "."
            ret += str(random.randint(1, max_num))
        return ret

    def label_from_length(self, num):
        if num == 1:
            return "S"
        else:
            return "B" + "M" * (num - 2) + "E"

    def generate(self):
        x1 = self.rand_word()
        x2 = self.rand_word()
        y = self.rand_word()
        op = "+-*/"[random.randint(0, 3)]
        sent = ["^", x1, op, x2, "=", y, "$"]
        label = [self.label_from_length(len(ele)) for ele in sent]
        sent = "".join(sent)
        sent += (self.length - len(sent)) * self.pad
        label = "".join(label)
        label += (self.length - len(label)) * self.pad
        sent = torch.tensor([self.word2idx[ele] for ele in sent])
        label = torch.tensor([self.label2idx[ele] for ele in label])
        masks = sent != self.word2idx[self.pad]
        return sent, masks, label

    def __iter__(self):
        while True:
            yield self.generate()

dataset = PseudoDataset(length=75)
train_data = DataLoader(dataset, batch_size=20)
test_data = DataLoader(dataset, batch_size=100)

In [22]:
import time
from tqdm import tqdm
from tabulate import tabulate
from torch_random_fields.models import LinearChainCRF
from torch_random_fields.models.constants import Inference, Training

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embed = torch.nn.Embedding(num_embeddings=len(dataset.word2idx), embedding_dim=10)
        self.pred = torch.nn.Linear(10, len(dataset.label2idx))
        self.crf = LinearChainCRF(
            len(dataset.label2idx),
            low_rank=5,
            training=Training.PIECEWISE,
            inference=Inference.VITERBI,
            feature_size=10,
        )

    def forward(self, nodes, masks, targets):
        feats = self.embed(nodes)
        unaries = self.pred(feats)
        loss = self.crf(unaries=unaries, targets=targets, masks=masks, node_features=feats)
        return loss

    def decode(self, nodes, masks):
        unaries = self.pred(self.embed(nodes))
        return self.crf(unaries, masks=masks, node_features=None)[1]

    def evaulate(self, nodes, masks, targets):
        pred = self(nodes, masks, None)[1]
        pred.masked_fill_(~masks, dataset.label2idx[dataset.pad])
        corr = pred == targets
        accu = corr[masks].sum() / masks.sum()
        return accu


cost_table = []
accu_table = []

for training in [Training.PIECEWISE, Training.PSEUDO_LIKELIHOOD, Training.PERCEPTRON, Training.EXACT_LIKELIHOOD]:
# for training in [Training.PERCEPTRON]:
    cost_table.append([training])
    accu_table.append([training])

    model = Model()
    model.crf.training = training
    opt = torch.optim.Adam(model.parameters(), lr=0.001)

    last_time = time.time()
    for i in tqdm(range(301)):
        nodes, masks, targets = next(iter(train_data))
        loss = model(nodes, masks, targets)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if i > 0 and i % 100 == 0:
            current_time = time.time()
            cost_time = current_time - last_time
            last_time = current_time
            accu = model.evaulate(*next(iter(test_data)))
            cost_table[-1].append(cost_time)
            accu_table[-1].append(accu)

headers = ["iteration"] + [str(i * 100) for i in range(1, len(cost_table[0]) + 1)]
print("====== cost ======")
print(tabulate(cost_table, headers=headers, floatfmt=".3f"))
print("====== accu ======")
print(tabulate(accu_table, headers=headers, floatfmt=".3f"))


100%|██████████| 301/301 [00:01<00:00, 159.37it/s]
100%|██████████| 301/301 [00:01<00:00, 162.33it/s]
100%|██████████| 301/301 [00:03<00:00, 77.85it/s]
100%|██████████| 301/301 [00:06<00:00, 44.12it/s]

iteration            100    200    300
-----------------  -----  -----  -----
piecewise          0.597  0.614  0.661
pseudo-likelihood  0.599  0.669  0.571
perceptron         1.267  1.342  1.241
exact-likelihood   2.237  2.361  2.208
iteration            100    200    300
-----------------  -----  -----  -----
piecewise          0.795  0.816  0.847
pseudo-likelihood  0.681  0.840  0.855
perceptron         0.287  0.370  0.496
exact-likelihood   0.285  0.708  0.737





In [23]:
cost_table = []
accu_table = []

INFERENCE_METHODS = [Inference.VITERBI, Inference.BATCH_MEAN_FIELD]
for inference in INFERENCE_METHODS:
    cost_table.append([inference])
    accu_table.append([inference])

model = Model()
opt = torch.optim.Adam(model.parameters(), lr=0.001)

for i in tqdm(range(701)):
    nodes, masks, targets = next(iter(train_data))
    loss = model(nodes, masks, targets)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if i > 0 and i % 100 == 0:
        for inference in INFERENCE_METHODS:
            last_time = time.time()
            model.crf.inference = inference
            accu = model.evaulate(*next(iter(test_data)))
            cost_time = time.time() - last_time

            cost_table[INFERENCE_METHODS.index(inference)].append(cost_time)
            accu_table[INFERENCE_METHODS.index(inference)].append(accu)

headers = ["iteration"] + [str(i * 100) for i in range(1, len(cost_table[0]) + 1)]
print("====== cost ======")
print(tabulate(cost_table, headers=headers, floatfmt=".3f"))
print("====== accu ======")
print(tabulate(accu_table, headers=headers, floatfmt=".3f"))

100%|██████████| 701/701 [00:04<00:00, 155.13it/s]

iteration           100    200    300    400    500    600    700
----------------  -----  -----  -----  -----  -----  -----  -----
viterbi           0.064  0.017  0.017  0.017  0.017  0.019  0.017
batch-mean-field  0.016  0.014  0.014  0.014  0.014  0.015  0.014
iteration           100    200    300    400    500    600    700
----------------  -----  -----  -----  -----  -----  -----  -----
viterbi           0.848  0.871  0.894  0.909  0.934  0.959  0.969
batch-mean-field  0.850  0.877  0.888  0.930  0.955  0.962  0.964



