In [2]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn 
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from utils.manager import Manager
from utils.util import load_pickle, save_pickle

In [3]:
class config:
    epochs = 10
    scale = "demo"
    mode = "train"
    device = 0
    batch_size = 2
    batch_size_encode = 2
    dropout_p = 0.1
    seed = 3407
    world_size = 1

    data_root = "../../../Data"
    cache_root = "data/cache"
    dataloaders = ["train", "dev", "news", "history"]

    his_size = 50
    impr_size = 20
    negative_num = 4

    max_title_length = 64
    max_abs_length = 256
    title_length = 32
    abs_length = 64

    plm = "bert"

    enable_fields = ["title", "abs"]
    enable_gate = "weight"

    rank = 0
    verbose = None
    distributed = False
    debug = False

manager = Manager(config, notebook=True)
loaders = manager.prepare()

[2022-02-11 03:13:23,868] INFO (Manager) Hyper Parameters are:
{'scale': 'demo', 'batch_size': 2, 'batch_size_encode': 2, 'dropout_p': 0.1, 'seed': 3407, 'dataloaders': ['train', 'dev', 'news', 'history'], 'his_size': 50, 'impr_size': 20, 'negative_num': 4, 'title_length': 32, 'abs_length': 64, 'plm': 'bert', 'enable_fields': ['title', 'abs'], 'enable_gate': 'weight', 'verbose': None, 'sequence_length': 96}

[2022-02-11 03:13:23,872] INFO (MIND_Train) Loading Cache at MINDdemo_train
[2022-02-11 03:13:25,412] INFO (MIND_Dev) Loading Cache at MINDdemo_dev
[2022-02-11 03:13:26,751] INFO (MIND_News) Loading Cache at MINDdemo_dev


In [4]:
t = AutoTokenizer.from_pretrained(manager.plm_dir)
# m = AutoModel.from_pretrained(manager.plm_dir).to(0)

In [6]:
loader_train = loaders["train"]
loader_dev = loaders["dev"]
loader_news = loaders["news"]

dataset_train = loader_train.dataset
dataset_dev = loader_dev.dataset
dataset_news = loader_news.dataset

X1 = iter(loader_train)
X2 = iter(loader_dev)
X3 = iter(loader_news)
x = next(X1)
x2 = next(X2)
x3 = next(X3)

In [7]:
t.batch_decode(x["cdd_token_id"][0])

["[CLS] what you need to know about the c8 corvette's new dual - clutch transmission [SEP] the new corvette has an eight - speed tremec dct. we weren't crazy about it in the pre - production c8 we drove, but engineers tell us the final version will be better. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]",
 '[CLS] halle berry shares photo of 6 - pack abs on instagram and her fans are freaking out [SEP] the 53 - year - old star worked hard for those washboard abs while training for her new movie, " bruised. " [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
 "[CLS] usa to

In [5]:
# check gate mask
index = (0, 0)
cdd_token_id = x['cdd_token_id'][index]
cdd_attn_mask = x["cdd_attn_mask"][index]
cdd_gate_mask = x["cdd_gate_mask"][index]
his_token_id = x["his_token_id"][index]
his_attn_mask = x["his_attn_mask"][index]
his_gate_mask = x["his_gate_mask"][index]

cdd_token = t.convert_ids_to_tokens(cdd_token_id)
his_token = t.convert_ids_to_tokens(his_token_id)

line = "{:15} a g".format(" "*15)
print(line)
for i in range(manager.sequence_length):
    line = "{:15} {} {}".format(cdd_token[i], cdd_attn_mask[i], cdd_gate_mask[i])
    print(line)
    if cdd_token[i] == "[PAD]":
        break

                a g
[CLS]           1 0
what            1 1
you             1 1
need            1 1
to              1 1
know            1 1
about           1 1
the             1 1
c               1 1
##8             1 1
corvette        1 1
'               1 1
s               1 1
new             1 1
dual            1 1
-               1 1
clutch          1 1
transmission    1 1
[SEP]           1 0
the             1 0
new             1 0
corvette        1 0
has             1 1
an              1 1
eight           1 1
-               1 0
speed           1 1
tre             1 1
##me            1 1
##c             1 1
dc              1 1
##t             1 1
.               1 1
we              1 1
weren           1 1
'               1 0
t               1 1
crazy           1 1
about           1 0
it              1 1
in              1 1
the             1 0
pre             1 1
-               1 0
production      1 1
c               1 0
##8             1 0
we              1 0
drove           1 1


In [34]:
# check train loader result
nid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/MINDdemo_train/news/nid2index.pkl")
uid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/uid2index.pkl")
nindex2id = {v:k for k,v in nid2index.items()}
uindex2id = {v:k for k,v in uid2index.items()}

# check behaviors.tsv
print([uindex2id[i] for i in x["user_index"].tolist()], (x["impr_index"] + 1).tolist())
# check news.tsv
print([nindex2id[i] for i in x["cdd_idx"][0][:5].tolist()])
print(t.batch_decode(x["cdd_token_id"][0][:5], skip_special_tokens=True))

print([nindex2id[i] for i in x["his_idx"][0][:5].tolist()])
print(t.batch_decode(x["his_token_id"][0][:5], skip_special_tokens=True))

['U63949', 'U28675'] [1633, 526]
['N33159', 'N29128', 'N55689', 'N59981', 'N47229']
["sheriff's deputy charged with murder after man shot, killed a sheriff's deputy is in jail, charged with murder after a deadly shooting in athens, authorities said.", "flight attendants have a secret language you didn't know about here are some phrases only flight attendant use and what they actually mean.", 'charles rogers, former michigan state football, detroit lions star, dead at 38 charles rogers, the former michigan state football star whom the detroit lions selected with the second overall pick in 2003 nfl draft, has died.', "tamron hall talks leaving the today show, jokes about megyn kelly's reported multimillion payout tamron hall talks losing the today show", 'these are the best harley - davidson motorcycles you can get for the cheap these are the best harleys to buy right now.']
['N1150', 'N29177', 'N29898', 'N46392', 'N54772']
['a texas mom is going to prison after putting her son through u

In [33]:
# check dev loader result
nid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/MINDdemo_dev/news/nid2index.pkl")
uid2index = load_pickle("/data/v-pezhang/Code/GateFormer/src/data/cache/MIND/uid2index.pkl")
nindex2id = {v:k for k,v in nid2index.items()}
uindex2id = {v:k for k,v in uid2index.items()}

# check behaviors.tsv
print([uindex2id[i] for i in x2["user_index"].tolist()], (x2["impr_index"] + 1).tolist())
# check news.tsv
print([nindex2id[i] for i in x2["cdd_idx"][0][:5].tolist()])
print(t.batch_decode(x2["cdd_token_id"][0][:5], skip_special_tokens=True))

print([nindex2id[i] for i in x2["his_idx"][0][:5].tolist()])
print(t.batch_decode(x2["his_token_id"][0][:5], skip_special_tokens=True))

['U80234', 'U80234'] [1, 1]
['N28682', 'N48740', 'N31958', 'N34130', 'N6916']
["browns apologize to mason rudolph, call myles garrett's actions'unacceptable'players often defend their teammates when commenting on an incident like the one we saw thursday, but the immediate reaction from mayfield and others proves just how far garrett crossed the line.", "i've been writing about tiny homes for a year and finally spent 2 nights in a 300 - foot home to see what it's all about i stayed in a tiny house for three days to see what the fuss was all about, and i was surprised by what i saw.", 'opinion : colin kaepernick is about to get what he deserves : a chance the end may be near for the 3 - year - old saga of colin kaepernick as the quarterback is scheduled to work out for teams on saturday.', "the kardashians face backlash over'insensitive'family food fight in kuwtk clip kardashian's face backlash over family food fight", "then and now : what all your favorite'90s stars are doing today thes