In [1]:
%load_ext autoreload
%autoreload 2

import random
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, IterableDataset
from pseudo_data import SkipChainDataset

In [2]:
dataset = SkipChainDataset(50, True)
print(dataset.__doc__)
train_data = DataLoader(dataset, batch_size=20)
test_data = DataLoader(dataset, batch_size=100)

for i, batch in enumerate(train_data):
    nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label = batch
    print(f"=== {i} ===")
    print(dataset.to_sent(nodes))
    print(dataset.to_label(label))
    if i == 3:
        break


    This is a dataset simulating a general graph.
    An item looks like:
        _7____________5___E2__A5_____7__5____22___D7_@@@@@
         |                           |             |
         D                           D             D
    where:
    - 7 is labeled as "D" since one of 7s is a successor of a "D"
    - similarly, 5 is labeled as "A" and 2 is labeled as "E"
    - 3 numbers (7, 5, 2) occurs in an item, and each occurs 3 times,
        
    For graphical models, we can build one ternary factor or 3 binary factors:
            7                   7
           / \         or       |
          7---7              7_/ \_7
    
=== 0 ===
_______6_2F6__4__C44_A2____2_6_@@@@@@@@@@@@@@@@@@@
_______F_A_F__C___CC__A____A_F_@@@@@@@@@@@@@@@@@@@
=== 1 ===
______C3_7___3_5E5_____5___3_F77______@@@@@@@@@@@@
_______C_F___C_E_E_____E___C__FF______@@@@@@@@@@@@
=== 2 ===
_7__6_E6__B3___3___6G7______7_______3__@@@@@@@@@@@
_G__E__E___B___B___E_G______G_______B__@@@@@@@@@@@
=== 3 ===
_____2_3

In [3]:
from torch_random_fields.models import GeneralCRF
import einops
from einops.layers.torch import Rearrange
from torch_random_fields.models.constants import Inference, Learning


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embed = torch.nn.Embedding(num_embeddings=len(dataset.word2idx), embedding_dim=10)
        self.pred = torch.nn.Sequential(
            Rearrange("B T C -> B C T"),
            torch.nn.Conv1d(10, len(dataset.label2idx), 3, padding="same"),
            Rearrange("B C T -> B T C"),
        )
        # self.pred = torch.nn.Linear(10, len(dataset.label2idx))
        self.crf = GeneralCRF(
            num_states=len(dataset.label2idx),
            feature_size=10,
            beam_size=64,
            low_rank=10,
            learning=Learning.PIECEWISE,
            inference=Inference.BELIEF_PROPAGATION,
            support_ternary=True,
        )

    def forward(self, nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, targets):
        feats = self.embed(nodes)
        unaries = self.pred(feats)
        loss = self.crf(
            unaries=unaries,
            masks=masks,
            binary_edges=bin_edges,
            binary_masks=bin_masks,
            ternary_edges=ter_edges,
            ternary_masks=ter_masks,
            targets=targets,
            node_features=feats,
        )
        return loss

    def decode(self, nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks):
        return self(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, None)[1]

    def evaulate(self, nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label):
        for b in test_data:
            pred = self(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, None)[1]
            break
        corr = pred == label
        accu = corr[masks].sum() / masks.sum()
        return accu


pbar = tqdm(train_data)
model = Model()
opt = torch.optim.AdamW(model.parameters(), .01, weight_decay=0.01)
for i, batch in enumerate(pbar):
    nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label = batch
    loss = model(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label)
    opt.zero_grad()
    loss.backward()
    opt.step()

    if i % 100 == 0:
        accu = model.evaulate(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label)
        pbar.set_description("loss: {:.4f}, accu: {:.4f}".format(loss.item(), accu.item()))

    if i == 400:
        break

loss: 1.8042, accu: 1.0000: : 400it [00:09, 42.75it/s]


In [4]:
bid = 1
node_len = masks[bid].sum()
print(dataset.to_sent(nodes[bid, :node_len]))
print(dataset.to_label(label[bid, :node_len]))

pred = model.decode(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks)
print(dataset.to_label(pred[bid, :node_len]))


_5___73___5___A7____________E3_B5_____3____7______
_B___AE___B____A_____________E__B_____E____A______
_B___AE___B____A_____________E__B_____E____A______
