In [None]:
import torch
from util.util import (
    create_model,
    create_gaussian_diffusion
)


In [None]:
def create_model_and_diffusion(
):
    model = create_model(
        model_channels=128,
        learn_sigma=False,
        dropout=0.1,
        model_arch="s2s_CAT",
        in_channel=128,
        out_channel=128,
        vocab_size=30522,
        config_name="bert-base-uncased",
        logits_mode=1,
        init_pretrained=False,
        token_emb_type="random",
    )
    diffusion = create_gaussian_diffusion(
        steps=2000,
        learn_sigma=False,
        sigma_small=False,
        noise_schedule="sqrt",
        use_kl=False,
        predict_xstart=False,
        rescale_timesteps=True,
        rescale_learned_sigmas=True,
        model_arch="s2s_CAT",
        training_mode="s2s",
    )
    return model, diffusion

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
model, diffusion = create_model_and_diffusion()
model.to(device)
model.eval()

In [None]:
from Genie_Generate import load_states_from_checkpoint
model_saved_state = load_states_from_checkpoint("GENIE_ckpt-500w")

In [None]:
model.load_state_dict(model_saved_state.model_dict)

In [None]:
sample_fn = (
        diffusion.p_sample_loop
    )


In [None]:
emb_model = model.word_embedding

In [None]:
import os
import numpy as np
from tqdm import tqdm
from functools import partial
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from data_util.s2s_data_util import S2S_dataset

In [None]:
def denoised_fn_round(model, text_emb,t):
    # thresh_t = 50
    # # print(thresh_t)
    # if thresh_t is not None and t[0] > thresh_t:
    #     return text_emb
    # return text_emb
    # print(t.float().mean(), t[0])

    # assert t.float().mean() == t[0].float()

    # print(text_emb.shape) # bsz, seqlen, dim
    down_proj_emb = model.weight  # input_embs
    # print(t)
    old_shape = text_emb.shape
    old_device = text_emb.device

    def get_efficient_knn(down_proj_emb, text_emb, dist='l2'):
        if dist == 'l2':
            emb_norm = (down_proj_emb ** 2).sum(-1).view(-1, 1)  # vocab
            text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1)  # d, bsz*seqlen
            arr_norm = (text_emb ** 2).sum(-1).view(-1, 1)  # bsz*seqlen, 1
            # print(emb_norm.shape, arr_norm.shape)
            dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(down_proj_emb,
                                                                     text_emb_t)  # (vocab, d) x (d, bsz*seqlen)
            dist = torch.clamp(dist, 0.0, np.inf)
            # print(dist.shape)
        topk_out = torch.topk(-dist, k=1, dim=0)
        #     adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
        #         down_proj_emb.size(0), -1, -1)
        #     adjacency = -th.norm(adjacency, dim=-1)
        # topk_out = th.topk(adjacency, k=1, dim=0)
        # print(topk_out1.indices == topk_out.indices)
        # assert th.all(topk_out1.indices == topk_out.indices)
        return topk_out.values, topk_out.indices

    def get_knn(down_proj_emb, text_emb, dist='l2'):
        if dist == 'l2':
            adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
                down_proj_emb.size(0), -1, -1)
            adjacency = -torch.norm(adjacency, dim=-1)
        topk_out = torch.topk(adjacency, k=1, dim=0)
        return topk_out.values, topk_out.indices

    dist = 'l2'
    if len(text_emb.shape) > 2:
        text_emb = text_emb.reshape(-1, text_emb.size(-1))
    else:
        text_emb = text_emb
    # val, indices = get_knn(down_proj_emb,
    #                        text_emb.to(down_proj_emb.device), dist=dist)
    val, indices = get_efficient_knn(down_proj_emb,
                                     text_emb.to(down_proj_emb.device), dist=dist)
    rounded_tokens = indices[0]
    # print(rounded_tokens.shape)
    new_embeds = model(rounded_tokens).view(old_shape).to(old_device)
    return new_embeds

In [None]:

sample = sample_fn(
    model,
    input_shape,
    clip_denoised=False,
    denoised_fn=partial(denoised_fn_round, emb_model.cuda()),
    model_kwargs=model_kwargs,
    top_p=-1.0,
    interval_step=interval_step,
)

In [None]:
sample.cuda()

In [None]:
model.cuda()

In [90]:
sample.shape

torch.Size([1, 128, 128])

In [None]:
logits = model.get_logits(sample.cuda())

In [95]:
num_samples = 5
data_name = "data"
data_path = "data/exp1/"
generate_path = "data/outputs/"
src_max_len = 128
tgt_max_len = 128
batch_size = 1
in_channel = 128
interval_step = 10000

In [96]:


# model_arch == 's2s_CAT'

# bert tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
print("-------------------------------------------------------------")
print("start generate query from dev dataset, for every passage, we generate ", num_samples, " querys...")
print("-------------------------------------------------------------")

print("***** load " + data_name + " test src dataset*****")
src = []
test_src_path = os.path.join(data_path, data_name + "/org_data/test.src")
with open(test_src_path, "r", encoding="utf-8") as ifile:
    for line in tqdm(ifile):
        line = line.strip()
        text = line
        src.append(text)

print("***** load " + data_name + " dev tgt dataset*****")
tgt = []
test_tgt_path = os.path.join(data_path, data_name + "/org_data/test.tgt")
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
    for line in tqdm(ifile):
        line = line.strip()
        text = line
        tgt.append(text)

# shard_size = len(src) // args.world_size
# start_idx = args.local_rank * shard_size
# end_idx = start_idx + shard_size
# if args.local_rank == args.world_size - 1:
#     end_idx = len(src)
# scr_data_piece = src[start_idx:end_idx]
# tgt_data_piece = tgt[start_idx:end_idx]
start_idx = 0
end_idx = len(src)
scr_data_piece = src
tgt_data_piece = tgt

print('generation for ', len(scr_data_piece), " src text from idx ", start_idx, " to ", end_idx)

test_dataset = S2S_dataset(scr_data_piece, tgt_data_piece, tokenizer, src_maxlength=src_max_len,
                            tgt_maxlength=tgt_max_len)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, drop_last=False,
                                num_workers=8, collate_fn=S2S_dataset.get_collate_fn())

if generate_path is not None:
    model_gen_files = []
    if os.path.exists(generate_path):
        for item in os.scandir(generate_path):
            if item.is_file():
                if "gen_seed" in item.path:
                    model_gen_files.append(item.path)
        if len(model_gen_files) != 0 :
            model_gen_files.sort(key=lambda f: int((f.split('_epoch')[-1]).split('.txt')[0]), reverse=True)
            epoch_num = int((model_gen_files[0].split('_epoch')[-1]).split('.txt')[0])
            # logger.info("***** load " + model_gen_files[0] + " *****")
        else:
            epoch_num = 0

else:
    # logger.info("generate_path is None")
    exit(0)

for epoch in range(num_samples - epoch_num):
    each_sample_list = []
    print("-------------------------------------------------------------")
    print("start sample ", epoch+1+epoch_num, " epoch...")
    print("-------------------------------------------------------------")

    for index, batch in enumerate(tqdm(test_dataloader)):
        '''
        for s2s
        '''
        input_shape = (batch['src_input_ids'].shape[0], tgt_max_len, in_channel)
        src_input_ids = batch['src_input_ids']
        tgt_input_ids = batch['tgt_input_ids']
        # print(p_input_ids.shape)
        src_attention_mask = batch['src_attention_mask']
        model_kwargs = {'src_input_ids' : src_input_ids.cuda(), 'src_attention_mask': src_attention_mask.cuda()}

        sample = sample_fn(
            model,
            input_shape,
            clip_denoised=False,
            denoised_fn=partial(denoised_fn_round, emb_model.cuda()),
            model_kwargs=model_kwargs,
            top_p=-1.0,
            interval_step=interval_step,
        )

        print("sample result shape: ", sample.shape)
        print('decoding for e2e... ')
        sample.cuda()
        model.cuda()
        logits = model.get_logits(sample)
        cands = torch.topk(logits, k=1, dim=-1)
        sample_id_list = cands.indices
        #print("decode id list example :", type(sample_id_list[0]), "  ", sample_id_list[0])

        '''
        for s2s
        '''
        print("src text: ", tokenizer.decode(src_input_ids.squeeze()))
        print("tgt text: ", tokenizer.decode(tgt_input_ids.squeeze()))
        print("generated query: ", tokenizer.decode(sample_id_list.squeeze()))

        print("sample control generate query: ")
        for sample_id in sample_id_list:
            sentence = tokenizer.decode(sample_id.squeeze())
            each_sample_list.append(sentence)
            # each_sample_list.append(clean(sentence))
            # print(sentence)

    # # total_sample_list.append(each_sample_list)
    # out_path = os.path.join(args.generate_path, "rank" + str(dist.get_rank()) + "_gen_seed_101" +
    #                         "_num" + str(args.num_samples) + "_epoch" + str(epoch + 1 + epoch_num) + ".txt")
    # with open(out_path, 'w') as f:
    #     for sentence in each_sample_list:
    #         f.write(sentence + '\n')


-------------------------------------------------------------
start generate query from dev dataset, for every passage, we generate  5  querys...
-------------------------------------------------------------
***** load data test src dataset*****


1it [00:00, 11949.58it/s]


***** load data dev tgt dataset*****


1it [00:00, 8439.24it/s]


generation for  1  src text from idx  0  to  1
-------------------------------------------------------------
start sample  1  epoch...
-------------------------------------------------------------


100%|██████████| 1/1 [00:00<00:00,  1.48it/s]


sample result shape:  torch.Size([1, 128, 128])
decoding for e2e... 
src text:  [CLS] a : hello! good morning, sir. how are you today? b : [MASK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tgt text:  [CLS] i am great! [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

100%|██████████| 1/1 [00:00<00:00,  1.39it/s]


sample result shape:  torch.Size([1, 128, 128])
decoding for e2e... 
src text:  [CLS] a : hello! good morning, sir. how are you today? b : [MASK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tgt text:  [CLS] i am great! [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

100%|██████████| 1/1 [00:00<00:00,  1.42it/s]


sample result shape:  torch.Size([1, 128, 128])
decoding for e2e... 
src text:  [CLS] a : hello! good morning, sir. how are you today? b : [MASK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tgt text:  [CLS] i am great! [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

100%|██████████| 1/1 [00:00<00:00,  1.42it/s]


sample result shape:  torch.Size([1, 128, 128])
decoding for e2e... 
src text:  [CLS] a : hello! good morning, sir. how are you today? b : [MASK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tgt text:  [CLS] i am great! [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

100%|██████████| 1/1 [00:00<00:00,  1.36it/s]

sample result shape:  torch.Size([1, 128, 128])
decoding for e2e... 
src text:  [CLS] a : hello! good morning, sir. how are you today? b : [MASK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tgt text:  [CLS] i am great! [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA




In [None]:
each_sample_list

In [None]:
model.get_logits(sample)