In [34]:
import pickle
import re
import nltk
import time
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel, BertConfig, BertPooler
import torch
from models.bert import Config
from keras.preprocessing.sequence import pad_sequences
from torch.nn.functional import softmax

from tqdm import tqdm
import numpy as np


def merge_mask(src, mask, tokenizer):
    mask_new = []
    i = 0
    for word in src:
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)
        if n_subwords > 1:
            if 1 in mask[i:i + n_subwords]:
                mask_new.append(1)
            else:
                mask_new.append(0)
        else:
            mask_new.append(mask[i])
        i += n_subwords

    return mask_new


def simplify(word):
    new_word = word.lower().replace('( ', '(').replace(' )', ')').replace('  ', ' ').replace(' ,', ',').strip()
    return new_word


def min_distance(str1, str2):
    matrix = [[i + j for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]

    for i in range(1, len(str1) + 1):
        for j in range(1, len(str2) + 1):
            if str1[i - 1] == str2[j - 1]:
                d = 0
            else:
                d = 1
            matrix[i][j] = min(matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, matrix[i - 1][j - 1] + d)

    return matrix[len(str1)][len(str2)]


def is_similar(word_1, word_2, stop=False):
    word_1 = simplify(word_1)
    word_2 = simplify(word_2)
    flag = False
    if ('(' in word_1 or '(' in word_2) and not stop:
        _word_1 = re.sub(u"\\(.*?\\)|\\{.*?}|\\[.*?]", "", word_1)
        _word_2 = re.sub(u"\\(.*?\\)|\\{.*?}|\\[.*?]", "", word_2)
        flag = is_similar(_word_1, _word_2, True)

    return word_1 == word_2 or \
           word_1 in word_2 or \
           word_2 in word_1 or \
           min_distance(word_1, word_2) <= 2 or \
           flag


def get_train_src_tar_txt(train_txt_path):
    src = []
    tar_1 = []
    tar_2 = []
    txt = ''
    try:
        txt += open(train_txt_path, 'r').read()
    except:
        txt += open(train_txt_path, 'r', encoding='utf-8').read()

    txt = txt.split('\n\n')
    for para in txt:
        sentences = para.split('\n')
        if len(sentences) < 2:
            continue
        for sid, sentence in enumerate(sentences[0:3]):
            if sid == 0:
                src.append(sentence)
            elif sid == 1:
                tar_1.append(sentence)
            elif sid == 2:
                tar_2.append(sentence)
    return src, tar_1, tar_2


def get_test_src_tar_txt(test_txt_path):
    txt = open(test_txt_path, 'r').read()
    #     txt = txt.lower()
    txt = txt.split('\n\n')
    src = []
    tar_1 = []
    tar_2 = []
    for para in txt:
        sentences = para.split('\n')
        src_sentence = ''
        if len(sentences) < 2 or len(sentences[0]) < 3 or len(sentences[1]) < 3:
            continue
        for sid, sentence in enumerate(sentences):
            if sid == 0:
                src.append(sentence)
            elif sid <= 2:
                cudic = {}
                sentence = sentence[2:]
                sentence = sentence.replace('].', '] .')
                text = re.sub('\[[^\[\]]*\]', '', sentence)
                pairs = re.findall('[^\[\] ]+\[[^\[\]]+\]', sentence)
                for pair in pairs:
                    pair = re.split('[\[\]]', pair)
                    cudic[pair[0]] = pair[1]
                words = nltk.word_tokenize(text)
                for wid, word in enumerate(words):
                    if word in cudic.keys():
                        words[wid] = cudic[word]
                new_text = ' '.join(words)
                if sid == 1:
                    tar_1.append(new_text)
                else:
                    tar_2.append(new_text)
    return src, tar_1, tar_2


def get_dcmn_data_from_gt(src_words, tar_words, abbrs, max_pad_length, max_dcmn_seq_length, tokenizer):
    if tar_words[-1] != '.':
        tar_words.append('.')
    i = 0
    j = 0
    sentences = []
    labels = []
    srcs = []
    keys = []
    key_ans = {}

    while i < len(src_words):
        if src_words[i] == tar_words[j]:
            i += 1
            j += 1
        else:
            p = i + 1
            q = j + 1

            while p < len(src_words):
                while q < len(tar_words) and tar_words[q] != src_words[p]:
                    q += 1
                if q == len(tar_words):
                    p = p + 1
                    q = j + 1
                else:
                    break
            if p - i == 1:
                pre = src_words[i]
                aft = " ".join(tar_words[j:q])
                if pre in abbrs.keys():
                    pass
                elif pre.upper() in abbrs.keys():
                    pre = pre.upper()
                elif pre.lower() in abbrs.keys():
                    pre = pre.lower()

                if pre in abbrs.keys():
                    temp = [' '.join(src_words), 'what is {} ?'.format(pre)]
                    label = -1
                    skip_cnt = 0
                    for index, u in enumerate(abbrs[pre]):
                        if index - skip_cnt >= max_pad_length - 2:
                            break
                        if len(u.split(' ')) > 10:
                            skip_cnt += 1
                            continue
                        h = u
                        temp.append(h)
                        if is_similar(u, aft):
                            label = index
                    while len(temp) < max_pad_length:
                        temp.append('[PAD]')
                    if len(tokenizer.tokenize(temp[0])) + len(tokenizer.tokenize(temp[1])) + len(
                            tokenizer.tokenize(temp[2])) >= max_dcmn_seq_length \
                            or label < 0 or label >= max_pad_length - 2:
                        pass
                    else:
                        sentences.append(temp)
                        labels.append(label)
                        keys.append(pre)
                        key_ans[pre] = label
                        srcs.append('[CLS] ' + ' '.join(src_words[:i]) + ' [MASK] ' + ' '.join(src_words[p:]))

            i = p
            j = q
    return sentences, labels, srcs, keys, key_ans


def get_dcmn_data_from_step1(src_words, masks, k_a, abbrs, max_pad_length, max_dcmn_seq_length, tokenizer):
    sentences = []
    srcs = []
    keys = []
    labels = []
    for i, mask in enumerate(masks):
        if mask == 0:
            continue
        key = src_words[i]
        if (key in abbrs.keys() and key in k_a.keys() and 0 <= k_a[key] < max_pad_length - 2) or \
                (key in abbrs.keys() and key not in k_a.keys() and len(abbrs[key]) == 1):
            temp = [' '.join(src_words), 'what is {} ?'.format(key)]
            if key in k_a.keys():
                label = k_a[key]
            else:
                label = 0

            skip_cnt = 0
            for index, u in enumerate(abbrs[key]):
                if index - skip_cnt >= max_pad_length - 2:
                    break
                if len(u.split(' ')) > 10:
                    skip_cnt += 1
                    continue
                h = u
                temp.append(h)

            while len(temp) < max_pad_length:
                temp.append('[PAD]')

            if len(tokenizer.tokenize(temp[0])) + len(tokenizer.tokenize(temp[1])) + len(
                    tokenizer.tokenize(temp[2])) >= max_dcmn_seq_length:
                continue
            sentences.append(temp)
            keys.append(key)
            labels.append(label)
            srcs.append('[CLS] ' + ' '.join(src_words[:i]) + ' [SEP] [MASK] [SEP] ' + ' '.join(src_words[i + 1:]))

    return sentences, labels, srcs, keys


def add_sep(train_srcs, train_tars, train_order):
    train_src_new = []
    train_tar_new = []
    i, j = 0, 0
    for src in train_srcs:
        src = src.split(' ')
        src_new = src
        if src_new[-1] != '.':
            src_new.append('.')
        while j >= train_order[i]:
            i += 1
            j = 0
        j += 1
        tar = train_tars[i]
        tar = tar.split(' ')
        tar_new = []
        p = 0

        for i, u in enumerate(src):
            if u == '[MASK]':
                while p < len(tar) and tar[p] != src[i - 1]:
                    tar_new.append(tar[p])
                    p += 1
                if p < len(tar):
                    tar_new.append(tar[p])
                    p += 1
                tar_new.append('[SEP]')
                while p < len(tar) and tar[p] != src[i + 1]:
                    tar_new.append(tar[p])
                    p += 1
                tar_new.append('[SEP]')
        while p < len(tar):
            tar_new.append(tar[p])
            p += 1
        train_src_new.append(src_new)
        train_tar_new.append(tar_new)

    for i, u in enumerate(train_src_new):
        tmp = []
        for j, v in enumerate(u):
            if v == '[MASK]':
                tmp.append('[SEP]')
                tmp.append('[MASK]')
                tmp.append('[SEP]')
            else:
                tmp.append(v)
        train_src_new[i] = tmp
    train_input = []
    for u in train_src_new:
        train_input.append(' '.join(u))
    train_output = []
    for u in train_tar_new:
        train_output.append(' '.join(u))

    return train_input, train_output


def get_embs(dcmn_keys, abbrs, max_pad_length):
    device = torch.device('cuda')
    bert_model = 'bert-base-cased'
    bert = BertModel.from_pretrained(bert_model)
    bert.to(device)
    tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=False)
    pad_tokens = ['[PAD]']
    ids = tokenizer.convert_tokens_to_ids(pad_tokens)
    inputs = [ids]
    inputs = torch.tensor(inputs)
    inputs = inputs.to(device)
    with torch.no_grad():
        _, pad_embs = bert(inputs)
    pad_embs = pad_embs.cpu().detach().numpy()[0]
    dcmn_embs = []
    for key in tqdm(dcmn_keys):
        emb_values = []
        skip_cnt = 0
        for i, value in enumerate(abbrs[key]):
            if len(value.split(' ')) > 10:
                skip_cnt += 1
                continue
            if i - skip_cnt >= max_pad_length - 2:
                break
            tokens = tokenizer.tokenize(key)
            ids = tokenizer.convert_tokens_to_ids(tokens)
            inputs = [ids]
            inputs = torch.tensor(inputs)
            inputs = inputs.to(device)
            with torch.no_grad():
                _, embs = bert(inputs)
                emb_values.append(embs.cpu().detach().numpy()[0])
        while len(emb_values) < max_pad_length - 2:
            emb_values.append(pad_embs)
        dcmn_embs.append(emb_values)
    return dcmn_embs


def seq_tokenize(input_data, tokenizer, max_seq_length):
    ids = []
    for data in tqdm(input_data):
        words = tokenizer.tokenize(data)
        ids.append(words)

    ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in ids],
                        maxlen=max_seq_length, dtype="long", value=0,
                        truncating="post", padding="post")
    masks = [[float(i != 0.0) for i in ii] for ii in ids]
    return ids, masks


def get_index(srcs):
    tar_indexs = []
    for i, src in enumerate(srcs):
        for j, u in enumerate(src):
            if u == 4:
                tar_indexs.append([0] * j + [1] + [0] * (len(src) - j - 1))
                break
    return tar_indexs


class DataGenerator():
    def __init__(self, seq_batch_size, max_pad_length=16, max_seq_length=64, cuda=True, emb_size=768):
        if cuda:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.emb_size = emb_size

        self.abbrs_path = './data/abbrs-all-cased.pkl'
        self.train_txt_path = './data/train(12809).txt'
        self.test_txt_path = './data/test(2030).txt'
        with open(self.abbrs_path, 'rb') as f:
            self.abbrs = pickle.load(f)
        self.train_src_txt, self.train_tar_1_txt, self.train_tar_2_txt = get_train_src_tar_txt(self.train_txt_path)
        self.train_src_txt = self.train_src_txt
        self.train_tar_1_txt = self.train_tar_1_txt
        self.train_tar_2_txt = self.train_tar_2_txt

        self.test_src_txt, self.test_tar_1_txt, self.test_tar_2_txt = get_test_src_tar_txt(self.test_txt_path)

        # generate data
        self.train_seq_srcs = []
        self.train_dcmn_srcs = []
        self.train_dcmn_labels = []
        self.train_keys = []
        self.train_order = []
        self.test_seq_srcs = []
        self.test_dcmn_srcs = []
        self.test_dcmn_labels = []
        self.test_keys = []
        self.test_order = []
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

        for i, (src, tar) in enumerate(zip(self.train_src_txt, self.train_tar_1_txt)):
            src = nltk.word_tokenize(src)
            tar = nltk.word_tokenize(tar)
            sentences, labels, srcs, keys, key_ans = get_dcmn_data_from_gt(src, tar, self.abbrs,
                                                                           max_pad_length=max_pad_length,
                                                                           max_dcmn_seq_length=max_seq_length,
                                                                           tokenizer=tokenizer)
            self.train_dcmn_srcs.extend(sentences)
            self.train_dcmn_labels.extend(labels)
            self.train_seq_srcs.extend(srcs)
            self.train_keys.extend(keys)
            self.train_order.append(len(sentences))

        with open('./data/test_mask_step2_2030.pkl', 'rb') as f:
            test_mask_step2 = pickle.load(f)
        test_mask = []

        for src, mask in zip(self.test_src_txt, test_mask_step2):
            src = nltk.word_tokenize(src)
            mask_new = merge_mask(src, mask, tokenizer)
            test_mask.append(mask_new)

        k_a = []
        for i, (src, tar) in enumerate(zip(self.test_src_txt, self.test_tar_1_txt)):
            src = nltk.word_tokenize(src)
            tar = nltk.word_tokenize(tar)
            sentences, labels, src, keys, key_ans = get_dcmn_data_from_gt(src, tar, self.abbrs,
                                                                          max_pad_length=max_pad_length,
                                                                          max_dcmn_seq_length=max_seq_length,
                                                                          tokenizer=tokenizer)
            k_a.append(key_ans)

        for i, (sts, masks, k_a) in enumerate(zip(self.test_src_txt, test_mask, k_a)):
            sts = nltk.word_tokenize(sts)
            sentences, labels, srcs, keys = get_dcmn_data_from_step1(sts, masks, k_a, self.abbrs,
                                                                     max_pad_length=max_pad_length,
                                                                     max_dcmn_seq_length=max_seq_length,
                                                                     tokenizer=tokenizer)
            self.test_order.append(len(sentences))
            self.test_keys.extend(keys)
            self.test_dcmn_srcs.extend(sentences)
            self.test_dcmn_labels.extend(labels)
            self.test_seq_srcs.extend(srcs)

        self.train_seq_srcs, self.train_tar_2_txt = add_sep(train_srcs=self.train_seq_srcs,
                                                            train_tars=self.train_tar_2_txt,
                                                            train_order=self.train_order)
        
        # self.train_embs = get_embs(self.train_keys, self.abbrs, max_pad_length)
        # self.test_embs = get_embs(self.test_keys, self.abbrs, max_pad_length)
        # with open('./data/train_embs.pkl', 'wb') as f:
        #     pickle.dump(self.train_embs, f)
        # with open('./data/test_embs.pkl', 'wb') as f:
        #     pickle.dump(self.test_embs, f)

        with open('./data/train_embs.pkl', 'rb') as f:
            self.train_embs = pickle.load(f)
        with open('./data/test_embs.pkl', 'rb') as f:
            self.test_embs = pickle.load(f)

        seq_config = Config(seq_batch_size)
        seq_tokenizer = seq_config.tokenizer
        self.train_seq_srcs_ids, self.train_seq_srcs_masks = seq_tokenize(self.train_seq_srcs, seq_tokenizer,
                                                                          max_seq_length)
        self.train_seq_tars_ids, self.train_seq_tars_masks = seq_tokenize(self.train_tar_2_txt, seq_tokenizer,
                                                                          max_seq_length)
        self.test_seq_srcs_ids, self.test_seq_srcs_masks = seq_tokenize(self.test_seq_srcs, seq_tokenizer,
                                                                        max_seq_length)
        self.cudics = pickle.load(open('./data/test_cudics.pkl', 'rb'))
        self.seq_test_tars = pickle.load(open('./data/test_tars.pkl', 'rb'))

        self.train_indices = get_index(self.train_seq_srcs_ids)
        self.test_indices = get_index(self.test_seq_srcs_ids)



    def build_dataset_eval(self):
        token_ids_srcs = self.test_seq_srcs_ids
        seq_len_src = 64
        mask_srcs = self.test_seq_srcs_masks

        cudics = self.cudics
        test_tars = self.seq_test_tars
        test_data = []
        for token_ids_src, mask_src, test_tar, cudic in zip(token_ids_srcs, mask_srcs, test_tars, cudics):
            tars = test_tar
            test_data.append((token_ids_src, int(0), seq_len_src, mask_src, tars, cudic))
        return test_data



    

In [35]:
dg = DataGenerator(1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15166/15166 [00:03<00:00, 4589.05it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15166/15166 [00:04<00:00, 3416.03it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2361/2361 [00:00<00:00, 4440.41it/s]


In [36]:
print(len(dg.train_seq_srcs_ids))
print(len(dg.test_seq_srcs_ids))
print(len(dg.train_indices))
# print(len())
# print(len())
# print(len())


15166
2361
15166


In [37]:

def parse_mc(input_file, answer_file, max_pad_length, dg):
    if 'dev' in input_file:
        _sentences = dg.test_dcmn_srcs
        _labels = dg.test_dcmn_labels
        key_embs = dg.test_embs
    else:
        _sentences = dg.train_dcmn_srcs
        _labels = dg.train_dcmn_labels
        key_embs = dg.train_embs

    sentences = _sentences
    labels= _labels

    q_id = [i+1 for i in range(len(labels))]
    article = [u[0] for u in sentences]
    question = [u[1] for u in sentences]
    cts = []
    for i in range(max_pad_length-2):
        cts.append([u[i+2] for u in sentences])
    y = labels
    if 'dev' in input_file:
        src_ids = dg.test_seq_srcs_ids
        src_masks = dg.test_seq_srcs_masks
        tar_ids = [None for _ in range(len(q_id))]
        tar_masks = [None for _ in range(len(q_id))]
        indices = dg.test_indices
        tars = dg.seq_test_tars
        cudics = dg.cudics
    else:
        src_ids = dg.train_seq_srcs_ids
        src_masks = dg.train_seq_srcs_masks
        tar_ids = dg.train_seq_tars_ids
        tar_masks = dg.train_seq_tars_masks
        indices = dg.train_indices
        tars = [None for _ in range(len(q_id))]
        cudics = [None for _ in range(len(q_id))]

    return article, question, cts, key_embs, y, q_id, \
            src_ids, src_masks, indices, tar_ids, tar_masks, tars, cudics


In [38]:

inputs = parse_mc('dev', None, 16, dg)


In [39]:
for u in inputs:
    print(len(u))

2361
2361
14
2361
2361
2361
2361
2361
2361
2361
2361
2031
2031


In [33]:
seq_config = Config(1)
seq_tokenizer = seq_config.tokenizer
for u in dg.test_seq_srcs_ids:
    print(seq_tokenizer.convert_ids_to_tokens(u))

['[', 'c', 'l', 's', ']', 'w', 'e', 'a', 'n', 'e', 'd', 'o', 'f', 'f', 'v', 'e', 'n', 't', 't', 'o', '[', 's', 'e', 'p', ']', '[', 'm', 'a', 's', 'k', ']', '[', 's', 'e', 'p', ']', 'a', 'n', 'd', 'w', 'a', 's', 'e', 'x', 't', 'u', 'b', 'a', 't', 'e', 'd', 'i', 'n', 't', 'h', 'e', 'a', 'f', 't', 'e', 'r', 'n', 'o', 'o']
['[', 'c', 'l', 's', ']', 'w', 'e', 'a', 'n', 'e', 'd', 'o', 'f', 'f', 'v', 'e', 'n', 't', 't', 'o', 'c', 'p', 'a', 'p', 'a', 'n', 'd', 'w', 'a', 's', 'e', 'x', 't', 'u', 'b', 'a', 't', 'e', 'd', 'i', 'n', 't', 'h', 'e', 'a', 'f', 't', 'e', 'r', 'n', 'o', 'o', 'n', 'o', 'n', '9', '-', '2', 'b', 'y', 't', 'h', 'e', '[']
['[', 'c', 'l', 's', ']', 's', 'h', 'e', 'w', 'a', 's', 'i', 'n', 't', 'u', 'b', 'a', 't', 'e', 'd', 'a', 'n', 'd', 'w', 'a', 's', 'r', 'e', 's', 'u', 's', 'c', 'i', 't', 'a', 't', 'e', 'd', 'a', 'f', 't', 'e', 'r', '1', '0', '-', '2', '2', 'm', 'i', 'n', 'u', 't', 'e', 's', 'o', 'f', '[', 's', 'e', 'p', ']', '[', 'm']
['[', 'c', 'l', 's', ']', 'h', 'a', '

['[', 'c', 'l', 's', ']', 'h', 'e', 'w', 'a', 's', 'c', 'o', 'n', 't', 'i', 'n', 'u', 'e', 'd', 'o', 'n', 'a', '[', 's', 'e', 'p', ']', '[', 'm', 'a', 's', 'k', ']', '[', 's', 'e', 'p', ']', 'a', 's', 'a', 'n', 'i', 'n', 'p', 'a', 't', 'i', 'e', 'n', 't', '.', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
['[', 'c', 'l', 's', ']', 's', 'h', 'e', 'r', 'e', 'q', 'u', 'i', 'r', 'e', 'd', 'b', 'l', 'o', 'o', 'd', 't', 'r', 'a', 'n', 's', 'f', 'u', 's', 'i', 'o', 'n', 'f', 'o', 'r', 'd', 'r', 'o', 'p', 'p', 'i', 'n', 'g', '[', 's', 'e', 'p', ']', '[', 'm', 'a', 's', 'k', ']', '[', 's', 'e', 'p', ']', 'f', 'r', 'o', 'm', '3']
['[', 'c', 'l', 's', ']', 's', 'h', 'e', 'r', 'e', 'q', 'u', 'i', 'r', 'e', 'd', 'b', 'l', 'o', 'o', 'd', 't', 'r', 'a', 'n', 's', 'f', 'u', 's', 'i', 'o', 'n', 'f', 'o', 'r', 'd', 'r', 'o', 'p', 'p', 'i', 'n', 'g', 'h', 'c', 't', 'f', 'r', 'o', 'm', '3', '0', 't', 'o', '2', '0', 'w', 'i', 't', 'h', '[', 's',

In [8]:
print(len(dg.train_dcmn_srcs))
print(len(dg.test_dcmn_srcs))
tot = 0
for u in dg.train_order:
    if u == 0:
        tot += 1
print(tot, len(dg.train_seq_srcs_ids)-tot, tot/len(dg.train_seq_srcs_ids))
tot = 0
for u in dg.test_order:
    if u == 0:
        tot += 1
print(tot, len(dg.test_seq_srcs_ids)-tot, tot/len(dg.test_seq_srcs_ids))

15166
2361
4300 10866 0.28352894632731107
664 1697 0.2812367640830157


In [59]:
a = torch.IntTensor([[1,2,3,4],[5,6,7,8]]) # batch_size * num_choices
print(a)
print(a.size())
a = a.unsqueeze(-1).expand(2,4,8)
print(a)
print(a.size())
b = torch.zeros(a.size())
print(b)

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]], dtype=torch.int32)
torch.Size([2, 4])
tensor([[[1, 1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4, 4]],

        [[5, 5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6, 6],
         [7, 7, 7, 7, 7, 7, 7, 7],
         [8, 8, 8, 8, 8, 8, 8, 8]]], dtype=torch.int32)
torch.Size([2, 4, 8])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]])


In [23]:
b = torch.IntTensor([[1,1,1,1],[1,1,1,1]])
b = b.unsqueeze(-1).expand(2,4,8)*2
print(b)

tensor([[[2, 2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 2, 2]],

        [[2, 2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 2, 2]]], dtype=torch.int32)


In [28]:
c = a*b
print(c)
d = torch.sum(c, dim = 1)
print(d)

tensor([[[ 2,  2,  2,  2,  2,  2,  2,  2],
         [ 4,  4,  4,  4,  4,  4,  4,  4],
         [ 6,  6,  6,  6,  6,  6,  6,  6],
         [ 8,  8,  8,  8,  8,  8,  8,  8]],

        [[10, 10, 10, 10, 10, 10, 10, 10],
         [12, 12, 12, 12, 12, 12, 12, 12],
         [14, 14, 14, 14, 14, 14, 14, 14],
         [16, 16, 16, 16, 16, 16, 16, 16]]], dtype=torch.int32)
tensor([[20, 20, 20, 20, 20, 20, 20, 20],
        [52, 52, 52, 52, 52, 52, 52, 52]])


In [29]:
e = torch.IntTensor([[1,2,3,4,5,6,7,8]])
arr = [e]
arr.extend(d)

In [30]:
print(arr)

[tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.int32), tensor([20, 20, 20, 20, 20, 20, 20, 20]), tensor([52, 52, 52, 52, 52, 52, 52, 52])]


In [33]:
f= e*0
print(f)

tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)


In [55]:
with open('./data/test_embs.pkl','rb') as f:
    embs = pickle.load(f)

In [58]:

print(np.shape(embs))


(2140, 14, 768)


In [68]:
a = torch.Tensor([[1,0,0],[0,1,0],[0,0,0],[0,0,1]])
b = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a.size())
print(b.size())
print(torch.mm(a,b))

torch.Size([4, 3])
torch.Size([3, 3])
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [0., 0., 0.],
        [7., 8., 9.]])


In [76]:

indices = torch.Tensor([[1,0,0],[0,1,0], [0,0,1]])
indices = torch.unsqueeze(indices, 2)
print(indices)
replaced_embeddings = torch.normal(0, 1, [3, 1, 100], requires_grad=True)
print(replace)
encoded_emebddings = torch.normal(0, 1, [3, 3, 100], requires_grad=True)
combined_embeddings = replaced_embeddings*indices+encoded_emebddings
y = (torch.mean(combined_embeddings)-1) ** 2
y.backward()
print(replaced_embeddings.grad)

tensor([[[1.],
         [0.],
         [0.]],

        [[0.],
         [1.],
         [0.]],

        [[0.],
         [0.],
         [1.]]])
tensor([[[-0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -0.0023,
          -0.0023, -0.0023, -0.0023, -0.0023, -0.0023, -

In [45]:
a = torch.randint(2,(4,2,3)).long()
b = torch.rand(4,3,4)
print(a)
print(torch.matmul(a,b).size())

tensor([[[0, 1, 0],
         [0, 1, 1]],

        [[0, 1, 1],
         [0, 0, 1]],

        [[1, 0, 0],
         [1, 1, 0]],

        [[1, 0, 0],
         [0, 1, 0]]])


RuntimeError: expected scalar type Long but found Float