In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Data Loading

In [21]:
import csv
tsv_file = open("data/MINDsmall_train/news.tsv", "r")
read_tsv = csv.reader(tsv_file, delimiter="\t")

In [19]:
import tqdm as tq

In [17]:
from dataclasses import dataclass

In [18]:
@dataclass
class News:
    id: str
    category: str
    subcategory: str
    title: str
    abstract: str

In [22]:
news = [News(id=row[0], category=row[1], subcategory=row[2], title=row[3], abstract=row[4]) for row in tq.tqdm(read_tsv)]

51282it [00:00, 91125.01it/s]


In [23]:
len(news)

51282

In [24]:
tsv_file.close()

In [25]:
news[:2]

[News(id='N55528', category='lifestyle', subcategory='lifestyleroyals', title='The Brands Queen Elizabeth, Prince Charles, and Prince Philip Swear By', abstract="Shop the notebooks, jackets, and more that the royals can't live without."),
 News(id='N19639', category='health', subcategory='weightloss', title='50 Worst Habits For Belly Fat', abstract='These seemingly harmless habits are holding you back and keeping you from shedding that unwanted belly fat for good.')]

In [42]:
tsv_file = open("data/MINDsmall_train/behaviors.tsv", "r")
read_tsv = csv.reader(tsv_file, delimiter="\t")

In [40]:
row

['5',
 'U8125',
 '11/12/2019 4:11:21 PM',
 'N10078 N56514 N14904 N33740',
 'N39985-0 N36050-0 N16096-0 N8400-1 N22407-0 N60408-0 N61497-0 N47412-0 N41220-0 N1940-0 N724-0 N11363-0 N261-0 N33883-0 N36807-0 N11967-0 N17896-0 N13486-0 N10413-0 N54274-0 N4247-0 N27497-0 N38512-0 N30253-0 N45389-0 N20015-0 N20678-0 N54003-0 N35850-0 N33261-0 N32010-0 N57426-0 N7419-0 N50023-0 N36446-0 N26940-0 N28495-0 N19318-0 N4936-0 N28414-0 N25108-0 N32791-0 N23563-0 N39317-0 N16166-0 N37058-0 N64851-0 N46992-0 N57327-0 N12995-0 N58363-0 N53084-0 N11094-0 N36436-0 N305-0 N58241-0 N33212-0 N6975-0 N58114-0 N3344-0 N25406-0 N4741-0 N33885-0 N20915-0 N44941-0 N57319-0 N36532-0 N61822-0 N20527-0']

In [33]:
from typing import List

@dataclass
class Impression:
    id: str
    user_id: str
    datetime: str
    past_clicked: List
    clicked_news: List
    non_clicked_news: List

In [43]:
impressions = [
    Impression(id=row[0], user_id=row[1], datetime=row[2], past_clicked=row[3].split(), clicked_news=[i.split("-")[0] for i in row[-1].split() if i.endswith("-1")], non_clicked_news=[i.split("-")[0] for i in row[-1].split() if i.endswith("-0")])
    for row in tq.tqdm(read_tsv)
]

156965it [00:05, 28199.42it/s]


In [48]:
tsv_file.close()

In [44]:
len(impressions)

156965

In [46]:
impressions[:2]

[Impression(id='1', user_id='U13740', datetime='11/11/2019 9:05:58 AM', past_clicked=['N55189', 'N42782', 'N34694', 'N45794', 'N18445', 'N63302', 'N10414', 'N19347', 'N31801'], clicked_news=['N55689'], non_clicked_news=['N35729']),
 Impression(id='2', user_id='U91836', datetime='11/12/2019 6:11:30 PM', past_clicked=['N31739', 'N6072', 'N63045', 'N23979', 'N35656', 'N43353', 'N8129', 'N1569', 'N17686', 'N13008', 'N21623', 'N6233', 'N14340', 'N48031', 'N62285', 'N44383', 'N23061', 'N16290', 'N6244', 'N45099', 'N58715', 'N59049', 'N7023', 'N50528', 'N42704', 'N46082', 'N8275', 'N15710', 'N59026', 'N8429', 'N30867', 'N56514', 'N19709', 'N31402', 'N31741', 'N54889', 'N9798', 'N62612', 'N2663', 'N16617', 'N6087', 'N13231', 'N63317', 'N61388', 'N59359', 'N51163', 'N30698', 'N34567', 'N54225', 'N32852', 'N55833', 'N64467', 'N3142', 'N13912', 'N29802', 'N44462', 'N29948', 'N4486', 'N5398', 'N14761', 'N47020', 'N65112', 'N31699', 'N37159', 'N61101', 'N14761', 'N3433', 'N10438', 'N61355', 'N21164

In [50]:
from collections import Counter

In [51]:
# number of unique users
user_count = Counter(imp.user_id for imp in impressions)

In [54]:
user_count.most_common(20)

[('U32146', 62),
 ('U15740', 44),
 ('U20833', 41),
 ('U51286', 40),
 ('U44201', 40),
 ('U79449', 37),
 ('U30304', 37),
 ('U57047', 36),
 ('U47521', 36),
 ('U56120', 35),
 ('U79210', 35),
 ('U63482', 34),
 ('U27166', 34),
 ('U85878', 34),
 ('U72280', 33),
 ('U68925', 33),
 ('U21954', 33),
 ('U43884', 33),
 ('U67455', 32),
 ('U83337', 32)]

In [55]:
[imp for imp in impressions if imp.user_id == "U83337"]

[Impression(id='877', user_id='U83337', datetime='11/12/2019 8:42:41 PM', past_clicked=['N28936', 'N51591', 'N37083', 'N7632', 'N8424', 'N15446', 'N23468', 'N20207', 'N20290', 'N30012', 'N30441', 'N18005', 'N63650', 'N26002', 'N7857', 'N47685', 'N53387', 'N16215', 'N46337', 'N55846', 'N29177', 'N31686', 'N52096', 'N45746', 'N25436', 'N6511', 'N20959', 'N34124', 'N58045', 'N35375', 'N54304', 'N10343', 'N46039', 'N36300', 'N28818', 'N62620', 'N36551', 'N19591', 'N45286', 'N52946', 'N29088', 'N16251', 'N719', 'N871', 'N30145', 'N46978', 'N25889', 'N17725', 'N36699', 'N6290', 'N63906', 'N20722', 'N49103', 'N11677', 'N54822', 'N7563', 'N48582', 'N4306', 'N60442', 'N54581', 'N33503', 'N19615', 'N9653', 'N18164', 'N51305', 'N56403', 'N25634', 'N35218', 'N47525', 'N28016', 'N21815', 'N51248', 'N28134', 'N25155', 'N58615', 'N49153', 'N14566', 'N27911', 'N848', 'N8275', 'N62746', 'N56514', 'N11101', 'N49647', 'N16965', 'N13893', 'N4415', 'N4069', 'N63513', 'N57109', 'N49341', 'N4490', 'N14934', 

### Model Building

In [56]:
def dropout_mask(x, sz, p):
    "Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to nullify an element."
    return x.new(*sz).bernoulli_(1-p).div_(1-p)


class EmbeddingDropout(nn.Module):
    "Apply dropout with probabily `embed_p` to an embedding layer `emb`."

    def __init__(self, emb: nn.Embedding, embed_p: float = 0.):
        super().__init__()
        self.emb = emb
        self.embed_p = embed_p

    def forward(self, words: torch.Tensor, scale: float = None):
        if self.training and self.embed_p != 0:
            size = (self.emb.weight.size(0),1)
            mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
            masked_embed = self.emb.weight * mask
        else: 
            masked_embed = self.emb.weight
        
        if scale: 
            masked_embed.mul_(scale)
        
        return F.embedding(words, masked_embed, 
            self.emb.padding_idx or -1, self.emb.max_norm,
            self.emb.norm_type, self.emb.scale_grad_by_freq, 
            self.emb.sparse)

In [58]:
ids_to_news = list(set([n.id for n in news]))

In [59]:
ids_to_news[:3]

['N54952', 'N48981', 'N18437']

In [61]:
len(ids_to_news)

51282

In [60]:
news_to_ids = {n:i for i, n in enumerate(ids_to_news)}

In [63]:
news[:2]

[News(id='N55528', category='lifestyle', subcategory='lifestyleroyals', title='The Brands Queen Elizabeth, Prince Charles, and Prince Philip Swear By', abstract="Shop the notebooks, jackets, and more that the royals can't live without."),
 News(id='N19639', category='health', subcategory='weightloss', title='50 Worst Habits For Belly Fat', abstract='These seemingly harmless habits are holding you back and keeping you from shedding that unwanted belly fat for good.')]

In [64]:
news_content = {news_to_ids[n.id]: n.title + " " + n.abstract for n in news}

In [65]:
len(news_content)

51282

In [66]:
news_content[0]

'Surry County Schools investigating after reports of inappropriate comments about the president during improv School officials are investigating reports of inappropriate comments made during a North Surry High School Improv Club performance Wednesday. The North Surry High School Improv Club performed before school hours during a Late Start Wednesday. The club performed for for approximately 45 students, the school said. Sign up for our Newsletters During a performance, team members and a sponsor started an improvisation about jobs at the White House....'

In [67]:
len(impressions)

156965

more than one positive samples

In [74]:
click_count = Counter([len(i.clicked_news) for i in impressions])

In [77]:
click_count.most_common()

[(1, 113888),
 (2, 25571),
 (3, 9263),
 (4, 3975),
 (5, 1957),
 (6, 942),
 (7, 515),
 (8, 296),
 (9, 198),
 (10, 117),
 (11, 81),
 (12, 46),
 (13, 38),
 (14, 22),
 (15, 17),
 (16, 10),
 (18, 9),
 (17, 6),
 (19, 2),
 (21, 2),
 (26, 2),
 (25, 1),
 (23, 1),
 (35, 1),
 (27, 1),
 (22, 1),
 (20, 1),
 (24, 1),
 (31, 1)]

In [80]:
imp = impressions[0]

In [84]:
train_ds = [(
    [news_to_ids[n] for n in imp.past_clicked],
    [news_to_ids[n] for n in imp.clicked_news],
    [news_to_ids[n] for n in imp.non_clicked_news]
) for imp in impressions]

In [None]:
# randomly sample 4 neg in non-clicked news
num_neg = 4

In [88]:
train_ds[0]

([27575, 36477, 40324, 13851, 10724, 6883, 27655, 41485, 33531],
 [28961],
 [27997])

In [91]:
seq_len_count = Counter([len(e[0]) for e in train_ds])

In [92]:
seq_len_count.most_common()

[(4, 6308),
 (5, 6101),
 (6, 5591),
 (3, 5442),
 (7, 5341),
 (8, 4636),
 (9, 4270),
 (10, 4233),
 (11, 4073),
 (2, 3695),
 (12, 3665),
 (13, 3519),
 (14, 3365),
 (0, 3238),
 (15, 3176),
 (16, 2942),
 (17, 2828),
 (19, 2710),
 (18, 2694),
 (20, 2428),
 (22, 2335),
 (21, 2236),
 (25, 2230),
 (23, 2228),
 (1, 2162),
 (24, 2101),
 (26, 1957),
 (27, 1862),
 (30, 1832),
 (28, 1630),
 (29, 1596),
 (32, 1563),
 (31, 1563),
 (33, 1534),
 (34, 1497),
 (37, 1390),
 (45, 1381),
 (36, 1312),
 (41, 1306),
 (39, 1274),
 (35, 1237),
 (40, 1183),
 (38, 1153),
 (43, 1094),
 (46, 1085),
 (42, 1065),
 (44, 904),
 (47, 886),
 (52, 870),
 (51, 858),
 (48, 828),
 (58, 801),
 (57, 777),
 (49, 772),
 (50, 766),
 (54, 743),
 (55, 733),
 (56, 713),
 (60, 659),
 (53, 658),
 (63, 649),
 (62, 594),
 (61, 541),
 (65, 538),
 (64, 517),
 (70, 512),
 (59, 507),
 (71, 488),
 (76, 468),
 (81, 445),
 (67, 442),
 (68, 430),
 (77, 417),
 (69, 408),
 (82, 396),
 (66, 386),
 (101, 378),
 (75, 369),
 (83, 356),
 (73, 355),
 (7

what to do with the one having no past clicks? 3238 of them - removing from training for now

In [94]:
len([len(e[0]) for e in train_ds if len(e[0]) == 0])

3238

In [104]:
len([1 for e in train_ds if len(e[0])>= 250])

533

In [105]:
max_seq_len = 250

In [106]:
train_ds = [e for e in train_ds if len(e[0]) > 0 and len(e[0]) < max_seq_len]

In [107]:
len(train_ds)

153194

### Tokenization

In [125]:
import sys

In [126]:
sys.path.append("/home/chris/Documents/ML/projects/")

In [127]:
from underscore.nlp import AlphaTokenizer

In [128]:
tokenizer = AlphaTokenizer()

In [134]:
news_content[0]

'Surry County Schools investigating after reports of inappropriate comments about the president during improv School officials are investigating reports of inappropriate comments made during a North Surry High School Improv Club performance Wednesday. The North Surry High School Improv Club performed before school hours during a Late Start Wednesday. The club performed for for approximately 45 students, the school said. Sign up for our Newsletters During a performance, team members and a sponsor started an improvisation about jobs at the White House....'

In [141]:
def tokenize(text: str) -> List[str]:
    return [t.strip() for t in list(tokenizer( [text.lower()] ))[0]]

In [143]:
print( tokenize(news_content[0]) )

['surry', 'county', 'schools', 'investigating', 'after', 'reports', 'of', 'inappropriate', 'comments', 'about', 'the', 'president', 'during', 'improv', 'school', 'officials', 'are', 'investigating', 'reports', 'of', 'inappropriate', 'comments', 'made', 'during', 'a', 'north', 'surry', 'high', 'school', 'improv', 'club', 'performance', 'wednesday', '.', 'the', 'north', 'surry', 'high', 'school', 'improv', 'club', 'performed', 'before', 'school', 'hours', 'during', 'a', 'late', 'start', 'wednesday', '.', 'the', 'club', 'performed', 'for', 'for', 'approximately', '45', 'students', ',', 'the', 'school', 'said', '.', 'sign', 'up', 'for', 'our', 'newsletters', 'during', 'a', 'performance', ',', 'team', 'members', 'and', 'a', 'sponsor', 'started', 'an', 'improvisation', 'about', 'jobs', 'at', 'the', 'white', 'house', '....']


In [144]:
news_content = {i: tokenize(text) for i, text in tq.tqdm(news_content.items())}

100%|██████████| 51282/51282 [00:02<00:00, 19265.66it/s]


In [146]:
max([len(tokens) for i, tokens in news_content.items()])

595

In [154]:
len([1 for i, tokens in news_content.items() if len(tokens) > 200])

62

In [156]:
len([1 for i, tokens in news_content.items() if len(tokens) < 1])

0

In [197]:
# to token ids
news_content = {i: [w2i.get(w, 1) for w in tokens] for i, tokens in news_content.items()}

In [155]:
max_text_len = 200

In [157]:
word_count = Counter()
for i, tokens in news_content.items():
    word_count.update(tokens)

In [158]:
len(word_count)

55564

In [159]:
i2w = [w for w, _ in word_count.most_common()]

In [160]:
i2w[:3]

['the', '.', ',']

In [161]:
PAD = "<PAD>"
UNK = "<UNK>"

i2w.insert(0, UNK)
i2w.insert(0, PAD)

In [162]:
i2w[:3]

['<PAD>', '<UNK>', 'the']

In [163]:
w2i = {w: i for i, w in enumerate(i2w)}

In [164]:
vocab_size = len(w2i)

### Glove Embedding

In [110]:
!head -n1 /home/chris/Documents/ML/data/GloVe_small/glove.6B.100d.txt

the -0.038194 -0.24487 0.72812 -0.39961 0.083172 0.043953 -0.39141 0.3344 -0.57545 0.087459 0.28787 -0.06731 0.30906 -0.26384 -0.13231 -0.20757 0.33395 -0.33848 -0.31743 -0.48336 0.1464 -0.37304 0.34577 0.052041 0.44946 -0.46971 0.02628 -0.54155 -0.15518 -0.14107 -0.039722 0.28277 0.14393 0.23464 -0.31021 0.086173 0.20397 0.52624 0.17164 -0.082378 -0.71787 -0.41531 0.20335 -0.12763 0.41367 0.55187 0.57908 -0.33477 -0.36559 -0.54857 -0.062892 0.26584 0.30205 0.99775 -0.80481 -3.0243 0.01254 -0.36942 2.2167 0.72201 -0.24978 0.92136 0.034514 0.46745 1.1079 -0.19358 -0.074575 0.23353 -0.052062 -0.22044 0.057162 -0.15806 -0.30798 -0.41625 0.37972 0.15006 -0.53212 -0.2055 -1.2526 0.071624 0.70565 0.49744 -0.42063 0.26148 -1.538 -0.30223 -0.073438 -0.28312 0.37104 -0.25217 0.016215 -0.017099 -0.38984 0.87424 -0.72569 -0.51058 -0.52028 -0.1459 0.8278 0.27062


In [122]:
import numpy as np
from typing import Dict

def read_glove(fname: str) -> Dict[str, np.ndarray]:
    w2v = {}
    with open(fname, "r") as f:
        for line in f:
            tokens = line.split()
            w = tokens[0]
            v = np.array([float(n) for n in tokens[1:]], dtype=np.float32)
            w2v[w] = v
    return w2v

In [123]:
glove_w2v = read_glove("/home/chris/Documents/ML/data/GloVe_small/glove.6B.100d.txt")

In [124]:
len(glove_w2v)

400000

In [165]:
embed_size = 100

In [245]:
glove_emb = np.random.uniform(low=-1., high=1., size=(vocab_size, embed_size))

In [248]:
glove_emb = glove_emb.astype(np.float32)

In [250]:
# zero for pad
glove_emb[0] = np.zeros_like(glove_emb[0])

num_existing_words = 0
for w, v in glove_w2v.items():
    word_id = w2i.get(w, None)
    if word_id:
        glove_emb[word_id] = v
        num_existing_words += 1

In [170]:
num_existing_words

48470

In [173]:
glove_emb[w2i["manager"]] @ glove_emb[w2i["management"]]

22.427108878476844

In [174]:
glove_emb[w2i["manager"]] @ glove_emb[w2i["scientist"]]

9.946951709414234

In [243]:
glove_emb[w2i["manager"]].dtype

dtype('float64')

In [251]:
emb = nn.Embedding.from_pretrained(torch.from_numpy(glove_emb), freeze=False, padding_idx=0)

In [269]:
emb.weight[0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], grad_fn=<SelectBackward>)

In [252]:
word_emb = EmbeddingDropout(emb, embed_p=0.0)

In [183]:
batch = train_ds[:4]

In [184]:
batch[0]

([27575, 36477, 40324, 13851, 10724, 6883, 27655, 41485, 33531],
 [28961],
 [27997])

In [185]:
seqs, pos, negs = zip(*batch)

In [187]:
len(seqs)

4

In [339]:
negs[1]

[30198, 47614, 14391, 50090, 12206, 32576, 39841, 34932, 40367, 18252]

In [342]:
np.random.choice(negs[1], size=4).tolist()

[39841, 34932, 39841, 12206]

In [343]:
import random

def DataLoader(dataset: List, num_negs: int = 4):
    for example in dataset:
        seq, pos_labels, neg_labels = example
        
        for pos in pos_labels:
            # shuffle the order
            random.shuffle(seq)
            negs = np.random.choice(
                neg_labels, 
                size=num_negs if num_negs <= len(neg_labels) else len(neg_labels), 
                replace=False
            ).tolist()
            
            yield (seq, pos, negs)

In [344]:
train_it = DataLoader(train_ds, num_negs=4)

In [353]:
seq, pos, negs = next(train_it)

In [354]:
pos

16987

In [355]:
negs

[50562, 37756, 1873, 26940]

In [356]:
x = seq + [pos] + negs

In [358]:
print( x )

[28818, 18403, 29802, 19823, 16473, 50511, 3108, 39018, 49918, 40468, 31503, 48719, 49677, 23977, 26158, 37134, 16987, 50562, 37756, 1873, 26940]


In [359]:
seq_token_ids = [torch.tensor(news_content[news_id], dtype=torch.int64) for news_id in x] 

In [360]:
seq_token_ids = nn.utils.rnn.pad_sequence(seq_token_ids, batch_first=True, padding_value=0)

In [361]:
seq_token_ids.shape

torch.Size([21, 107])

In [362]:
seq_token_ids[1]

tensor([   86,    27,     2,   583,  1020,    73,     7,   161, 20258,    13,
          443,   143,     2,   443,   143,  1020,    13,   133,    39,   569,
         4335, 12298,    45,     3,   104,     5,  5060,  4560,    17,    59,
            6,  1206,    23,   682,  4148,   139,  3098,     4,  1981,     4,
           10,   130,   714,     3,  4210,   122,  1868,    10,   514,    21,
            6,   412,    70,   839,    13,     2,   583,  1290,   520,     9,
            2,    40,     3,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0])

In [363]:
seq_emb = word_emb(seq_token_ids)

In [364]:
seq_emb.shape

torch.Size([21, 107, 100])

In [365]:
seq_emb[1, -1]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], grad_fn=<SelectBackward>)

In [366]:
proj = nn.Linear(in_features=100, out_features=16*16, bias=False)

In [367]:
seq_emb = proj(seq_emb)

In [368]:
seq_emb.shape

torch.Size([21, 107, 256])

In [369]:
seq_emb[1, -1]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

### Multi-head self-attention

In [214]:
query = torch.tensor([[[1., 2., 3., 1., 2., 3.], [3., 0., 0., 3., 0., 0.]]])

In [215]:
key = torch.tensor([[[1., 2., 3., 1., 2., 3.], [3., 0., 0., 3., 0., 0.]]])

In [216]:
value = torch.tensor([[[10., 1., 10., 10., 1., 10.], [1., 10., 1., 1., 10., 1.]]])

In [217]:
multihead_attn = nn.MultiheadAttention(embed_dim=6, num_heads=1)

In [221]:
attn_out, attn_w = multihead_attn(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))

In [224]:
attn_out = attn_out.transpose(0, 1)

In [226]:
attn_w

tensor([[[0.9709, 0.0291],
         [0.3827, 0.6173]]], grad_fn=<DivBackward0>)

In [227]:
attn_out

tensor([[[ 0.7450,  2.8015,  1.3443, -2.0017,  3.4129, -2.6062],
         [-0.3155,  0.5638,  0.7060, -0.6808,  0.1740, -1.0112]]],
       grad_fn=<TransposeBackward0>)

In [370]:
head_size = 16
num_heads = 16

multihead_attn = nn.MultiheadAttention(embed_dim=head_size*num_heads, num_heads=num_heads)

> key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.

In [371]:
seq_emb.shape

torch.Size([21, 107, 256])

In [372]:
seq_emb = seq_emb.transpose(0, 1)
attn_out, attn_w = multihead_attn(seq_emb, seq_emb, seq_emb, key_padding_mask=(seq_token_ids == 0))

In [373]:
attn_out = attn_out.transpose(0, 1)

In [374]:
attn_out.shape

torch.Size([21, 107, 256])

In [375]:
attn_w.shape

torch.Size([21, 107, 107])

In [377]:
len(seq), len(negs)

(16, 4)

In [303]:
# additive attention
class AdditiveAttention(nn.Module):
    
    def __init__(self, input_size: int, attention_size: int):
        super().__init__()
        self.proj = nn.Linear(input_size, attention_size, bias=True)
        self.register_parameter("query", nn.Parameter(torch.ones(attention_size, 1)))
        
    def forward(self, x):
        "x of shape (b, n, d)"
        x_proj = self.proj(x).tanh() # (b, n, a)
        attn_w = (x_proj @ self.query).squeeze(-1).softmax(-1) # (b, n)
        attn_out = (attn_w.unsqueeze(-1) * x).sum(1)
        return attn_out

In [319]:
additive_attn = AdditiveAttention(3, 4)

In [320]:
additive_attn

AdditiveAttention(
  (proj): Linear(in_features=3, out_features=4, bias=True)
)

In [321]:
additive_attn.query.data.shape

torch.Size([4, 1])

In [323]:
additive_attn.query.data = torch.tensor([1., 2., 3., 4.]).reshape(4, 1)

In [324]:
additive_attn.query

Parameter containing:
tensor([[1.],
        [2.],
        [3.],
        [4.]], requires_grad=True)

In [325]:
additive_attn.proj.weight.data = torch.ones_like(additive_attn.proj.weight.data)

In [326]:
additive_attn.proj.weight

Parameter containing:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], requires_grad=True)

In [327]:
x = torch.tensor([[[1., 2., 3.], [1., 0., 0.]]])

In [328]:
x.shape

torch.Size([1, 2, 3])

In [329]:
additive_attn(x)

tensor([[1.0000, 1.8625, 2.7938]], grad_fn=<SumBackward1>)

In [380]:
attn_out.shape

torch.Size([21, 107, 256])

In [381]:
additive_attn = AdditiveAttention(256, attention_size=200)

In [386]:
x_pooled = additive_attn(attn_out)

In [387]:
# one vector for each news
x_pooled.shape

torch.Size([21, 256])

In [397]:
# split
seq_pooled = x_pooled[:len(seq)]
label_pooled = x_pooled[len(seq):]

In [407]:
seq_pooled.shape, label_pooled.shape

(torch.Size([16, 1, 256]), torch.Size([5, 256]))

In [392]:
news_self_attn = nn.MultiheadAttention(embed_dim=head_size*num_heads, num_heads=num_heads)

In [398]:
seq_pooled = seq_pooled.unsqueeze(0).transpose(0, 1)

In [399]:
seq_pooled.shape

torch.Size([16, 1, 256])

In [400]:
seq_encoded, _ = news_self_attn(seq_pooled, seq_pooled, seq_pooled)

In [401]:
seq_encoded.shape

torch.Size([16, 1, 256])

In [402]:
seq_encoded = seq_encoded.transpose(0, 1)

In [403]:
user_additive_attn = AdditiveAttention(256, 200)

In [404]:
user_encoded = user_additive_attn(seq_encoded)

In [405]:
user_encoded.shape

torch.Size([1, 256])

In [417]:
logits = user_encoded @ label_pooled.t()

In [418]:
logits = logits.squeeze(0)

In [419]:
logits

tensor([-0.0063, -0.0219, -0.0187,  0.0014, -0.0059],
       grad_fn=<SqueezeBackward1>)

In [422]:
pos_prob = F.log_softmax(logits, dim=-1)[0]

In [423]:
pos_prob

tensor(-1.6055, grad_fn=<SelectBackward>)