In [1]:
import os

import json

from transformers import BertTokenizer, BertForMaskedLM
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [2]:
print("Global initialization started.")
WORK_DIR = "/data/disk5/private/yuc/coref/bert-tagger"
FILE_LIST = "filelist.txt"
WIKI_DIR = os.path.join(WORK_DIR, "../wikipedia/text")
# DUMP_DIR =  os.path.join(WORK_DIR, "playground/dump")
DUMP_DIR = os.path.join(WORK_DIR, "playground/dump_kl_para")
LOG_DIR = os.path.join(WORK_DIR, "playground/logs")

global_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
global_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
print("Global initialization completed.")

Global initialization started.


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Global initialization completed.


In [167]:
def default_transform(sentence):
    if sentence[0] == "<":
        return None
    # raw_tokens: parsed into subword but not yet converted to ids
    raw_tokens = global_tokenizer.tokenize(sentence.strip())
    # tokens converted to ids
    tokens = global_tokenizer(raw_tokens, return_tensors="pt", is_split_into_words=True)["input_ids"]
    tokens = torch.squeeze(tokens)
    # l = len(tokens)
    l = tokens.shape[0]
    if l >= 30: # ignore doc that is too long
        return None
    if l <= 5: # ignore invalid lines and short sentences
        return None

    # this is the format of a sentence.
    return {
        "tokens": tokens,
        "raw": raw_tokens
    }

ts = default_transform("To be or not to be, this is the question.")
print(ts)

{'tokens': tensor([ 101, 2000, 2022, 2030, 2025, 2000, 2022, 1010, 2023, 2003, 1996, 3160,
        1012,  102]), 'raw': ['to', 'be', 'or', 'not', 'to', 'be', ',', 'this', 'is', 'the', 'question', '.']}


In [168]:
class SentenceIterable:
    def __init__(self,
        file_path_list=FILE_LIST,
        file_id=0,
        stc_id=0,
        transform=default_transform):
        self.file_id = file_id
        self.stc_id = stc_id
        with open(file_path_list, "r") as f_list:
            self.file_paths = f_list.read().split()
        if transform == None:
            self.transform = default_transform
        else:
            self.transform = transform
        print("SentenceIterable constructed.")
        
    def __iter__(self):
        return self.sentence_generator()
    
    def sentence_generator(self):
        file_count = len(self.file_paths)
        while self.file_id < file_count:
            file_path = self.file_paths[self.file_id]
            with open(file_path) as fs:
                sentences = fs.readlines()
                sentence_count = len(sentences)
                while self.stc_id < sentence_count:
                    sentence = sentences[self.stc_id]
                    sentence = self.transform(sentence)
                    if sentence == None:
                        print("sentence discarded.")
                    else:
                        yield (sentence, self.file_id, self.stc_id)
                    self.stc_id += 1
            self.stc_id = 0
            self.file_id += 1

In [170]:
class QuestionPairIterable(Dataset):
    def __init__(self, 
        sentence,
        mask_placeholder="[MASK]",
        miss_placeholder="[MASK]"):
        super(QuestionPairIterable).__init__()
        self.sentence = sentence["tokens"]
        self.miss_ph = miss_placeholder
        self.mask_ph = mask_placeholder
        self.miss_id = global_tokenizer.convert_tokens_to_ids(miss_placeholder)
        self.mask_id = global_tokenizer.convert_tokens_to_ids(mask_placeholder)
        length = len(self.sentence)
        self.index_pairs = [
            ([miss_index], [mask_index])
            for miss_index in range(1, length-1)
                for mask_index in range(1, length-1)
                    if miss_index != mask_index
        ]

        self.start = 0
        self.end = len(self.index_pairs)
        print("QuestionPairIterable constructed.")

    def __len__(self):
        return len(self.index_pairs)

    def __getitem__(self, index):
        missing_indices, masked_indices = self.index_pairs[index]
        # unmasked_question = list(self.sentence)
        unmasked_question = self.sentence.clone()
        for missing_index in missing_indices:
            # unmasked_question[missing_index] = self.miss_ph
            unmasked_question[missing_index] = self.miss_id
        # masked_question = list(unmasked_question)
        masked_question = unmasked_question.clone()
        for masked_index in masked_indices:
            # masked_question[masked_index] = self.mask_ph
            masked_question[masked_index] = self.mask_id
        return {
            "label": self.sentence,
            "unmasked": unmasked_question, 
            "masked": masked_question, 
            "miss_id": torch.tensor(missing_indices), 
            "mask_id": torch.tensor(masked_indices)
        }
    
def test_pair_iterable():
    sentence = {
        "tokens": torch.from_numpy(np.array([1,2,3,4,5]))
    }
    dataset = QuestionPairIterable(sentence)
    for sample in dataset:
        print(sample)

test_pair_iterable()

QuestionPairIterable constructed.
{'label': tensor([1, 2, 3, 4, 5]), 'unmasked': tensor([  1, 103,   3,   4,   5]), 'masked': tensor([  1, 103, 103,   4,   5]), 'miss_id': tensor([1]), 'mask_id': tensor([2])}
{'label': tensor([1, 2, 3, 4, 5]), 'unmasked': tensor([  1, 103,   3,   4,   5]), 'masked': tensor([  1, 103,   3, 103,   5]), 'miss_id': tensor([1]), 'mask_id': tensor([3])}
{'label': tensor([1, 2, 3, 4, 5]), 'unmasked': tensor([  1,   2, 103,   4,   5]), 'masked': tensor([  1, 103, 103,   4,   5]), 'miss_id': tensor([2]), 'mask_id': tensor([1])}
{'label': tensor([1, 2, 3, 4, 5]), 'unmasked': tensor([  1,   2, 103,   4,   5]), 'masked': tensor([  1,   2, 103, 103,   5]), 'miss_id': tensor([2]), 'mask_id': tensor([3])}
{'label': tensor([1, 2, 3, 4, 5]), 'unmasked': tensor([  1,   2,   3, 103,   5]), 'masked': tensor([  1, 103,   3, 103,   5]), 'miss_id': tensor([3]), 'mask_id': tensor([1])}
{'label': tensor([1, 2, 3, 4, 5]), 'unmasked': tensor([  1,   2,   3, 103,   5]), 'masked':

In [171]:
class QuestionPairConsumer:
    def __init__(self,
        tokenizer=global_tokenizer,
        model=global_model,
        measure=kl_divergence_dist):
        self.tokenizer = tokenizer
        self.model = model
        self.measure = measure
    
    def consume_question_pair(self, question_pair):
        # [B(atch), L(ength of sentence)]
        context = question_pair["label"]
        unmasked = question_pair["unmasked"]
        masked = question_pair["masked"]
        # [B(atch), n(umber of missing tokens)]
        missing_indices = question_pair["miss_id"]
        masked_indices = question_pair["mask_id"]
        # u_pred = consume_question(unmasked, context)
        # m_pred = consume_question(masked, unmasked)
        # [B(atch), L(ength of sentence), V(ocabulary size)]
        u_logits = self.model(input_ids=unmasked).logits
        m_logits = self.model(input_ids=masked).logits

        missing_label_ids = torch.gather(context, 1, missing_indices) # [B, n]
        answer_shape = list(missing_indices.shape)
        answer_shape.append(u_logits.shape[2])
        missing_indices = missing_indices.unsqueeze(2).expand(answer_shape) # [B, n, V]
        missing_label_ids = missing_label_ids.unsqueeze(2).expand(answer_shape) # [B, n, V]
        
        ones_template = torch.tensor([[[1.]]]).expand(answer_shape) # [B, n, V]
        # golden logits ,g_logits[b][n][index[b][n]] = 1
        g_logits = torch.scatter(torch.zeros(answer_shape), 2, missing_label_ids, ones_template)
        # unmasked logits
        u_logits = torch.gather(u_logits, 1, missing_indices)
        # masked logits
        m_logits = torch.gather(m_logits, 1, missing_indices)
        
        return self.measure(m_logits, u_logits, g_logits)

In [181]:
def get_batch_size(batch):
    for value in batch.values():
        return value.shape[0]

class SaveManager:
    def __init__(self,
        dump_dir=DUMP_DIR,
        counter=0,
        log_interval=100,
        save_interval=500):
        self.sentence_dict = {}
        self.relation_list = []
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.counter = counter - counter % save_interval
        self.dump_dir = dump_dir
        self.progress_path = os.path.join(self.dump_dir, "progress.log")
        self.rel_template = os.path.join(dump_dir, "relation_list_cnt_{}.dump")
        self.stc_template = os.path.join(dump_dir, "sentence_dict_cnt_{}.dump")

    def load_progress(self):
        return (0, 0)

        if os.path.exists(self.progress_path):
            with open(self.progress_path, "r") as p_log:
                progress = json.load(p_log)
                file_id = progress["file_id"]
                stc_id = progress["stc_id"]
                self.save_interval = progress["save_interval"]
                self.counter = progress["counter"]
                return (file_id, stc_id)
        return (0, 0)

    def dump_progress(self, file_id, stc_id):
        with open(self.progress_path, "w") as p_log:
            progress = {
                "file_id": file_id,
                "stc_id": stc_id,
                "counter": self.counter,
                "save_interval": self.save_interval
            }
            p_log.write(json.dumps(progress))
        
    def save_sentence_list(self):
        sentence_list = []
        for context_id, raw_tokens in self.sentence_dict:
            self.sentence_list.append({
            "id": context_id,
            "context": raw_tokens
        })
        sentence_list.sort(key=lambda x:x["id"])
        save_path = self.stc_template.format(self.counter)
        with open(save_path, "w") as f:
            for sentence in sentence_list:
                f.write(json.dumps(sentence)+"\n")

    def update_sentence(self, sentence, context_id):
        self.sentence_dict[context_id] = sentence["raw"]
    
    def save_relation_list(self):
        save_path = self.rel_template.format(self.counter)
        with open(save_path, "w") as f:
            for relation in self.relation_list:
                f.write(json.dumps(relation)+"\n")
    
    def update_relation(self, sample, distance, context_id):
        self.relation_list.append({
            "context": context_id,
            "missing_index": sample["miss_id"].tolist(),
            "masked_index": sample["mask_id"].tolist(),
            "distance": float(distance)
        })
        self.counter += 1
        if self.counter % self.log_interval == 0:
            print("Got example count: ", self.counter)
        if self.counter % self.save_interval == 0:
            print("Save examples.")
            save_relation_list()
            save_sentence_list()
            self.relation_list = []
            self.sentence_dict = {}

    def update_relation_batched(self, batch, distance, context_id):
        batch_size = get_batch_size(batch)
        for index in range(0, batch_size):
            relation = {}
            for key, batched_tensor in batch.items():
                relation[key] = batched_tensor[index]
            self.update_relation(relation, distance[index], context_id)

In [182]:
def main():
    sentence_list = []
    relation_list = []
    log_interval = 100
    save_interval = 500

    save_manager = SaveManager(save_interval=save_interval)
    last_file_id, last_stc_id = save_manager.load_progress()
   
    sentence_dataset = SentenceIterable(file_id=last_file_id,stc_id=last_stc_id)
    consumer = QuestionPairConsumer() 

    for sentence, file_id, stc_id in sentence_dataset:
        context_id = file_id * 50000 + stc_id
        stc_relation_list = []
        question_pair_dataset = QuestionPairIterable(sentence)
        dataloader = DataLoader(question_pair_dataset, batch_size=32, num_workers=0)
        for sample_batched in dataloader:
            distance = consumer.consume_question_pair(sample_batched)
            save_manager.update_sentence(sentence, context_id)
            save_manager.update_relation_batched(sample_batched, distance, context_id)
            save_manager.dump_progress(file_id, stc_id)
            break
        break

In [183]:
main()

SentenceIterable constructed.
sentence discarded.
QuestionPairIterable constructed.


In [102]:
sample = main()
label = sample["label"]
unmasked = sample["unmasked"]
masked = sample["masked"]
print(sample)

0 0 0
SentenceIterable constructed.
sentence discarded.
10
QuestionPairIterable constructed.
{'label': tensor([[  101, 28506,  1516,  2849,  6590,  3351,  3231,  2005,  9874,   102],
        [  101, 28506,  1516,  2849,  6590,  3351,  3231,  2005,  9874,   102],
        [  101, 28506,  1516,  2849,  6590,  3351,  3231,  2005,  9874,   102],
        [  101, 28506,  1516,  2849,  6590,  3351,  3231,  2005,  9874,   102]]), 'unmasked': tensor([[ 101,  103, 1516, 2849, 6590, 3351, 3231, 2005, 9874,  102],
        [ 101,  103, 1516, 2849, 6590, 3351, 3231, 2005, 9874,  102],
        [ 101,  103, 1516, 2849, 6590, 3351, 3231, 2005, 9874,  102],
        [ 101,  103, 1516, 2849, 6590, 3351, 3231, 2005, 9874,  102]]), 'masked': tensor([[ 101,  103,  103, 2849, 6590, 3351, 3231, 2005, 9874,  102],
        [ 101,  103, 1516,  103, 6590, 3351, 3231, 2005, 9874,  102],
        [ 101,  103, 1516, 2849,  103, 3351, 3231, 2005, 9874,  102],
        [ 101,  103, 1516, 2849, 6590,  103, 3231, 2005, 9874

In [134]:
consumer = QuestionPairConsumer()
g_logits, u_logits, m_logits = consumer.consume_question_pair(sample)
print(g_logits, u_logits, m_logits)

[4, 1, 30522]
tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]]]) tensor([[[-6.9024, -6.7771, -6.7575,  ..., -7.0699, -6.2059, -6.7337]],

        [[-6.9024, -6.7771, -6.7575,  ..., -7.0699, -6.2059, -6.7337]],

        [[-6.9024, -6.7771, -6.7575,  ..., -7.0699, -6.2059, -6.7337]],

        [[-6.9024, -6.7771, -6.7575,  ..., -7.0699, -6.2059, -6.7337]]],
       grad_fn=<GatherBackward>) tensor([[[-6.9170, -6.7934, -6.7733,  ..., -7.1199, -6.1876, -6.8410]],

        [[-6.9155, -6.7748, -6.8467,  ..., -6.9335, -6.4507, -6.2258]],

        [[-6.6585, -6.5278, -6.6587,  ..., -6.7598, -6.3038, -6.3652]],

        [[-6.6641, -6.6089, -6.6145,  ..., -7.1140, -6.7535, -5.7576]]],
       grad_fn=<GatherBackward>)


In [142]:
# result: masked; target: unmasked index: golden
def index_only_dist(result, target, index):
    n_dim = 1
    v_dim = 2
    return torch.mean(
        torch.sum(
            F.relu(target - result) * index, dim=v_dim
        ),
        dim=n_dim
    )

def kl_divergence_dist(result, target, index):
    n_dim = 1
    v_dim = 2
    return torch.mean(
        torch.sum(
            F.softmax(target, dim=v_dim) * ( - F.log_softmax(result, dim=v_dim) + 
            F.log_softmax(target, dim=v_dim)),
            dim=v_dim
        ),
        dim=n_dim
    )

def cross_entropy_dist(result, target, index):
    n_dim = 1
    v_dim = 2
    return torch.mean(
        torch.sum(
            - F.softmax(target, dim=v_dim) * F.log_softmax(result, dim=v_dim),
            dim=v_dim
        ),
        dim=n_dim
    )

print(index_only_dist(u_logits, m_logits, g_logits))
print(kl_divergence_dist(u_logits, m_logits, g_logits))
print(cross_entropy_dist(u_logits, m_logits, g_logits))

tensor([0.0502, 0.0000, 0.0000, 0.2214], grad_fn=<MeanBackward1>)
tensor([0.0186, 0.2816, 0.3189, 0.4473], grad_fn=<MeanBackward1>)
tensor([5.5009, 5.9601, 7.2967, 7.6955], grad_fn=<MeanBackward1>)


In [149]:
class MyDataset(Dataset):
    def __init__(self):
        # self.data = ["To be or not to be, this is the question.".split(),] * 12
        # self.data = [[1,2,3,4],] * 12
        self.data = [torch.tensor([1,2,3,4]),] * 12
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index]
    
dataset = MyDataset()
for sample in dataset:
    print(sample)

tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])
tensor([1, 2, 3, 4])


In [151]:
def test_dataloader():
    dataset = MyDataset()
    dataloader = DataLoader(dataset, batch_size = 4)
    for sample in dataloader:
        print(sample)
        break

test_dataloader()

tensor([[1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4],
        [1, 2, 3, 4]])


In [74]:
import numpy as np

In [159]:
a = torch.tensor([[0,1,2],[0,3,4]])
for ele in a:
    print(ele)
b = a.tolist()
print(b)

tensor([0, 1, 2])
tensor([0, 3, 4])
[[0, 1, 2], [0, 3, 4]]


In [161]:
a = {1:2,3:4}
print(a.values())
def test(d):
    d.update({5:6})
test(a)
print(a)

dict_values([2, 4])
{1: 2, 3: 4, 5: 6}
