In [1]:
import numpy as np
from pytorch_transformers.tokenization_distilbert import DistilBertTokenizer
from pytorch_transformers.modeling_distilbert import DistilBertModel
import torch
import torch.nn
from torch import optim
import random

from torch.utils.data import Dataset
import json
from torch import nn

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
special_tokens_dict = {'additional_special_tokens': ['<PLH>']}
tokenizer.add_special_tokens(special_tokens_dict)
encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
encoder.resize_token_embeddings(len(tokenizer))  

noop = "N"
sub = "S"
insert = "I"
delete = "D"

def ld(s1, s2, subcost=1, delcost=1, inscost=1):
    operations = [["" for j in range(len(s2) + 1)] for i in range(len(s1) + 1)]
    
    matrix = np.zeros((len(s1)+1, len(s2)+1))
    
    for j in range(len(s2) + 1):
        matrix[0,j] = j
        operations[0][j] = insert
        for i in range(len(s1) + 1):
            matrix[i,0] = i
            operations[i][0] = insert
            if i > 0 and j > 0:
                subCost = matrix[i-1, j-1] if s1[i-1] == s2[j-1] else matrix[i-1, j-1] + subcost
                insertCost = matrix[i, j-1] + inscost
                deleteCost = matrix[i-1, j] + delcost
                minCost = min(subCost, insertCost, deleteCost)
                matrix[i,j] = minCost
                if minCost == 0:
                    operations[i][j] = noop
                elif minCost == deleteCost:
                    operations[i][j] = delete
                elif minCost == insertCost:
                    operations[i][j] = insert
                elif minCost == subCost:
                    operations[i][j] = sub
    i = len(s1)
    j = len(s2)
    history = []
    while j > 0 or i > 0:        
        if delcost != np.inf:
            if j == 0:
                history.append(delete)
                i -= 1
                continue
            if matrix[i-1][j-1] < matrix[i-1,j]:
                history.append(noop)
                i -= 1
                j -= 1
            else:
                history.append(delete)
                i -= 1
        elif inscost != np.inf:
            if j == 0:
                history.append(noop)
                i -= 1
                continue
            if matrix[i-1][j-1] < matrix[i,j-1]:
                history.append(noop)
                i -= 1
                j -= 1
            else:
                history.append((insert,s2[j-1]))
                #history.append(insert)
                j -= 1
    history.reverse()
    return matrix, matrix[len(s1),len(s2)], history
    


In [2]:
class QADataset(Dataset):
    def __init__(self, path):
        with open(path, "r") as infile:
            self.data = json.load(infile)

    def __len__(self):
        return len(self.data["Data"])

    def sample(self):
        return self.data["Data"][random.randint(0, len(self.data["Data"]))]["Question"]
    
PLH = "<PLH>"
TS = "<s>"
TE = "</s>"

class PlaceholderClassifier(nn.Module):
    def __init__(self, hsz, max_placeholders=10):
        super().__init__()
        self.dense = nn.Linear(
            hsz, max_placeholders,
        )
        self.activation = nn.ReLU()

    def forward(self, input, hidden=None):
        return self.activation(self.dense(input))
    
class TokenClassifier(nn.Module):    
    def __init__(self, hsz, vsz, max_seq_len):
        super().__init__()
        
        self.dense = nn.Linear(
            hsz, vsz
        )
        self.activation = nn.ReLU()
        
    def forward(self, input, hidden=None):
        return self.activation(self.dense(input))
        
class DeletionClassifier(nn.Module):    
    def __init__(self, hsz):
        super().__init__()
        self.dense = nn.Linear(
            hsz, 2,
        )
        self.activation = nn.ReLU()
    
    def forward(self, input, hidden=None):
        return self.activation(self.dense(input))
dataset = QADataset("/virtualmachines/data/trivia_qa/qa/wikipedia-train.json")
dataset.sample()
plh = tokenizer.encode("<PLH>")[0]
print(plh)
print(len(tokenizer))

30522
30523


In [3]:
def deleted_indices_to_placeholders(deleted, max_placeholders):
    print("max_placeholders")
    print(max_placeholders)
    i = 0
    placeholders = []
    num_deleted = 0
    while True:
        if i == len(deleted):
            if num_deleted > 0:
                placeholders.append(min(num_deleted, max_placeholders))
            break
        if deleted[i] == 1:
            num_deleted += 1
        else:
            if num_deleted > 0:
                placeholders.append(min(num_deleted, max_placeholders))
            placeholders.append(0)
            num_deleted = 0
        i += 1
    while len(placeholders) < len(deleted):
        placeholders.append(0)
    return torch.unsqueeze(torch.LongTensor(placeholders), 0)

def apply_deletions(y, deleted, pad=False):
    batch_post_deletion = []
    batch_post_deletion_with_placeholders = []
    for batch_index in range(len(y)):
        j = 0
        post_deletion = []
        post_deletion_with_placeholders = []
        for i in range(len(y[batch_index])):
            if deleted[batch_index][i] == 0:
                post_deletion.append(y[batch_index][i])
                post_deletion_with_placeholders.append(y[batch_index][i])
                j += 1
            else:
                post_deletion_with_placeholders.append(plh)
        batch_post_deletion.append(post_deletion)
        batch_post_deletion_with_placeholders.append(post_deletion_with_placeholders)
    return batch_post_deletion, batch_post_deletion_with_placeholders

def delete_random(y, p=0.25, pad=True):
    """Deletes token(s) randomly from the passed (tokenized) string with probability p
    Accepts:
    - a list of token sequences bsz * pad_length
    Returns tensors of:
    - the token sequence post-deletion
    - the token sequence post-deletion, with PLH inserted at each deleted position
    - the number of placeholders inserted at each post-deletion index
    """
    batch_deletions = []
    for i in range(len(y)):
        deletions = []
        for i in range(len(y[i])):
            if random.random() < p:
                deletions.append(1)
            else:
                deletions.append(0)
        batch_deletions.append(deletions)
    post_deletion, post_deletion_with_placeholders = apply_deletions(y, batch_deletions)

    return post_deletion, post_deletion_with_placeholders, batch_deletions#,placeholders

delete_random([tokenizer.encode("Hello, my name is "), tokenizer.encode("What is your name?")])

([[7592, 1010, 2026, 2171, 2003], [2054, 2003, 2115]],
 [[7592, 1010, 2026, 2171, 2003], [2054, 2003, 2115, 30522, 30522]],
 [[0, 0, 0, 0, 0], [0, 0, 0, 1, 1]])

In [4]:
def delete_minimal(y, y_ground, pad=True, max_placeholders=0):
    """Apply the sequence of deletions from y that give the smallest possible Levenshtein distance from y_ground
    Returns tensors of:
    - the token sequence post-deletion
    - the token sequence post-deletion, with PLH inserted at each deleted position
    - the number of placeholders inserted at each post-deletion index
    """
    batched = torch.zeros(len(y),)
    
    # if y_ground is longer than y, there is no sequence of deletes with a shorter distance
    if len(y_ground) > len(y): 
        return y
    
    # calculate LD directly against tokens
    deleted = []
    matrix, dist, edits = ld(y.numpy(),y_ground.numpy(), subcost=np.inf, inscost=np.inf)
    num_deleted = 0
    for i in range(len(edits)):
        if edits[i] == "D":
            deleted.append(1)
            num_deleted += 1
        else:
            deleted.append(0)
    placeholders = deleted_indices_to_placeholders(deleted,max_placeholders=max_placeholders)
    post_deletion, post_deletion_with_placeholders = deleted_boolean_to_tensors(y, deleted, num_deleted)    
    return post_deletion, post_deletion_with_placeholders, torch.LongTensor([deleted])#,placeholders

def pad(items, pad_len=0, pad_token=0):
    if pad_len == 0:
        pad_len = max([len(i) for i in items])
    for i in items:
        while len(i) < pad_len:
            i.append(pad_token)
    return items
    
def insert_minimal(y, y_ground,max_placeholders):
    """Apply the sequence of insertions to y resulting in the smallest possible Levenshtein distance from y_ground
    Accepts tensor of:
    - bsz * max_seq_len
    Returns tensors of:
    - size (n+1), where the value at position 0 <= i < n represents the number of PLH tags inserted at that position
        - n is the number of tokens in y
    - size(k) containing the indices of each inserted token, where k is the total number of tokens added
    - size(n+k) containing the tokens of the entire post-insertion sequence
    """

    batch_placeholders = []
    batch_inserted = []
    batch_new = []
    for batch_idx in range(len(y)):
        matrix, dist, edits = ld(y[batch_idx],y_ground[batch_idx], subcost=np.inf, delcost=np.inf)
        inserted = 0
        y_placeholders = []
        y_inserted = []
        y_new = []
        y_idx = 0
        i = 0
        while i < len(edits):
            if edits[i][0] == "I":
                accum = 0
                while i < len(edits) and edits[i][0] == "I":
                    y_inserted.append(edits[i][1])
                    y_new.append(edits[i][1])
                    accum += 1
                    i += 1
                y_placeholders.append(min(accum,max_placeholders-1))
            else:
                y_placeholders.append(0)
                y_new.append(y[batch_idx][y_idx])
                i += 1
                y_idx += 1
        batch_placeholders.append(y_placeholders)
        batch_inserted.append(y_inserted)
        batch_new.append(y_new)
    batch_placeholders = pad(batch_placeholders, pad_token=tokenizer.pad_token_id)
    #batch_inserted = pad(batch_inserted)
    batch_new = pad(batch_new, pad_token=tokenizer.pad_token_id)
    
    return batch_placeholders, batch_inserted, torch.LongTensor(batch_new)

y1 = tokenizer.encode("My name is Nick, what is your name")
y2 = tokenizer.encode("My name is Nick")
y3 = tokenizer.encode("What's the dog bro you say now")
y4 = tokenizer.encode("What's the")

#delete_minimal(torch.LongTensor([y1]), torch.LongTensor([y2]))
#insert_minimal([y2,y4], [y1,y3])
#y = torch.LongTensor([tokenizer.encode("Hi dude My name is Nick what name")])
#y_ground = torch.LongTensor([tokenizer.encode("My name is Nick, what is your name")])

#placeholders, inserted, new = insert_minimal(y, y_ground)


In [None]:
class DatasetSampler():
    def __init__(self, tokenizer, encoder, alpha=0, beta=0, max_placeholders=2):
        self.dataset = QADataset("/virtualmachines/data/trivia_qa/qa/wikipedia-train.json")
        self.tokenizer = tokenizer
        self.encoder = encoder
        self.alpha = alpha
        self.beta = beta
        self.max_placeholders = max_placeholders
            
    def encode_and_pad(self, string, pad_length):
        string = self.tokenizer.encode(string)
        while len(string) < pad_length:
            string.append(self.tokenizer.pad_token_id)
        return torch.LongTensor([string])
    
    '''
        Encode the passed strings and pad to the specified length
    '''
    def encode_and_pad_batch(self, strings):
        encoded = [self.tokenizer.encode(s) for s in strings]
        pad_length = max([len(s) for s in encoded])
        padded = torch.zeros(len(strings), pad_length, dtype=torch.int64) # because self.tokenizer.pad_token_id == 0
        for i in range(len(encoded)):
            for j in range(len(encoded[i])):
                padded[i,j] = encoded[i][j]
        return padded
        
    def sample(self, bs=10, pad_to=30, t_classifier=None):
        '''Sample an observation and return tensor tuples:
        1) (a) the observation perturbed with deletions, for input to the placeholder classifier
           (b) the indices of the deleted tokens (i.e. where placeholders should be inserted, for the placeholder classifier loss) 
        2) (a) 1(a), but with PLH replacing each deleted token. For input to the token insertion classifier 
           (b) the number of placeholders tokens to insert at each index in 1(a) (for the placeholder classifier loss)
        3) (a) the observation perturbed with insertion (tokens), for input to the token deletion classifier 
           (b) 2(a) (for the token insertion classifier loss)
        '''
        u = random.random()
        v = random.random()
        # sample a pair of (untokenized) strings 
        y_ground = [self.dataset.sample() for i in range(bs)]
        y0 = []
        
        # first, pad/encode y0 and y_ground
        y_ground = [self.tokenizer.encode(s) for s in y_ground]
        
        # randomly choose between LD-minimal deletion and random deletion
        # Returns tensors of:
        # - the token sequence post-deletion
        # - the token sequence post-deletion, with PLH inserted at each deleted position
        # - whether a given index was deleted or not
        if len(y0) == 0:
            y_ins, y_ins_with_placeholders, y_ins_p = delete_random(y_ground)
        elif u >= self.alpha:
            y_ins, y_ins_with_placeholders, y_ins_p = delete_random(y0)
        else:
            y_ins, y_ins_with_placeholders, y_ins_p = delete_minimal(y0, y_ground)

        # Returns tensors of:
        # - (1, n+1) - the number of PLH tags at each index
        # - (1, k) - the inserted tokens
        # - (1, n+k) - the entire post-insertion sequence
        # where n is the length of the original sequence and k is the number of tokens added 
        y_placeholders, y_inserted, y_ins_prime = insert_minimal(y_ins, y_ground, self.max_placeholders)

        # input to the deletion classifier will be randomly chosen between:
        # - the input to the placeholder classifier (i.e. the deleted input)
        # - the output from applying the token classifier to y_placeholders
        if v >= self.alpha and len(y0) != 0:
            y_del = y_ins
        else:
            with torch.no_grad():
                enc = self.encoder(torch.LongTensor(pad(y_ins_with_placeholders, pad_token=0)))[0]
                logits = t_classifier(enc)
                y_del = torch.argmax(logits,dim=2)
                y_del = torch.LongTensor(y_del)
                
        y_ins_pad_len = max(max([len(i) for i in y_placeholders]), max([len(i) for i in y_ins]), max([len(i) for i in y_ins_prime]))
        y_ins = torch.LongTensor(pad(y_ins, pad_len=y_ins_pad_len))
        y_placeholders = torch.LongTensor(pad(y_placeholders, pad_len=y_ins_pad_len))
        
        y_ins_p = torch.LongTensor(pad(y_ins_p, pad_token=0))
        
        y_del = torch.LongTensor(pad(y_del, pad_len=y_ins_p.size(1), pad_token=0))
        
        y_ins_prime = torch.LongTensor(pad(y_ins_prime, pad_token=0, pad_len=y_ins_pad_len))
        
        #print("y_placeholders")
        #print(y_placeholders.size())
        #print("y_ins_prim")
        #print(y_ins_prime.size())
        #print("y_ins")
        #print(y_ins.size())
        
        assert y_ins_prime.size(1) == y_ins.size(1)
        
        return ((y_del, self.encoder(y_del)[0], y_ins_p), 
               (y_ins, self.encoder(y_ins)[0], y_placeholders), 
               (y_ins_prime, self.encoder(y_ins_prime)[0], y_ins), y_ground)
                
class Model():
    def __init__(self, sampler, vocab_size, hsz=768, lr=0.0001):
        super().__init__()
        self.sampler = sampler
        self.p_classifier = PlaceholderClassifier(hsz, max_placeholders=sampler.max_placeholders)
        self.t_classifier = TokenClassifier(hsz, vocab_size, 20)
        self.d_classifier = DeletionClassifier(hsz)        
        self.alpha = 0.5
        self.beta = 0.5
        self.p_loss = nn.CrossEntropyLoss()
        self.t_loss = nn.CrossEntropyLoss()
        self.d_loss = nn.CrossEntropyLoss()
        
        self.optims = {
            'p_classifier': optim.SGD(self.p_classifier.parameters(), lr=lr),
            't_classifier': optim.SGD(self.t_classifier.parameters(), lr=lr),
            'd_classifier': optim.SGD(self.d_classifier.parameters(), lr=lr),
        }
        
        self.step = 0
        self.loss = 0
    
    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()
            
            
    def decode_step(self):
        with torch.no_grad():
            self.p_classifier.eval()
            self.t_classifier.eval()
            self.d_classifier.eval()

            (y_del,y_del_enc,_), (y_ins,_,_), (y_ins_prime, _,_), y_ground = self.sampler.sample(bs=1, t_classifier=self.t_classifier)
            print("y_ground")
            print(tokenizer.decode(y_ground[0]))
            print("y_del")
            print(tokenizer.decode(y_del[0].tolist()))
            print("y_ins")
            print(tokenizer.decode(y_ins[0].tolist()))

            y_ground = y_ins_prime
            y_last = None

            step = 0
            max_steps = 10

            while True:
                print("Decode step %d" % step)
                step += 1
                if step > max_steps:
                    y_ground = y_last
                    break
                if y_last is not None:
                    if (y_last.size() == y_ground.size() and torch.all(y_last == y_ground)):
                        break         
                    y_ground = y_last
                
                # deletion first
                preds_deletes = self.d_classifier(y_del_enc).cuda()
                if preds_deletes.size(1) == 0:
                    continue

                deletions = torch.argmax(preds_deletes,2)
                
                deleted = [y_ground[0,i].item() for i in range(deletions.size(1)) if deletions[0,i] == 0]

                y_ground = tokenizer.decode(deleted)

                # then placeholder
                y_ground = encoder(torch.LongTensor([deleted]))[0]

                preds_placeholders = self.p_classifier(y_ground).cuda()
                if preds_placeholders.size(1) == 0:
                    y_ground = torch.LongTensor([deleted])
                    continue
                placeholders = torch.argmax(preds_placeholders,2)

                reconstructed = []
                for i in range(placeholders.size(1)):
                    for j in range(placeholders[0,i].item()):
                        reconstructed.append(plh)
                    reconstructed.append(deleted[i])

                y_ground = torch.LongTensor([reconstructed])

                # then inserts
                inserts = self.t_classifier(encoder(torch.LongTensor([reconstructed]))[0]).cuda()
                inserts = torch.argmax(inserts, 2)
                j = 0

                for i in range(len(reconstructed)):
                    if reconstructed[i] == plh:
                        if i < len(reconstructed):
                            reconstructed[i] = inserts[0,j].item()
                        else:
                            reconstructed.append(inserts[0,j].item())
                        j += 1

                y_last = torch.LongTensor([reconstructed])
            if y_last is not None:
                print(y_last[0].tolist())
                print(tokenizer.decode(y_last[0].tolist()))
            else:
                print("y_last none")
            
    def train_step(self):
        loss = 0
        self.zero_grad()
        self.p_classifier.train()
        self.t_classifier.train()
        self.d_classifier.train()
        
        (y_del,y_del_enc,y_del_out), (y_ins,y_ins_enc,y_ins_out), (y_ins_prime, y_ins_prime_enc, y_ins_prime_out), _ = self.sampler.sample(bs=1, t_classifier=self.t_classifier)
        
        preds_deletes = self.d_classifier(y_del_enc)
        preds_placeholders = self.p_classifier(y_ins_enc)
        preds_inserts = self.t_classifier(y_ins_prime_enc)
        
        loss = self.d_loss(torch.transpose(preds_deletes, 1, 2), y_del_out)
        loss += self.p_loss(torch.transpose(preds_placeholders, 1, 2), y_ins_out)
        loss += self.t_loss(torch.transpose(preds_inserts,1,2), y_ins_prime_out)
        self.step += 1
        self.loss += loss
        if self.step % 10 == 0:
            print(self.step)
        if self.step % 50 == 0:
            print(self.step)
            print(self.loss / 50)
            self.loss = 0
            self.decode_step()
        
        loss.backward()
        self.update_params()

sampler = DatasetSampler(tokenizer, encoder, max_placeholders=5)

model = Model(sampler, len(tokenizer))
for i in range(10000):
    model.train_step()
#model.decode_step()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
50
tensor(12.6091, grad_fn=<DivBackward0>)
y_ground
what was adopted as the official motto of the united states in 1956?
y_del
##virեberriesե [unused168]aneous walk regattaaneousφ regatta special topicsե
y_ins
what was as the official motto of united states in [PAD] [PAD] [PAD] [PAD]
Decode step 0
torch.Size([1, 14, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[ 2054,  2001,  4233,  2004,  1996,  2880, 12652,  1997,  1996,  2142,
          2163,  1999,  3838,  1029]])
torch.Size([1, 14])
Decode step 1
torch.Size([1, 14, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[ 3646, 17538,  2054,  3646,  2001, 17538, 17538,  4233, 17538, 17538,
          2004, 17538, 17538,  3646, 17538,  1996, 17538, 17538, 17538, 17538,
          2880, 29726, 17538, 12652, 17538,  1997, 17538, 17538,

Decode step 7
torch.Size([1, 12, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[12966, 12966,   585, 12966, 12966,   585, 12966, 12966, 17538, 12966,
         12966,   585, 12966, 12966,   585, 12966,   585, 17538, 14540, 14540,
           585,   585, 14540,   585, 14540, 14540, 14540, 28782, 12966,   585,
           585, 12966, 12966,   585,   585, 12966, 17538]])
torch.Size([1, 37])
Decode step 8
torch.Size([1, 12, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[28782, 23003, 12966, 23003, 23003, 12966, 23003, 23003,   585, 17538,
         17780, 12966,   585,   585, 12966,   585,   585,   585, 17538, 28782,
         12966, 17780, 17780, 12966, 17780, 17780, 17538, 17780, 17780, 12966,
         17780, 17780, 12966, 17780, 17780,   585]])
torch.Size([1, 36])
Decode step 9
torch.Size([1, 12, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[12966, 17538, 28782, 17538, 17538, 23003, 17538, 17538, 12966

Decode step 9
torch.Size([1, 24, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0')
tensor([[3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646,
         3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646,
         3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646,
         3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646,
         3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646,
         3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646]])
torch.Size([1, 72])
Decode step 10
[3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 3646, 36

Decode step 3
torch.Size([1, 17, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[23001,  9826, 21061,  9826,  9826,  2054,  2744, 21061, 23001,  9826,
         28782, 14884, 17538, 21061, 14884, 14884, 17538, 12966, 12966,  2003,
         12966, 12966, 23003, 23001, 23001,  1978, 14884, 23001, 10911, 23001,
         15183, 17538, 23001, 23001, 23003]])
torch.Size([1, 35])
Decode step 4
torch.Size([1, 17, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
tensor([[12966,  3646, 23001,  3646, 12966, 12966,  9826,  2345,  3646,  2345,
         21061, 12966,  2345,  3646,  9826, 14540, 14540, 14540,  9826, 14540,
         14540,  2054, 14540, 14540,  2744,  2345, 14540, 21061,  2345, 14540,
         23001,  2345,  2345, 14540,  9826, 14540, 14540, 14540, 28782, 14540,
         14540, 14540, 14884, 14540, 14540, 14540, 17538, 14540, 14540, 14540,
         21061, 17780, 14540, 17780, 14884, 17780, 17780, 14540, 14884

Decode step 1
torch.Size([1, 22, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0')
tensor([[ 2054,  2001,  1996,  2171,  1997,  1996,  1058, 15472,  2278, 20135,
          2008,  3975,  1999,  2048,  2125, 29549,  1996,  3023,  2218,  1997,
         12686,  1999,  3301,  1029]])
torch.Size([1, 24])
Decode step 2
torch.Size([1, 22, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0')
tensor([[21709,  2054, 30499,  2001, 20968,  1996,  2171,  1997,  1996,  1058,
         15472,  2278, 20135,  2008,  3975,  1999,  2048,  2125, 27742, 29549,
          1996,  3023, 20968,  2218,  1997, 12686,  1999]])
torch.Size([1, 27])
Decode step 3
[21709, 2054, 30499, 2001, 20968, 1996, 2171, 1997, 1996, 1058, 15472, 2278, 20135, 2008, 3975, 1999, 2048, 2125, 27742, 29549, 1996, 3023]
##rot what陽 wasberries the name of the vlcc tanker that split in two offencia binoculars the coast
451
452
453
454


Decode step 1
[1000, 1000, 1000, 2085, 2003, 1996, 3467, 1997, 2256, 27648, 1000, 1000, 2003, 1037, 2240, 2013, 2029, 2377, 1029, 1000]
" " " now is the winter of our discontent " " is a line from which play? "
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
800
tensor(11.5725, grad_fn=<DivBackward0>)
y_ground
' that's livin'alright'was the theme song to which tv programme?
y_del
[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
y_ins
' that s livin'alright'was the theme song to tv programme? [PAD] [PAD]
Decode step 0
torch.Size([1, 18, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
       device='cuda:0')
tensor([[ 1005,  2008,  1005,  1055, 22135,  2378,  1005, 10303,  1005,  2001,
          1996,  4323,  2299,  2000,  2029,  2694,  4746,  1029]])
torch.Si

1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1200
tensor(11.2303, grad_fn=<DivBackward0>)
y_ground
who led the first expedition to successfully circumnavigate the earth between 1519 and 1522, and was killed during the voyage?
y_del
[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
y_ins
who first expedition successfully ciumnavigate the earth between 1519 1522, and was during the voyage [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
Decode step 0
torch.Size([1, 29, 2])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0]], device='cuda:0')
tensor([[ 2040,  2419,  1996,  2034,  5590,  2000,  5147, 25022, 11890,  2819,

In [36]:
def decode(self, text):
        
    self.p_classifier.eval()
    self.t_classifier.eval()
    self.d_classifier.eval()

    step = 0
    max_steps = 10
    
    y_last = None
    y_ground = torch.LongTensor([tokenizer.encode(text)])
    y_ground_enc = encoder(y_ground)[0]
    
    while True:
        print("Decode step %d" % step)
        step += 1
        if step > max_steps:
            y_ground = y_last
            break
        if y_last is not None:
            if (y_last.size() == y_ground.size() and torch.all(y_last == y_ground)):
                break         
            y_ground = y_last
                
        # deletion first
        preds_deletes = self.d_classifier(y_ground_enc).cuda()
        if preds_deletes.size(1) == 0:
            continue

        deletions = torch.argmax(preds_deletes,2)
        print(deletions)

        deleted = [y_ground[0,i].item() for i in range(deletions.size(1)) if deletions[0,i] == 0]

        y_ground = tokenizer.decode(deleted)
        print("post deleted")
        print(y_ground)

        # then placeholder
        y_ground = encoder(torch.LongTensor([deleted]))[0]

        preds_placeholders = self.p_classifier(y_ground).cuda()
        if preds_placeholders.size(1) == 0:
            y_ground = torch.LongTensor([deleted])
            continue
        placeholders = torch.argmax(preds_placeholders,2)

        reconstructed = []
        for i in range(placeholders.size(1)):
            for j in range(placeholders[0,i].item()):
                reconstructed.append(plh)
            reconstructed.append(deleted[i])

        y_ground = torch.LongTensor([reconstructed])

        # then inserts
        inserts = self.t_classifier(encoder(torch.LongTensor([reconstructed]))[0]).cuda()
        inserts = torch.argmax(inserts, 2)
        j = 0

        for i in range(len(reconstructed)):
            if reconstructed[i] == plh:
                if i < len(reconstructed):
                    reconstructed[i] = inserts[0,j].item()
                else:
                    reconstructed.append(inserts[0,j].item())
                j += 1

        y_last = torch.LongTensor([reconstructed])
    if y_last is not None:
        print(y_last[0].tolist())
        print(tokenizer.decode(y_last[0].tolist()))
    else:
        print("y_last none")
decode(model, "I are go to the mall")
decode(model, "She am my mother")

Decode step 0
tensor([[0, 0, 0, 0, 0, 0]], device='cuda:0')
post deleted
i are go to the mall
Decode step 1
[1045, 2024, 2175, 2000, 1996, 6670]
i are go to the mall
Decode step 0
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
she am my mother
Decode step 1
[2016, 2572, 2026, 2388]
she am my mother
