### GCN preprocess

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.datasets import TranslationDataset, Multi30k, IWSLT
from torchtext.data import Field, BucketIterator, RawField, Dataset

import spacy
import numpy as np

import random
import math
import time

#### Functions
---

In [2]:
from tqdm import tqdm
from spacy import displacy
from collections import defaultdict
from torch_geometric.data import Data
from torch_geometric.utils import add_remaining_self_loops, to_undirected, to_dense_adj

problematic_tokens = []

def build_graph(tokens, spacy_model):
    """
    build adjacency graph for provided list of tokens
    - rules:
        - tokens are connected if they are immediately linked in dep tree
        - punctuations are connected to their neighbors
        - add <sos> and <eos> around the tokens, and linked to 
            the first token and last token respectively
        - add self connection
    """
    sent = " ".join(tokens)
    doc = spacy_model(sent)
    tokens_ = ['<sos>'] + tokens + ['<eos>']
    n = len(tokens_)
    M = torch.zeros([n, n])
    edges = [(0,1), (n-2, n-1)] # add dep for sos and eos
    for parent in doc:
        parent_i = parent.i + 1  # sos at the beginning and eos in the end
        for child in parent.children:
            child_i = child.i + 1
            edges.append((parent_i, child_i))
            
        if parent.is_punct:
            prev_i = max(parent_i-1, 0)
            next_i = min(parent_i+1, n-1)
            edges.extend([(prev_i, parent_i), (parent_i, next_i)])
    
    edges = torch.tensor(list(set(edges)), dtype=torch.long).t().contiguous()
    edges = add_remaining_self_loops(edges, num_nodes=n)[0]
    edges = to_undirected(edges) # add self loops and to undirected
    M = to_dense_adj(edge_index=edges).squeeze()

    try:
        assert(tuple(M.shape) == (n, n))
    except:
        print(tokens)
        problematic_tokens.extend(tokens)
    return M


def normalize_graph(M):
    """
    M = D**(-1/2) M D**(-1/2)
    """
    deg = M.sum(axis=1).squeeze()
    deg = deg ** (-1/2)
    D = torch.diag(deg)
    return D.mm(M.mm(D))


def build_graphs(dataset, spacy_model, normalize=True):
    """ build graphs for every example int the torchtext dataset """
    GRH = RawField(postprocessing=None)
    for d in tqdm(dataset):
        tokens = d.src
        d.grh = build_graph(tokens, spacy_model)
        if normalize:
            d.grh = normalize_graph(d.grh)
    dataset.fields['grh'] = GRH
    return dataset


def validate_data(dataset1, dataset2):
    for i in range(len(dataset1)):
        d = dataset1[i]
        d_ = dataset2[i]
        assert(' '.join(d.src) == ' '.join(d_.src))
        assert(' '.join(d.trg) == ' '.join(d_.trg))
        assert((d.grh == d_.grh).all())
    print("datasets are the same!!")

#### Multi30k
---
German -> English

In [3]:
SEED = 11747

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [5]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [6]:
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TGT = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

GRH = RawField(postprocessing=None)

data_fields = [('src', SRC), ('trg', TGT), ('grh', GRH)]

SEED = 11747
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), fields = (SRC, TGT))

In [7]:
problematic_tokens = []
%time train_data = build_graphs(train_data, spacy_de)
%time valid_data = build_graphs(valid_data, spacy_de)
%time test_data = build_graphs(test_data, spacy_de)
problematic_tokens = set(problematic_tokens)

torch.save(list(train_data), "data/Multi30k/train_data.pt")
torch.save(list(valid_data), "data/Multi30k/valid_data.pt")
torch.save(list(test_data), "data/Multi30k/test_data.pt")

100%|██████████| 29000/29000 [04:34<00:00, 105.49it/s]
  1%|          | 11/1014 [00:00<00:09, 104.86it/s]

CPU times: user 4min 35s, sys: 1.03 s, total: 4min 37s
Wall time: 4min 34s


100%|██████████| 1014/1014 [00:09<00:00, 105.53it/s]
  1%|          | 11/1000 [00:00<00:09, 100.84it/s]

CPU times: user 9.66 s, sys: 20.2 ms, total: 9.68 s
Wall time: 9.61 s


100%|██████████| 1000/1000 [00:09<00:00, 107.53it/s]


CPU times: user 9.35 s, sys: 16.2 ms, total: 9.37 s
Wall time: 9.3 s


In [8]:
train_data_ = Dataset(torch.load("data/Multi30k/train_data.pt"), data_fields)
valid_data_ = Dataset(torch.load("data/Multi30k/valid_data.pt"), data_fields)
test_data_ = Dataset(torch.load("data/Multi30k/test_data.pt"), data_fields)

In [9]:
validate_data(train_data, train_data_)
validate_data(valid_data, valid_data_)
validate_data(test_data, test_data_)

datasets are the same!!
datasets are the same!!
datasets are the same!!


In [10]:
doc = spacy_de(" ".join(train_data_[0].src))
displacy.render(doc)

In [11]:
vars(train_data_[0])

{'src': ['zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.'],
 'trg': ['two',
  'young',
  ',',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.'],
 'grh': tensor([[0.5000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4082, 0.3333, 0.0000, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.5000, 0.3536, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.2887, 0.3536, 0.2500, 0.2887, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.2887, 0.3333, 0.2582, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.00

#### ISWLT
---

In [12]:
SEED = 11747

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [13]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [14]:
def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [15]:
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TGT = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

GRH = RawField(postprocessing=None)

data_fields = [('src', SRC), ('trg', TGT), ('grh', GRH)]

SEED = 11747
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
train_data, valid_data, test_data = IWSLT.splits(exts = ('.de', '.en'), fields = (SRC, TGT))

In [16]:
ptokens = ['st.', '3m.', 'dr.', 'sog.', 'l.a.', 'u.n.', 'a.k.', 'p.i.', 'r.c.', 'mr.', 'u.s.', 'z.q.', 'z.', 'ca.', 'mio.', 'mrd.']
replacement = [w.replace('.', '') for w in ptokens]
ptokens_map = dict(zip(ptokens, replacement))
def map_src_tokens(dataset, mapper):
    for example in dataset:
        sent = example.src
        sent = [mapper[w] if w in mapper else w for w in sent]
        example.src = sent

map_src_tokens(train_data, ptokens_map)
map_src_tokens(valid_data, ptokens_map)
map_src_tokens(test_data, ptokens_map)

In [17]:
train_data[0].src

['david',
 'gallo',
 ':',
 'das',
 'ist',
 'bill',
 'lange',
 '.',
 'ich',
 'bin',
 'dave',
 'gallo',
 '.']

Some German tokens messed up the dependency parsing, so we will have to filter them out...
* ['st.', '3m.', 'dr.', 'sog.', 'l.a.', 'u.n.', 'a.k.', 'p.i.', 'r.c.', 'mr.', 'u.s.', '.', 'z.q.', 'z.', 'ca.', 'mio.', 'mrd.']

In [20]:
problematic_tokens = []
%time train_data = build_graphs(train_data, spacy_de)
%time valid_data = build_graphs(valid_data, spacy_de)
%time test_data = build_graphs(test_data, spacy_de)
problematic_tokens = set(problematic_tokens)

torch.save(list(train_data), "data/ISWLT/train_data.pt")
torch.save(list(valid_data), "data/ISWLT/valid_data.pt")
torch.save(list(test_data), "data/ISWLT/test_data.pt")

100%|██████████| 196884/196884 [34:43<00:00, 94.50it/s] 
  1%|          | 10/993 [00:00<00:09, 99.30it/s]

CPU times: user 52min 1s, sys: 54.5 s, total: 52min 55s
Wall time: 34min 43s


100%|██████████| 993/993 [00:10<00:00, 92.81it/s] 
  1%|          | 11/1305 [00:00<00:12, 101.32it/s]

CPU times: user 17 s, sys: 423 ms, total: 17.4 s
Wall time: 10.7 s


100%|██████████| 1305/1305 [00:13<00:00, 94.70it/s] 


CPU times: user 21.3 s, sys: 447 ms, total: 21.7 s
Wall time: 13.8 s


In [21]:
ptokens = [w for w in problematic_tokens if '.' == w[-1]]
replacement = [w.replace('.', 'Dot') for w in ptokens]
ptokens_map = dict(zip(ptokens, replacement))
print(ptokens_map)

{}


In [22]:
train_data_ = Dataset(torch.load("data/ISWLT/train_data.pt"), data_fields)
valid_data_ = Dataset(torch.load("data/ISWLT/valid_data.pt"), data_fields)
test_data_ = Dataset(torch.load("data/ISWLT/test_data.pt"), data_fields)

In [23]:
validate_data(train_data, train_data_)
validate_data(valid_data, valid_data_)
validate_data(test_data, test_data_)

datasets are the same!!
datasets are the same!!
datasets are the same!!


In [24]:
doc = spacy_de(" ".join(train_data_[0].src))
displacy.render(doc)

In [25]:
" ".join(train_data_[0].trg)

"david gallo : this is bill lange . i 'm dave gallo ."

In [26]:
vars(train_data_[0])

{'src': ['david',
  'gallo',
  ':',
  'das',
  'ist',
  'bill',
  'lange',
  '.',
  'ich',
  'bin',
  'dave',
  'gallo',
  '.'],
 'trg': ['david',
  'gallo',
  ':',
  'this',
  'is',
  'bill',
  'lange',
  '.',
  'i',
  "'m",
  'dave',
  'gallo',
  '.'],
 'grh': tensor([[0.5000, 0.4082, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4082, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.3333, 0.3333, 0.2887, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,

### How to Apply
---

In [27]:
SEED = 11747
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [28]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

def tokenize_de(text):
    """
    Tokenizes German text from a string into a list of strings (tokens) and reverses it
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
    Tokenizes English text from a string into a list of strings (tokens)
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [29]:
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)
TGT = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)
GRH = RawField(postprocessing=None)
data_fields = [('src', SRC), ('trg', TGT), ('grh', GRH)]

In [30]:
train_data = Dataset(torch.load("data/Multi30k/train_data.pt"), data_fields)
valid_data = Dataset(torch.load("data/Multi30k/valid_data.pt"), data_fields)
test_data = Dataset(torch.load("data/Multi30k/test_data.pt"), data_fields)

In [31]:
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)