In [None]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn 
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from utils.manager import Manager
from utils.util import load_pickle, save_pickle

In [None]:
class config:
    epochs = 10
    scale = "demo"
    mode = "train"
    device = 0
    batch_size = 2
    batch_size_encode = 2
    dropout_p = 0.1
    seed = 3407
    world_size = 1

    data_root = "../../../Data"
    cache_root = "data/cache"
    dataloaders = ["train", "dev", "news", "history"]

    his_size = 50
    impr_size = 20
    negative_num = 4

    max_title_length = 64
    max_abs_length = 256
    title_length = 32
    abs_length = 64

    plm = "bert"

    enable_fields = ["title", "abs"]
    enable_gate = "weight"

    rank = 0
    verbose = None
    distributed = False
    debug = False

manager = Manager(config, notebook=True)
loaders = manager.prepare()

In [None]:
t = AutoTokenizer.from_pretrained(manager.plm_dir)
# m = AutoModel.from_pretrained(manager.plm_dir).to(0)

In [4]:
loader_train = loaders["train"]
loader_dev = loaders["dev"]
loader_news = loaders["news"]

dataset_train = loader_train.dataset
dataset_dev = loader_dev.dataset
dataset_news = loader_news.dataset

X1 = iter(loader_train)
X2 = iter(loader_dev)
X3 = iter(loader_news)
x = next(X1)
x2 = next(X2)
x3 = next(X3)

In [13]:
b = torch.zeros(3)
b[:1] = 1
b * (x["his_gate_mask"].sum(dim=-1, keepdim=True) < 16)

tensor([[[0., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1

In [9]:
x["cdd_gate_mask"]

tensor([[[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,
          1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,
          0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0],
         [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,
          1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0],
         [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
          0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
      

In [None]:
# check gate mask
index = (0, 0)
cdd_token_id = x['cdd_token_id'][index]
cdd_attn_mask = x["cdd_attn_mask"][index]
cdd_gate_mask = x["cdd_gate_mask"][index]
his_token_id = x["his_token_id"][index]
his_attn_mask = x["his_attn_mask"][index]
his_gate_mask = x["his_gate_mask"][index]

cdd_token = t.convert_ids_to_tokens(cdd_token_id)
his_token = t.convert_ids_to_tokens(his_token_id)

line = "{:15} a g".format(" "*15)
print(line)
for i in range(manager.sequence_length):
    line = "{:15} {} {}".format(cdd_token[i], cdd_attn_mask[i], cdd_gate_mask[i])
    print(line)
    if cdd_token[i] == "[PAD]":
        break

In [None]:
# check train loader result
nid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/MINDdemo_train/news/nid2index.pkl")
uid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/uid2index.pkl")
nindex2id = {v:k for k,v in nid2index.items()}
uindex2id = {v:k for k,v in uid2index.items()}

# check behaviors.tsv
print([uindex2id[i] for i in x["user_index"].tolist()], (x["impr_index"] + 1).tolist())
# check news.tsv
print([nindex2id[i] for i in x["cdd_idx"][0][:5].tolist()])
print(t.batch_decode(x["cdd_token_id"][0][:5], skip_special_tokens=True))

print([nindex2id[i] for i in x["his_idx"][0][:5].tolist()])
print(t.batch_decode(x["his_token_id"][0][:5], skip_special_tokens=True))

In [None]:
# check dev loader result
nid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/MINDdemo_dev/news/nid2index.pkl")
uid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/uid2index.pkl")
nindex2id = {v:k for k,v in nid2index.items()}
uindex2id = {v:k for k,v in uid2index.items()}

# check behaviors.tsv
print([uindex2id[i] for i in x2["user_index"].tolist()], (x2["impr_index"] + 1).tolist())
# check news.tsv
print([nindex2id[i] for i in x2["cdd_idx"][0][:5].tolist()])
print(t.batch_decode(x2["cdd_token_id"][0][:5], skip_special_tokens=True))

print([nindex2id[i] for i in x2["his_idx"][0][:5].tolist()])
print(t.batch_decode(x2["his_token_id"][0][:5], skip_special_tokens=True))