In [56]:
import numpy as np
from pytorch_transformers.tokenization_distilbert import DistilBertTokenizer
from pytorch_transformers.modeling_distilbert import DistilBertModel
import torch
import torch.nn
from torch import optim
import random

from torch.utils.data import Dataset
import json
from torch import nn

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
special_tokens_dict = {'additional_special_tokens': ['<PLH>']}
tokenizer.add_special_tokens(special_tokens_dict)
encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')
encoder.resize_token_embeddings(len(tokenizer))  

noop = "N"
sub = "S"
insert = "I"
delete = "D"

def ld(s1, s2, subcost=1, delcost=1, inscost=1):
    operations = [["" for j in range(len(s2) + 1)] for i in range(len(s1) + 1)]
    
    matrix = np.zeros((len(s1)+1, len(s2)+1))
    
    for j in range(len(s2) + 1):
        matrix[0,j] = j
        operations[0][j] = insert
        for i in range(len(s1) + 1):
            matrix[i,0] = i
            operations[i][0] = insert
            if i > 0 and j > 0:
                subCost = matrix[i-1, j-1] if s1[i-1] == s2[j-1] else matrix[i-1, j-1] + subcost
                insertCost = matrix[i, j-1] + inscost
                deleteCost = matrix[i-1, j] + delcost
                minCost = min(subCost, insertCost, deleteCost)
                matrix[i,j] = minCost
                if minCost == 0:
                    operations[i][j] = noop
                elif minCost == deleteCost:
                    operations[i][j] = delete
                elif minCost == insertCost:
                    operations[i][j] = insert
                elif minCost == subCost:
                    operations[i][j] = sub
    i = len(s1)
    j = len(s2)
    history = []
    while j > 0 or i > 0:        
        if delcost != np.inf:
            if j == 0:
                history.append(delete)
                i -= 1
                continue
            if matrix[i-1][j-1] < matrix[i-1,j]:
                history.append(noop)
                i -= 1
                j -= 1
            else:
                history.append(delete)
                i -= 1
        elif inscost != np.inf:
            if j == 0:
                history.append(noop)
                i -= 1
                continue
            if matrix[i-1][j-1] < matrix[i,j-1]:
                history.append(noop)
                i -= 1
                j -= 1
            else:
                history.append((insert,s2[j-1]))
                #history.append(insert)
                j -= 1
    history.reverse()
    return matrix, matrix[len(s1),len(s2)], history
    


In [2]:
class QADataset(Dataset):
    def __init__(self, path):
        with open(path, "r") as infile:
            self.data = json.load(infile)

    def __len__(self):
        return len(self.data["Data"])

    def sample(self):
        key = "Question" if random.random() > 0.5 else "Answer"
        return self.data["Data"][random.randint(0, len(self.data["Data"]))]["Question"]
    
PLH = "<PLH>"
TS = "<s>"
TE = "</s>"

class PlaceholderClassifier(nn.Module):
    def __init__(self, hsz, max_placeholders=10):
        super().__init__()
        self.dense = nn.Linear(
            hsz, max_placeholders,
        )
        self.activation = nn.ReLU()

    def forward(self, input, hidden=None):
        return self.activation(self.dense(input))
    
class TokenClassifier(nn.Module):    
    def __init__(self, hsz, vsz, max_seq_len):
        super().__init__()
        
        self.dense = nn.Linear(
            hsz, vsz
        )
        self.activation = nn.ReLU()
        
    def forward(self, input, hidden=None):
        return self.activation(self.dense(input))
        
class DeletionClassifier(nn.Module):    
    def __init__(self, hsz):
        super().__init__()
        self.dense = nn.Linear(
            hsz, 2,
        )
        self.activation = nn.ReLU()
    
    def forward(self, input, hidden=None):
        return self.activation(self.dense(input))
dataset = QADataset("/virtualmachines/data/trivia_qa/qa/wikipedia-train.json")
dataset.sample()
plh = tokenizer.encode("<PLH>")[0]
print(plh)
print(len(tokenizer))

30522
30523


In [3]:
def deleted_indices_to_placeholders(deleted, max_placeholders):
    i = 0
    placeholders = []
    num_deleted = 0
    while True:
        if i == len(deleted):
            if num_deleted > 0:
                placeholders.append(min(num_deleted, max_placeholders))
            break
        if deleted[i] == 1:
            num_deleted += 1
        else:
            if num_deleted > 0:
                placeholders.append(min(num_deleted, max_placeholders))
            placeholders.append(0)
            num_deleted = 0
        i += 1
    while len(placeholders) < len(deleted):
        placeholders.append(0)
    return torch.unsqueeze(torch.LongTensor(placeholders), 0)


delete_random([tokenizer.encode("Hello, my name is "), tokenizer.encode("What is your name?")])

([[7592, 1010, 2026, 2171, 2003], [2054, 2003, 2115]],
 [[7592, 1010, 2026, 2171, 2003], [2054, 2003, 2115, 30522, 30522]],
 [[0, 0, 0, 0, 0], [0, 0, 0, 1, 1]])

In [54]:
def delete_minimal(y, y_ground, pad=True, max_placeholders=0):
    """Apply the sequence of deletions from y that give the smallest possible Levenshtein distance from y_ground
    Returns tensors of:
    - the token sequence post-deletion
    - the token sequence post-deletion, with PLH inserted at each deleted position
    - the number of placeholders inserted at each post-deletion index
    """
    batched = torch.zeros(len(y),)
    
    # if y_ground is longer than y, there is no sequence of deletes with a shorter distance
    if len(y_ground) > len(y): 
        return y
    
    # calculate LD directly against tokens
    deleted = []
    matrix, dist, edits = ld(y.numpy(),y_ground.numpy(), subcost=np.inf, inscost=np.inf)
    num_deleted = 0
    for i in range(len(edits)):
        if edits[i] == "D":
            deleted.append(1)
            num_deleted += 1
        else:
            deleted.append(0)
    placeholders = deleted_indices_to_placeholders(deleted,max_placeholders=max_placeholders)
    post_deletion, post_deletion_with_placeholders = deleted_boolean_to_tensors(y, deleted, num_deleted)    
    return post_deletion, post_deletion_with_placeholders, torch.LongTensor([deleted])#,placeholders

def pad(items, pad_len=0, pad_token=0):
    if pad_len == 0:
        pad_len = max([len(i) for i in items])
    for i in items:
        while len(i) < pad_len:
            i.append(pad_token)
    return items
    
def insert_minimal(y, y_ground,max_placeholders):
    """Apply the sequence of insertions to y resulting in the smallest possible Levenshtein distance from y_ground
    Accepts tensor of:
    - bsz * max_seq_len
    Returns tensors of:
    - size (n+1), where the value at position 0 <= i < n represents the number of PLH tags inserted at that position
        - n is the number of tokens in y
    - size(k) containing the indices of each inserted token, where k is the total number of tokens added
    - size(n+k) containing the tokens of the entire post-insertion sequence
    """

    batch_placeholders = []
    batch_inserted = []
    batch_new = []
    for batch_idx in range(len(y)):
        matrix, dist, edits = ld(y[batch_idx],y_ground[batch_idx], subcost=np.inf, delcost=np.inf)
        inserted = 0
        y_placeholders = []
        y_inserted = []
        y_new = []
        y_idx = 0
        i = 0
        while i < len(edits):
            if edits[i][0] == "I":
                accum = 0
                while i < len(edits) and edits[i][0] == "I":
                    y_inserted.append(edits[i][1])
                    y_new.append(edits[i][1])
                    accum += 1
                    i += 1
                y_placeholders.append(min(accum,max_placeholders-1))
            else:
                y_placeholders.append(0)
                y_new.append(y[batch_idx][y_idx])
                i += 1
                y_idx += 1
        batch_placeholders.append(y_placeholders)
        batch_inserted.append(y_inserted)
        batch_new.append(y_new)
    batch_placeholders = pad(batch_placeholders, pad_token=tokenizer.pad_token_id)

    batch_new = pad(batch_new, pad_token=tokenizer.pad_token_id)
    
    return batch_placeholders, batch_inserted, torch.LongTensor(batch_new)
   

[[30522, 0, 1, 2], [5, 10, 12]]

In [98]:
foo = InsertionInput([tokenizer.encode("Hello my name is")])
print(foo.ground)
print(tokenizer.decode(foo.post_deletion_with_placeholders[0]))
print(tokenizer.decode(foo.post_deletion[0]))
print(foo.num_placeholders)

[[7592, 2026, 2171, 2003]]
hello my <PLH>is
hello my is
[[0, 0, 1, 0]]


In [231]:
class InsertionInput():
    '''
    Wraps tensors of;
    - the original sequence
    - the post-deletion sequence
    - the post-deletion sequence (including placeholders)
    - boolean indicating whether the tokens at index i was deleted
    '''
    def __init__(self, ground, p_del=0.25):
        self.ground = ground
        self.p_del = p_del
        self.delete_random()
    
    def delete_random(self):
        """Deletes token(s) randomly from the passed (tokenized) string with probability p
        Accepts:
        - a list of token sequences bsz * pad_length
        Returns tensors of:
        - the token sequence post-deletion
        - the token sequence post-deletion, with PLH inserted at each deleted position
        - the number of placeholders inserted at each post-deletion index
        """
        self.deleted_indices = []
        self.post_deletion = []
        self.post_deletion_with_placeholders = []
        for entry in self.ground:
            indices = []
            post_deletion_with_placeholders = []
            post_deletion = []
            for token in entry:
                if random.random() < self.p_del:
                    indices.append(1)
                    post_deletion_with_placeholders.append(plh)
                else:
                    indices.append(0)
                    post_deletion.append(token)
                    post_deletion_with_placeholders.append(token)
            self.deleted_indices.append(indices)
            self.post_deletion_with_placeholders.append(post_deletion_with_placeholders)
            self.post_deletion.append(post_deletion)
        self.num_placeholders = []
        for entry in self.post_deletion_with_placeholders:
            num_placeholders = []
            accum = 0
            for token in entry:
                if token == plh:
                    accum += 1
                    continue
                if accum > 0:
                    num_placeholders.append(accum)
                    accum = 0
                num_placeholders.append(accum)
            if accum > 0:
                num_placeholders.append(accum)
            self.num_placeholders.append(num_placeholders)
        self.post_deletion = pad(self.post_deletion, pad_len=100)
        self.num_placeholders = pad(self.num_placeholders, pad_len=100)
        self.post_deletion_with_placeholders = pad(self.post_deletion_with_placeholders, pad_len=100)
        self.num_placeholders = torch.LongTensor(self.num_placeholders)
        self.post_deletion = torch.LongTensor(self.post_deletion)
        self.post_deletion_with_placeholders = torch.LongTensor(self.post_deletion_with_placeholders)

class DeletionInput():
    """Randomly inserts tokens into the sequence
    Wraps tensors of:
    - the post-insertion sequence
    - the indices of the inserted tokens
    """ 
    def __init__(self, insertion_input, t_classifier, p_ins=0.25, max_inserts=5):
#        if v >= self.alpha and len(y0) != 0:
#            y_del = y_ins
#        else:
        ground = insertion_input.ground.copy()
        bsz = len(ground)
        self.insertion_indices = [] 
        for batch_idx in range(bsz):
            i = 0
            insertion_indices = [0] * len(ground[batch_idx])
            inserted = 0
            while i < len(ground[batch_idx]):
                #print("iter")
                if random.random() < p_ins and inserted < max_inserts:
                    ground[batch_idx] = ground[batch_idx][:i] + [plh] + ground[batch_idx][i:]
                    insertion_indices[i-inserted] = 1
                    inserted += 1
                i += 1
            self.insertion_indices.append(insertion_indices)
        self.post_insertion = torch.LongTensor(pad(ground, pad_len=100))
        encoded = encoder(self.post_insertion)[0]
        logits = t_classifier(encoded)
        self.post_insertion = torch.argmax(logits,dim=2)
        self.post_insertion = torch.LongTensor(self.post_insertion)
        self.insertion_indices = pad(self.insertion_indices,pad_len=100)
        self.insertion_indices = torch.LongTensor(self.insertion_indices)
insertion_input = InsertionInput([tokenizer.encode("Hi my friend")])
deletion_input = DeletionInput(insertion_input, TokenClassifier(100, len(tokenizer), 20))
#tokenizer.decode(deletion_input.post_insertion[0].tolist())

RuntimeError: size mismatch, m1: [100 x 768], m2: [100 x 30523] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:752

In [None]:
class Model():
    def __init__(self, vocab_size, hsz=768, lr=0.0001, max_placeholders=10, alpha=0, beta=0):
        super().__init__()
        self.p_classifier = PlaceholderClassifier(hsz, max_placeholders=max_placeholders)
        self.t_classifier = TokenClassifier(hsz, vocab_size, 20)
        self.d_classifier = DeletionClassifier(hsz)        
        self.dataset = QADataset("/virtualmachines/data/trivia_qa/qa/wikipedia-train.json")
        self.tokenizer = tokenizer
        self.encoder = encoder
        self.alpha = alpha
        self.beta = beta
        self.max_placeholders = max_placeholders             
        
        self.alpha = 0.5
        self.beta = 0.5
        self.p_loss = nn.CrossEntropyLoss()
        self.t_loss = nn.CrossEntropyLoss()
        self.d_loss = nn.CrossEntropyLoss()
        
        self.optims = {
            'p_classifier': optim.SGD(self.p_classifier.parameters(), lr=lr),
            't_classifier': optim.SGD(self.t_classifier.parameters(), lr=lr),
            'd_classifier': optim.SGD(self.d_classifier.parameters(), lr=lr),
        }
        
        self.step = 0
        self.loss = 0
                        
    def sample(self, bs=10, pad_to=30):
        '''Sample an observation and return tensor tuples:
        
        2) (a) 1(a), but with PLH replacing each deleted token. For input to the token insertion classifier 
           (b) the number of placeholders tokens to insert at each index in 1(a) (for the placeholder classifier loss)
        3) (a) the observation perturbed with insertion (tokens), for input to the token deletion classifier 
           (b) 2(a) (for the token insertion classifier loss)
        '''
        
        v = random.random()
        
        # sample an untokenized string
        y_ground = [self.dataset.sample() for i in range(bs)]
        # tokenize/encode
        encoded = [self.tokenizer.encode(s) for s in y_ground]
        
        insertion_input = InsertionInput(encoded)
        deletion_input = DeletionInput(insertion_input, self.t_classifier)
    
        return insertion_input, deletion_input
    
    def zero_grad(self):
        for optimizer in self.optims.values():
            optimizer.zero_grad()

    def update_params(self):
        for optimizer in self.optims.values():
            optimizer.step()
            
    def encode_list(self, tokens):
        if type(tokens) == list:
            return encoder(torch.LongTensor([tokens]))[0]
        return encoder(tokens)[0]
            
    def pretty_print(self, label, tokens):
        if type(tokens) == torch.Tensor:
            tokens = tokens.tolist()
        pretty = "%s \n %s" % (label, tokenizer.decode(tokens))
        pretty = pretty.replace("[PAD]","")
        print(pretty)
        
    def decode_step(self):
        with torch.no_grad():
            self.p_classifier.eval()
            self.t_classifier.eval()
            self.d_classifier.eval()

            insertion_input, deletion_input = self.sample(bs=1)
            self.pretty_print("Ground truth", insertion_input.ground[0])
            self.pretty_print("Ground truth post-deletion", insertion_input.post_deletion[0].tolist())

            step = 0
            max_steps = 10
            last = insertion_input.post_deletion
            ground = last
            
            while True:
                print("Decode step %d" % step)
            
                if step > max_steps:
                    ground = last
                    break
                
                if step > 0 and (last.size() == ground.size() and torch.all(last == ground)):
                    break
                
                step += 1    
                
                # run a deletion pass
                preds_deletes = self.d_classifier(self.encode_list(ground))
                if preds_deletes.size(1) > 0:
                    deletions = torch.argmax(preds_deletes,2)
                    assert ground.size(1) >= deletions.size(1)
                    deleted = [ground[0,i].item() for i in range(deletions.size(1)) if deletions[0,i] == 0]
                    ground = torch.LongTensor([deleted])
                    self.pretty_print("Post-deletion", deleted)
                else:
                    print("No deletions")
                    
                preds_placeholders = self.p_classifier(self.encode_list(ground))
                # then run a placeholder pass                    
                if preds_placeholders.size(1) == 0:
                    print("No placeholders")
                    continue
                    
                placeholders = torch.argmax(preds_placeholders,2)
                
                reconstructed = []
                added_placeholders = 0
                for i in range(ground.size(1)):
                    for j in range(min(self.max_placeholders, placeholders[0,i])):
                        reconstructed.append(plh)
                        added_placeholders += 1
                    reconstructed.append(ground[0,i].item())

                self.pretty_print("Post-placeholders", reconstructed)
                      
                # then an insertion pass
                preds_inserts = self.t_classifier(self.encode_list(reconstructed))
                inserts = torch.argmax(preds_inserts, 2)
                
                output = []
                for i in range(len(reconstructed)):
                    if reconstructed[i] == plh:
                        output.append(inserts[0,i].item())
                    else:
                        output.append(reconstructed[i])
                self.pretty_print("Post-insert", output)
                last = torch.LongTensor([output])
            if last is not None:
                print(last[0].tolist())
                print(tokenizer.decode(last[0].tolist()))
            else:
                print("Couldn't decode")
            
    def train_step(self):
        loss = 0
        self.zero_grad()
        self.p_classifier.train()
        self.t_classifier.train()
        self.d_classifier.train()
        
        insertion_input, deletion_input = self.sample(bs=1)
        
        preds_deletes = self.d_classifier(encoder(deletion_input.post_insertion)[0])
        preds_placeholders = self.p_classifier(encoder(insertion_input.post_deletion)[0])
        preds_inserts = self.t_classifier(encoder(insertion_input.post_deletion_with_placeholders)[0])
        
        loss = self.d_loss(torch.transpose(preds_deletes, 1, 2), deletion_input.insertion_indices)
        loss += self.p_loss(torch.transpose(preds_placeholders, 1, 2), insertion_input.num_placeholders)
        loss += self.t_loss(torch.transpose(preds_inserts,1,2), insertion_input.post_deletion)
        self.step += 1
        self.loss += loss
        if self.step % 10 == 0:
            print(self.step)
        if self.step % 50 == 0:
            print(self.step)
            print(self.loss / 50)
            self.loss = 0
            self.decode_step()
        
        loss.backward()
        self.update_params()

model = Model(len(tokenizer))
for i in range(10000):
    model.train_step()
model.decode_step()

In [43]:
def decode(self, text):
        
    self.p_classifier.eval()
    self.t_classifier.eval()
    self.d_classifier.eval()

    step = 0
    max_steps = 10
    
    y_last = None
    y_ground = torch.LongTensor([tokenizer.encode(text)])
    y_ground_enc = encoder(y_ground)[0]
    
    while True:
        print("Decode step %d" % step)
        step += 1
        if step > max_steps:
            y_ground = y_last
            break
        if y_last is not None:
            if (y_last.size() == y_ground.size() and torch.all(y_last == y_ground)):
                break         
            y_ground = y_last
                
        # deletion first
        preds_deletes = self.d_classifier(y_ground_enc).cuda()
        if preds_deletes.size(1) == 0:
            continue

        deletions = torch.argmax(preds_deletes,2)
        print(deletions)

        deleted = [y_ground[0,i].item() for i in range(deletions.size(1)) if deletions[0,i] == 0]

        y_ground = tokenizer.decode(deleted)
        print("post deleted")
        print(y_ground)

        # then placeholder
        y_ground = encoder(torch.LongTensor([deleted]))[0]

        preds_placeholders = self.p_classifier(y_ground).cuda()
        if preds_placeholders.size(1) == 0:
            y_ground = torch.LongTensor([deleted])
            continue
        placeholders = torch.argmax(preds_placeholders,2)

        reconstructed = []
        for i in range(placeholders.size(1)):
            for j in range(placeholders[0,i].item()):
                reconstructed.append(plh)
            reconstructed.append(deleted[i])

        y_ground = torch.LongTensor([reconstructed])

        # then inserts
        inserts = self.t_classifier(encoder(torch.LongTensor([reconstructed]))[0]).cuda()
        inserts = torch.argmax(inserts, 2)
        j = 0

        for i in range(len(reconstructed)):
            if reconstructed[i] == plh:
                if i < len(reconstructed):
                    reconstructed[i] = inserts[0,j].item()
                else:
                    reconstructed.append(inserts[0,j].item())
                j += 1
        y_last = torch.LongTensor([reconstructed])
    if y_last is not None:
        print(y_last[0].tolist())
        print(tokenizer.decode(y_last[0].tolist()))
    else:
        print("y_last none")
decode(model, "I are go to")
#decode(model, "She am my mother")

Decode step 0
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
i are go to
Decode step 1
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
##nxnxnxnx
Decode step 2
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
##vancesnxvancesnx
Decode step 3
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
##vancesvancesvancesnx
Decode step 4
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
##vancesvancesvancesvances
Decode step 5
tensor([[0, 0, 0, 0]], device='cuda:0')
post deleted
midtownvancesvancesvances
Decode step 6
[27219, 26711, 26711, 26711]
midtownvancesvancesvances
