<a href="https://colab.research.google.com/github/harvardnlp/pytorch-struct/blob/master/notebooks/Unsupervised_CFG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q torchtext
!pip install -q pytorch-transformers
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct@cky

  Building wheel for torch-struct (setup.py) ... [?25l[?25hdone


In [0]:
import torchtext
import torch
from torch_struct import CKY, MaxSemiring, StdSemiring
import torch_struct.data

In [0]:
# Download and the load default data.
WORD = torchtext.data.Field(include_lengths=True)
UD_TAG = torchtext.data.Field(init_token="<bos>", eos_token="<eos>", include_lengths=True)

# Download and the load default data.
train, val, test = torchtext.datasets.UDPOS.splits(
    fields=(('word', WORD), ('udtag', UD_TAG), (None, None)), 
    filter_pred=lambda ex: 5 < len(ex.word) < 30
)

WORD.build_vocab(train.word, min_freq=3)
UD_TAG.build_vocab(train.udtag)
train_iter = torchtext.data.BucketIterator(train, 
    batch_size=20,
    device="cuda:0")


In [0]:
H = 256
T = 30
NT = 30
    
class Res(torch.nn.Module):
    def __init__(self, H):
        super().__init__()
        self.u1 = torch.nn.Linear(H, H)
        self.u2 = torch.nn.Linear(H, H)

        self.v1 = torch.nn.Linear(H, H)
        self.v2 = torch.nn.Linear(H, H)
        self.w = torch.nn.Linear(H, H)
        
    def forward(self, y):     
         y = self.w(y)
         y = y + torch.relu(self.v1(torch.relu(self.u1(y))))
         return y + torch.relu(self.v2(torch.relu(self.u2(y))))

params = []
def param(*size):
    p = torch.zeros(*size).cuda().requires_grad_(True)
    params.append(p)
    return p

# Params
word_emb = param(len(WORD.vocab), H)
term_emb = param(T, H)
nonterm_emb = param(NT, H)
nonterm_emb_c = param(NT+T, NT+T, H)
root_emb = param(NT, H)
s_emb = param(1, H)
mlp1 = Res(H)
mlp2 = Res(H)
mlp1.cuda()
mlp2.cuda()
all_params = params + list(mlp1.parameters()) + list(mlp2.parameters())
for p in all_params:
    if p.dim()> 1:
        torch.nn.init.xavier_uniform_(p) 

opt = torch.optim.Adam(all_params, lr=0.001, betas=[0.75, 0.999])

In [0]:
def terms(words):
    return torch.einsum("bnh,th->bnt", word_emb[words], mlp1(term_emb)).log_softmax(-2)
    
def rules(b):
    return torch.einsum("sh,tuh->stu", nonterm_emb, nonterm_emb_c).view(NT, -1).log_softmax(-1).view(1, NT, NT+T, NT+T).expand(b, NT, NT +T, NT+T)
    
def roots(b):
    return torch.einsum("ah,th->t", s_emb, mlp2(root_emb)).log_softmax(-1).view(1, NT).expand(b, NT)

In [0]:
def train():
    #model.train()
    losses = []
    for epoch in range(10):
        for i, ex in enumerate(train_iter):
            opt.zero_grad()
            words, lengths = ex.word 

            N, batch = words.shape
            words = words.long()
            params = terms(words.transpose(0, 1)), rules(batch), roots(batch)

            log_partition = CKY().sum(params, lengths=lengths, _autograd=True)
            loss = log_partition.mean()
            (-loss).backward()
            losses.append(loss.detach())
            torch.nn.utils.clip_grad_norm_(all_params, 3.0)
            opt.step()

            if i % 100 == 1:            
                print(-torch.tensor(losses).mean(), words.shape)
                losses = []

In [7]:
train()

tensor(57.9222) torch.Size([29, 20])
tensor(48.2354) torch.Size([28, 20])
tensor(43.9638) torch.Size([26, 20])
tensor(43.6623) torch.Size([24, 20])
tensor(42.9041) torch.Size([29, 20])
tensor(41.5489) torch.Size([27, 20])
tensor(41.1781) torch.Size([26, 20])
tensor(38.9022) torch.Size([26, 20])
tensor(39.2924) torch.Size([28, 20])
tensor(39.6077) torch.Size([29, 20])
tensor(39.4375) torch.Size([28, 20])
tensor(37.8449) torch.Size([27, 20])
tensor(37.7898) torch.Size([25, 20])
tensor(37.6564) torch.Size([28, 20])
tensor(37.9299) torch.Size([24, 20])
tensor(38.0394) torch.Size([29, 20])
tensor(36.7634) torch.Size([27, 20])
tensor(36.8474) torch.Size([25, 20])
tensor(36.4528) torch.Size([28, 20])
tensor(36.6790) torch.Size([29, 20])
tensor(35.7165) torch.Size([28, 20])
tensor(35.1808) torch.Size([27, 20])
tensor(35.5481) torch.Size([25, 20])
tensor(35.9955) torch.Size([29, 20])
tensor(36.2696) torch.Size([28, 20])
tensor(37.3094) torch.Size([28, 20])
tensor(34.9490) torch.Size([22, 20])
t

In [8]:
for i, ex in enumerate(train_iter):
    opt.zero_grad()
    words, lengths = ex.word 

    N, batch = words.shape
    words = words.long()
    params = terms(words.transpose(0, 1)), rules(batch), roots(batch)

    tree = CKY(MaxSemiring).marginals(params, lengths=lengths, _autograd=True)
    print(tree)
    break

(tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,

In [0]:
def split(spans):
    batch, N = spans.shape[:2]
    splits = []
    for b in range(batch):
        cover = spans[b].nonzero()
        left = {i: [] for i in range(N)}
        right = {i: [] for i in range(N)}
        batch_split = {}
        for i in range(cover.shape[0]):
            i, j, A = cover[i].tolist()
            left[i].append((A, j, j - i + 1))
            right[j].append((A, i, j - i + 1))
        for i in range(cover.shape[0]):
            i, j, A = cover[i].tolist()
            B = None
            for B_p, k, a_span in left[i]:
                for C_p, k_2, b_span in right[j]:
                    if k_2 == k + 1 and a_span + b_span == j - i + 1:
                        B, C = B_p, C_p
                        k_final = k
                        break
            if j > i:
                batch_split[(i, j)] =k
        splits.append(batch_split)
    return splits 

In [0]:
splits = split(spans)

In [30]:
splits[0]

{(0, 2): 24,
 (0, 24): 24,
 (1, 2): 2,
 (3, 24): 24,
 (4, 8): 24,
 (4, 24): 24,
 (5, 8): 8,
 (6, 7): 8,
 (6, 8): 8,
 (9, 24): 24,
 (10, 24): 24,
 (11, 23): 24,
 (11, 24): 24,
 (12, 23): 23,
 (13, 23): 23,
 (14, 19): 23,
 (14, 21): 23,
 (14, 22): 23,
 (14, 23): 23,
 (15, 19): 19,
 (16, 19): 19,
 (17, 18): 19,
 (17, 19): 19,
 (20, 21): 21}

In [45]:
!pip install pydot-ng



In [0]:
spans, _ = CKY().from_parts(tuple((t.cpu() for t in tree)))
import networkx as nx
from networkx.drawing.nx_agraph import write_dot, graphviz_layout
import matplotlib.pyplot as plt
G = nx.DiGraph()
for n in spans[0].nonzero():
    G.add_node((n[0], n[1]), label="n[3]")
for k, v in splits[0].items():
    G.add_edge(k, (k[0],v))
    G.add_edge(k, (v+1, k[1]))


# run "dot -Tpng test.dot >test.png"
#write_dot(G,'test.dot')

# same layout using matplotlib with no labels
##plt.title('draw_networkx')
#pos =graphviz_layout(G, prog='dot')
#nx.draw(G, pos, with_labels=False, arrows=True)
#plt.savefig('nx_test.png')