## Data

In [115]:
%load_ext autoreload
%autoreload 2   

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [116]:
# Generate simple data for experimenting a CRF model

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
np.random.seed(0)

n_samples = 1000
n_features = 768
n_states = 3

X = np.random.randn(n_samples, n_features) # Sample like bert features
y = np.random.randint(n_states, size=n_samples) # Random labels

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)




## torchcrf

In [491]:
features = emissions[:,1]
unary_features = features.view(B, C).unsqueeze(1)
transitions.unsqueeze(0).shape


torch.Size([1, 2, 2])

In [618]:
import torch
from torchcrf import CRF
torch.manual_seed(0)

C = 2  
N = 3  
B = 1

emissions = torch.randn(B, N, C) # emission scores = unary potentials
transitions = torch.randn(C, C) # transition scores = binary potentials
tags = torch.tensor([[0, 1, 1]], dtype=torch.long)  # (B, N)
print("emissions", emissions)
print('transitions', transitions)
print('tags', tags)

model = CRF(C, batch_first=True)
mask = torch.ones_like(tags, dtype=torch.uint8)

log_num = model._compute_score(emissions.transpose(0,1), mask=mask.transpose(0,1), tags=tags.transpose(0,1))
log_denom = model._compute_normalizer(emissions.transpose(0,1), mask.transpose(0,1))

nll = log_num - log_denom
print("nll: ", -nll.sum())
# print('nll by torch crf :', -model(emissions, tags))

emissions tensor([[[ 1.5410, -0.2934],
         [-2.1788,  0.5684],
         [-1.0845, -1.3986]]])
transitions tensor([[ 0.4033,  0.8380],
        [-0.7193, -0.4033]])
tags tensor([[0, 1, 1]])
nll:  tensor(1.1161, grad_fn=<NegBackward0>)


## torch_struct

In [633]:
import torch_struct
from torch_struct import LinearChainCRF
print("B:", B, "N:", N, "C:", C)
lengths = torch.tensor([3])

# Emission scores (output of encoder) # (B, N, C) (torchcrf) -> (B, N, C, 1) (torch_struct)
emissions_struct = emissions.view(B, N, C, 1)
print("emissions_struct", emissions_struct)
print("emissions_struct.shape", emissions_struct.shape)

# Transition scores                   # (C, C) (torchcrf) -> (1, 1, C, C) (torch_struct)
transitions_struct = transitions.view(1, 1, C, C)
print("transitions_struct", transitions_struct)
print("transitions_struct.shape", transitions_struct.shape)

# Score 
score = emissions_struct[:, 1:N] + transitions_struct
print("score.shape", score.shape)

dist = torch_struct.LinearChainCRF(score)
labels_edges = LinearChainCRF.struct.to_parts(tags.view(B, N), C, lengths=lengths).type_as(dist.log_potentials)
print('nll', -dist.log_prob(labels_edges))

B: 1 N: 3 C: 2
emissions_struct tensor([[[[ 1.5410],
          [-0.2934]],

         [[-2.1788],
          [ 0.5684]],

         [[-1.0845],
          [-1.3986]]]])
emissions_struct.shape torch.Size([1, 3, 2, 1])
transitions_struct tensor([[[[ 0.4033,  0.8380],
          [-0.7193, -0.4033]]]])
transitions_struct.shape torch.Size([1, 1, 2, 2])
score.shape torch.Size([1, 2, 2, 2])
nll tensor([3.8205], grad_fn=<NegBackward0>)
