In [1]:
%load_ext autoreload
%autoreload 2

import random
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, IterableDataset

In [2]:
class PseudoDataset(IterableDataset):
    def __init__(self, length, skip_as_ternary) -> None:
        super().__init__()

        self.length = length
        self.skip_as_ternary = skip_as_ternary

        vocab_ = "ABCDEFG1234567_@"
        self.word2idx = {ele: idx for idx, ele in enumerate(vocab_)}
        self.idx2word = {idx: ele for idx, ele in enumerate(vocab_)}

        vocab_ = "ABCDEFG_@"
        self.label2idx = {ele: idx for idx, ele in enumerate(vocab_)}
        self.idx2label = {idx: ele for idx, ele in enumerate(vocab_)}

        self.pad = "@"

    def generate(self):
        seqlen = random.randint(int(self.length * 0.6), self.length)
        sent = list("_" * seqlen + self.pad * (self.length - seqlen))
        label = list("_" * seqlen + self.pad * (self.length - seqlen))
        names = list("ABCDEFG")
        random.shuffle(names)
        names = names[:3]
        nicks = list("1234567")
        random.shuffle(nicks)
        nicks = nicks[:3]
        locs = list(range(seqlen))
        random.shuffle(locs)

        bin_edges = [(i, i+1) for i in range(seqlen - 2)]
        ter_edges = []
        for i in range(3):
            loc1 = locs.pop()
            while loc1 - 1 not in locs:
                locs.insert(0, loc1)
                loc1 = locs.pop()
            locs.remove(loc1 - 1)
            loc2 = locs.pop()
            loc3 = locs.pop()
            sent[loc1 - 1] = names[i]
            sent[loc1], sent[loc2], sent[loc3] = nicks[i], nicks[i], nicks[i]
            label[loc1], label[loc2], label[loc3] = names[i], names[i], names[i]

            if not self.skip_as_ternary:
                bin_edges.append((loc1, loc2))
                bin_edges.append((loc2, loc3))
                bin_edges.append((loc1, loc3))
            else:
                ter_edges.append((loc1, loc2, loc3))
        
        bin_edges_real_len = len(bin_edges)
        bin_edges.extend([(0, 0) for _ in range(self.length - seqlen)])

        sent = "".join(sent)
        label = "".join(label)

        nodes = torch.tensor([self.word2idx[ele] for ele in sent])
        bin_edges = torch.tensor(bin_edges)
        ter_edges = torch.tensor(ter_edges)
        node_masks = nodes != self.word2idx[self.pad]
        bin_edge_masks = torch.zeros([len(bin_edges)]).bool()
        bin_edge_masks[:bin_edges_real_len] = True
        ter_edge_masks = torch.ones([len(ter_edges)]).bool()
        label = torch.tensor([self.label2idx[ele] for ele in label])
        return nodes, node_masks, bin_edges, bin_edge_masks, ter_edges, ter_edge_masks, label

    def __iter__(self):
        while True:
            yield self.generate()

    def to_sent(self, tensor):
        sent = [dataset.idx2word[ele] for ele in tensor.view(-1).tolist()]
        return "".join(sent)

    def to_label(self, tensor):
        label = [dataset.idx2label[ele] for ele in tensor.view(-1).tolist()]
        return "".join(label)

dataset = PseudoDataset(50, True)
dataset.generate()
train_data = DataLoader(dataset, batch_size=20)
test_data = DataLoader(dataset, batch_size=100)

for b in train_data:
    nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label = b
    # print(ter_edge_masks)
    break

In [3]:
ter_masks.shape

torch.Size([20, 3])

In [4]:
from torch_random_fields.models import GeneralCRF
import einops
from einops.layers.torch import Rearrange
from torch_random_fields.models.constants import Inference, Training

class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embed = torch.nn.Embedding(num_embeddings=len(dataset.word2idx), embedding_dim=10)
        self.pred = torch.nn.Sequential(
            Rearrange("B T C -> B C T"),
            torch.nn.Conv1d(10, len(dataset.label2idx), 3, padding="same"),
            Rearrange("B C T -> B T C"),
        )
        # self.pred = torch.nn.Linear(10, len(dataset.label2idx))
        self.crf = GeneralCRF(
            num_states=len(dataset.label2idx), 
            feature_size=10, 
            beam_size=64,
            low_rank=10, 
            training=Training.PIECEWISE, 
            inference=Inference.BELIEF_PROPAGATION,
            support_ternary=True,
        )
    
    def forward(self, nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, targets):
        feats = self.embed(nodes)
        unaries = self.pred(feats)
        loss = self.crf(
            unaries=unaries, 
            masks=masks,
            binary_edges=bin_edges,
            binary_masks=bin_masks,
            ternary_edges=ter_edges,
            ternary_masks=ter_masks,
            targets=targets,
            node_features=feats
        )
        return loss

    def evaulate(self, nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label):
        for b in test_data:
            pred = self(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, None)[1]
            break
        corr = pred == label
        accu = corr[masks].sum() / masks.sum()
        return accu


pbar = tqdm(train_data)
model = Model()
opt = torch.optim.AdamW(model.parameters(), .01, weight_decay=0.01)
for i, batch in enumerate(pbar):
    nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label = batch
    loss = model(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label)
    opt.zero_grad()
    loss.backward()
    opt.step()
    
    if i % 100 == 0:
        accu = model.evaulate(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, label)
        pbar.set_description("loss: {:.4f}, accu: {:.4f}".format(loss.item(), accu.item()))
    
    if i == 800:
        break

loss: 1.8468, accu: 1.0000: : 800it [00:17, 46.99it/s]


In [5]:
bid = 1
node_len = masks[bid].sum()
print(dataset.to_sent(nodes[bid, :node_len]))
print(dataset.to_label(label[bid, :node_len]))

pred = model(nodes, masks, bin_edges, bin_masks, ter_edges, ter_masks, None)[1]
print(dataset.to_label(pred[bid, :node_len]))



32_1__2___C1_____B3__3G2_1__________
BG_C__G____C______B__B_G_C__________
BG_C__G____C______B__B_G_C__________
