In [1]:
from src.edit_tagger import lcs, lcs_traceback, bi_directional_traceback, perform_edit, random_edit
import torchtext
from torchtext.data import Field, BucketIterator
from src.vocab_classes import Shared_Vocab
import torch
import torch.nn as nn
from src.BERT_style_modules import BERTStyleEncoder, BERTStyleDecoder
from src.exposed_transformer import Transformer
from torch.nn.modules.activation import MultiheadAttention
import math
import numpy as np
from src.trainers import Model_Trainer
import random

%load_ext autoreload
%autoreload 2

In [32]:
a = torch.rand(3,3) # B x D
b = torch.rand(3,3).T # B x D
torch.mm(a,b)

tensor([[0.7083, 1.0442, 0.4933],
        [0.5611, 1.1759, 0.7063],
        [0.4835, 0.7634, 0.4113]])

In [36]:
import tensorflow as tf

In [39]:
tf_a = tf.constant([1.0,2,3,4])
tf.reduce_max(tf_a)

<tf.Tensor: id=6, shape=(), dtype=float32, numpy=4.0>

In [38]:
tf.maximum(0., tf_a)

<tf.Tensor: id=3, shape=(4,), dtype=float32, numpy=array([1., 2., 3., 4.], dtype=float32)>

In [41]:
a

tensor([[0.5794, 0.1694, 0.9146],
        [0.0839, 0.7236, 0.6458],
        [0.4473, 0.2983, 0.3775]])

In [42]:
torch.diagonal(a)

tensor([0.5794, 0.7236, 0.3775])

In [33]:
neg_matrix = torch.diag(tensor.new_full((3,), float('-inf')))

In [40]:
torch.max(torch.relu(a+neg_matrix), dim=-1)

torch.return_types.max(
values=tensor([0.9146, 0.6458, 0.4473]),
indices=tensor([2, 2, 0]))

In [27]:
torch.max(a, dim=-1, keepdim=True)

torch.return_types.max(
values=tensor([[0.9518],
        [0.9623],
        [0.9947]]),
indices=tensor([[1],
        [2],
        [0]]))

In [24]:
tensor = torch.ones(())
torch.diag(tensor.new_full((3,), float('-inf')))

tensor([[-inf, 0., 0.],
        [0., -inf, 0.],
        [0., 0., -inf]])

In [2]:
start  = "ABC"
target = "CBC"
bi_directional_traceback(start, target)

[('Insert', 'C', 0), ('Delete', 'A', 0)]

In [3]:
perform_edit("BBC",('Insert', 'C', 3))

['B', 'B', 'C', 'C']

In [4]:
random_edit("BBC",["A","B","C"])

['B', 'B', 'A', 'C']

In [5]:
def sample_random_edits(seq, edit_depth=2, num_samples=100):
    samples = []
    for i in range(num_samples):
        edit_seq = seq[:]
        for j in range(edit_depth):
            edit_seq = random_edit(edit_seq,["A","B","C","D","E"])
        samples.append(edit_seq)
    return samples

In [6]:
def only_unique(list_data):
    return list(map(list, set(map(lambda i: tuple(i), list_data))))

In [7]:
def balance_actions(list_data):
    only_deletes = [x for x in list_data if x[2][0] == "Delete"]
    only_inserts = [x for x in list_data if x[2][0] == "Insert"]
    equal_deletes = only_deletes[:len(only_inserts)]
    equal_inserts = only_inserts[:len(only_deletes)]
    balanced_data = equal_inserts + equal_deletes
    random.shuffle(balanced_data)
    return balanced_data

In [9]:
target_seq = ["B","B","C"]
moving_starting_points = only_unique(sample_random_edits(target_seq, edit_depth=4, num_samples=1000))

edit_dataset = []
for start_seq in moving_starting_points:
    edits = bi_directional_traceback(start_seq, target_seq)
    for edit in edits:
        sample = (start_seq, target_seq, edit)
        edit_dataset.append(sample)
edit_dataset = balance_actions(edit_dataset)

In [10]:
print(f"filtered edit dataset size: {len(edit_dataset)}")
edit_dataset[:10]

filtered edit dataset size: 444


[(['A', 'B', 'B', 'C', 'B'], ['B', 'B', 'C'], ('Delete', 'A', 0)),
 (['B', 'C', 'D', 'D'], ['B', 'B', 'C'], ('Insert', 'B', 0)),
 (['B', 'B', 'A', 'B'], ['B', 'B', 'C'], ('Insert', 'C', 2)),
 (['D', 'B', 'E'], ['B', 'B', 'C'], ('Insert', 'C', 2)),
 (['B', 'B'], ['B', 'B', 'C'], ('Insert', 'C', 2)),
 (['E', 'B', 'C', 'B', 'B'], ['B', 'B', 'C'], ('Delete', 'C', 2)),
 (['C', 'E', 'B', 'C', 'A'], ['B', 'B', 'C'], ('Insert', 'B', 2)),
 (['C', 'B', 'E', 'B', 'B'], ['B', 'B', 'C'], ('Insert', 'C', 4)),
 (['B'], ['B', 'B', 'C'], ('Insert', 'C', 1)),
 (['B', 'B', 'E', 'B', 'D'], ['B', 'B', 'C'], ('Delete', 'D', 4))]

In [11]:
class Simple_Edit_Vocab():
    def __init__(self, vocab_tokens, tokenizer_fn):
        self.tokenizer_fn = tokenizer_fn
        
        self.action_stoi = {"Delete":0, "Insert":1, "None":2}
        self.action_itos = ["Delete", "Insert", "None"]
        self.token_stoi = {"[PAD]":0, "[EOS]":1}
        self.token_itos = ["[PAD]","[EOS]"]
        for tok in vocab_tokens:
            self.token_stoi[tok] = len(self.token_stoi)
            self.token_itos.append(tok)
        
        self.PAD = self.token_stoi["[PAD]"]
        self.EOS = self.token_stoi["[EOS]"]
        self.action_size = len(self.action_stoi)
        self.token_size = len(self.token_stoi)
    
    def encode_sent(self, string):
        words = self.tokenizer_fn(string)
        return encode_list(words)
    
    def encode_list(self, l):
        IDs = []
        for word in l:
            id = self.token_stoi[word]
            IDs.append(id)
        return IDs
    
    def encode_edit(self, edit, OOV_ids=[]):
        action, tok, pos = edit
        if pos == None:
            pos = -1
        return (self.action_stoi[action], self.token_stoi[tok], pos)
    
    def decode_sent(self,ids):
        return " ".join([self.stoi[idx] for idx in ids])
    
    def decode_edit(self,edit):
        action_id, tok_id, pos = edit
        return (self.action_itos[action_id], self.token_itos[tok_id], pos)

In [12]:
vocab = Simple_Edit_Vocab(["A","B","C","D","E"],tokenizer_fn=str.split)

In [13]:
sample = edit_dataset[11]
print(sample)

start, target, edit = sample
vocab.encode_edit(edit)

(['A', 'B', 'C', 'D'], ['B', 'B', 'C'], ('Insert', 'B', 1))


(1, 3, 1)

In [40]:
class EditTransformer(nn.Module):

    def __init__(self, action_vocab_size=2, token_vocab_size=1000, embed_dim=768, att_heads=8, layers=4, dim_feedforward=1024, dropout=0.1):
        super(EditTransformer, self).__init__()
        self.action_vocab_size = action_vocab_size
        self.token_vocab_size = token_vocab_size
        self.embedding_size = embed_dim
        self.bert_encoder_model = BERTStyleEncoder(vocab_size=token_vocab_size, dim_model=embed_dim, nhead=att_heads, \
                 num_encoder_layers=layers, d_feed=dim_feedforward, dropout=dropout)
        
        self.src_embedder = self.bert_encoder_model.embedder
        self.src_encoder = self.bert_encoder_model.encoder
        
        self.bert_decoder_model = BERTStyleDecoder(vocab_size=token_vocab_size, dim_model=embed_dim, nhead=att_heads, \
                 num_encoder_layers=layers, d_feed=dim_feedforward, dropout=dropout)
        
        self.tgt_embedder = self.bert_encoder_model.embedder # nn.Embedding(vocab_size, embed_dim)
        self.tgt_decoder = self.bert_decoder_model.decoder
        
        ACT = torch.tensor(np.random.rand(embed_dim)*1e-2, \
                                requires_grad=True, \
                                dtype=torch.float32)
        self.ACT = torch.nn.Parameter(ACT)
        
        self.action_attn = MultiheadAttention(embed_dim, 1, dropout=dropout)
        
        self.transformer = Transformer(d_model=embed_dim,
                                       nhead=att_heads, 
                                       num_encoder_layers=layers, 
                                       num_decoder_layers=layers, 
                                       dim_feedforward=dim_feedforward,
                                       custom_encoder=self.src_encoder,
                                       custom_decoder=self.tgt_decoder)
        self.action_decoder = nn.Linear(embed_dim, action_vocab_size)
        self.token_decoder = nn.Linear(embed_dim, token_vocab_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.params = self.parameters()
        self.stats = {}
        self.init_weights()
    
    def init_train_params(self, vocab, lr=0.005, gamma=0.99):
        self.action_criterion = nn.CrossEntropyLoss()
        self.token_criterion = nn.CrossEntropyLoss(ignore_index=vocab.PAD)
        self.position_criterion = nn.CrossEntropyLoss(ignore_index=-1)
        self.optimizer = torch.optim.SGD(self.params, lr=lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1.0, gamma=gamma)

    def init_weights(self):
        initrange = 0.1
        
        self.action_decoder.bias.data.zero_()
        self.action_decoder.weight.data.uniform_(-initrange, initrange)
        self.token_decoder.bias.data.zero_()
        self.token_decoder.weight.data.uniform_(-initrange, initrange)
        

    def forward(self, src, tgt):        
        src_emb = self.src_embedder(src) * math.sqrt(self.embedding_size)
        tgt_emb = self.tgt_embedder(tgt) * math.sqrt(self.embedding_size)
        
        output, _ = self.transformer(src_emb, tgt_emb)
        batch_size = output.shape[1]
        action_embed = self.ACT.repeat(1,batch_size,1)
        action_out, atts = self.action_attn(action_embed, output, output)
        action_dist = self.action_decoder(action_out).squeeze(0)
        token_dist = self.token_decoder(action_out).squeeze(0)
        atts = atts.squeeze(1)
        
        return action_dist, token_dist, atts
    
    def train_step(self, batch):
        encoder_input = batch.target
        decoder_input = batch.start
        action_id = batch.action_id
        token_id = batch.token_id
        action_pos = batch.action_pos
        batch_max_seq_len = decoder_input.shape[0]
        
        self.optimizer.zero_grad()
        actions, tokens, positions = self(encoder_input, decoder_input)
        
        action_loss = self.action_criterion(actions.view(-1, self.action_vocab_size), action_id.view(-1))
        token_loss = self.token_criterion(tokens.view(-1, self.token_vocab_size), token_id.view(-1))
        position_loss = self.position_criterion(positions.view(-1, batch_max_seq_len), action_pos.view(-1))
        loss = action_loss + token_loss + position_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.params, 0.5)
        self.optimizer.step()
        return loss
    
    def data2dataset(self, data, vocab):
        TEXT_FIELD = Field(sequential=True, use_vocab=False, pad_token=vocab.PAD)
        POS_FIELD = Field(sequential=True, use_vocab=False, pad_token=-1)

        examples = []

        for (start, target, edit) in data:
            start_ids = vocab.encode_list(start) + [vocab.EOS]
            target_ids = vocab.encode_list(target)
            action_id, token_id, action_pos = vocab.encode_edit(edit)
            if action_id == vocab.action_stoi["Delete"]:
                token_id = vocab.PAD

            example = torchtext.data.Example.fromdict({"start":start_ids, 
                                                   "target":target_ids, 
                                                   "action_id":[action_id],
                                                   "token_id":[token_id],
                                                   "action_pos":[action_pos]}, 
                                                        fields={"start":("start",TEXT_FIELD), 
                                                                "target":("target",TEXT_FIELD), 
                                                                "action_id":("action_id", TEXT_FIELD),
                                                                "token_id":("token_id", TEXT_FIELD),
                                                                "action_pos":("action_pos", POS_FIELD)})
            examples.append(example)
        fields = {"start":TEXT_FIELD, "target":TEXT_FIELD, "action_id":TEXT_FIELD, "token_id":TEXT_FIELD, "action_pos":POS_FIELD}
        dataset = torchtext.data.Dataset(examples,fields=fields)
        return dataset

In [41]:
model = EditTransformer(action_vocab_size=vocab.action_size, token_vocab_size=vocab.token_size).to("cuda")
model.init_train_params(vocab)

In [42]:
dataset = model.data2dataset(edit_dataset, vocab)

In [43]:
train_iterator = BucketIterator(
    dataset,
    batch_size = 16,
    repeat=True,
    shuffle=True,
    device = "cuda")

In [44]:
for batch in train_iterator:
    print(f"start: {batch.start}")
    print(f"target: {batch.target}")
    print(f"action_id: {batch.action_id}")
    print(f"token_id: {batch.token_id}")
    print(f"action_pos: {batch.action_pos}")
    break

start: tensor([[6, 3, 4, 3, 4, 3, 3, 4, 3, 2, 4, 4, 3, 4, 3, 4],
        [3, 2, 3, 3, 3, 3, 4, 3, 4, 3, 4, 5, 6, 2, 4, 5],
        [6, 4, 4, 5, 3, 3, 5, 3, 5, 5, 3, 4, 6, 5, 5, 3],
        [1, 4, 5, 4, 3, 4, 5, 2, 4, 3, 4, 3, 4, 1, 5, 3],
        [0, 3, 1, 6, 1, 2, 1, 1, 1, 4, 5, 3, 4, 0, 4, 4],
        [0, 1, 0, 4, 0, 2, 0, 0, 0, 1, 1, 5, 1, 0, 1, 6],
        [0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0, 0, 3],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]], device='cuda:0')
target: tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
action_id: tensor([[1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]], device='cuda:0')
token_id: tensor([[4, 0, 3, 0, 4, 0, 3, 4, 3, 0, 3, 0, 0, 0, 0, 0]], device='cuda:0')
action_pos: tensor([[2, 2, 1, 2, 3, 2, 0, 3, 2, 2, 2, 5, 3, 2, 1, 6]], device='cuda:0')


In [45]:
it = iter(train_iterator)
batch = next(it)
start = batch.start
target = batch.target
print(start, target)
outputs = model(target, start)
print(outputs[0].shape, outputs[1].shape)

tensor([[6, 3, 6, 2, 2, 1, 4, 3, 3, 3, 5, 3, 6, 3, 3, 6],
        [4, 3, 3, 3, 3, 0, 6, 5, 2, 3, 3, 2, 3, 5, 2, 3],
        [1, 3, 2, 4, 2, 0, 3, 1, 3, 2, 6, 2, 4, 1, 3, 2],
        [0, 4, 1, 4, 4, 0, 3, 0, 1, 3, 3, 2, 5, 0, 6, 2],
        [0, 1, 0, 2, 4, 0, 6, 0, 0, 6, 1, 3, 2, 0, 3, 4],
        [0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 4, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]], device='cuda:0') tensor([[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
torch.Size([16, 3]) torch.Size([16, 7])


In [46]:
model.train_step(batch)

tensor(5.7881, device='cuda:0', grad_fn=<AddBackward0>)

In [47]:
trainer = Model_Trainer(model, vocab)

'output_dir' not defined, training and model outputs won't be saved.


In [63]:
train_logs = trainer.train(model, train_iterator, 10000, save_interval=10000000)

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

KeyboardInterrupt: 

In [64]:
def predict_edit(start, target):
    start = torch.tensor(vocab.encode_list(start) + [vocab.EOS], dtype=torch.long).unsqueeze(1).to("cuda")
    target = torch.tensor(vocab.encode_list(target), dtype=torch.long).unsqueeze(1).to("cuda")
    model.eval()
    actions, tokens, positions = model(target, start)
    model.train()
    best_act = torch.argmax(actions, dim=-1).cpu().tolist()[0]
    best_tok = torch.argmax(tokens, dim=-1).cpu().tolist()[0]
    best_pos = torch.argmax(positions, dim=-1).cpu().tolist()[0]
    edit = vocab.decode_edit((best_act, best_tok, best_pos))
    return edit

In [65]:
start  = list("AAC")
target = list("BBC")
edits = bi_directional_traceback(start, target)
print(start, target)
print(f"Possible actions: {edits}")

predict_edit(start, target)

['A', 'A', 'C'] ['B', 'B', 'C']
Possible actions: [('Insert', 'B', 1), ('Delete', 'A', 0), ('Delete', 'A', 1)]


('Insert', 'B', 2)

In [66]:
def multi_edit_animation(start, target, rounds=15, interval=0.2):
    print(f"TARGET: {target}")
    print(f"START : {start}")
    for r in range(rounds):
        edit = predict_edit(start, target)
        print(f"Predicted EDIT: {edit}")
        start = perform_edit(start, edit)
        print(f"STEP {r}: {start}")
        print()
    return start

In [67]:
multi_edit_animation(start, target)

TARGET: ['B', 'B', 'C']
START : ['A', 'A', 'C']
Predicted EDIT: ('Insert', 'B', 2)
STEP 0: ['A', 'A', 'B', 'C']

Predicted EDIT: ('Insert', 'B', 3)
STEP 1: ['A', 'A', 'B', 'B', 'C']

Predicted EDIT: ('Delete', 'C', 1)
STEP 2: ['A', 'B', 'B', 'C']

Predicted EDIT: ('Delete', 'C', 3)
STEP 3: ['A', 'B', 'B']

Predicted EDIT: ('Insert', 'C', 0)
STEP 4: ['C', 'A', 'B', 'B']

Predicted EDIT: ('Insert', 'C', 1)
STEP 5: ['C', 'C', 'A', 'B', 'B']

Predicted EDIT: ('Insert', 'C', 2)
STEP 6: ['C', 'C', 'C', 'A', 'B', 'B']

Predicted EDIT: ('Delete', 'B', 3)
STEP 7: ['C', 'C', 'C', 'B', 'B']

Predicted EDIT: ('Delete', 'C', 2)
STEP 8: ['C', 'C', 'B', 'B']

Predicted EDIT: ('Insert', 'C', 1)
STEP 9: ['C', 'C', 'C', 'B', 'B']

Predicted EDIT: ('Delete', 'C', 2)
STEP 10: ['C', 'C', 'B', 'B']

Predicted EDIT: ('Insert', 'C', 1)
STEP 11: ['C', 'C', 'C', 'B', 'B']

Predicted EDIT: ('Delete', 'C', 2)
STEP 12: ['C', 'C', 'B', 'B']

Predicted EDIT: ('Insert', 'C', 1)
STEP 13: ['C', 'C', 'C', 'B', 'B']

Pre

['C', 'C', 'B', 'B']