# Simple implementation of our NRMS model

## Forord

Denne notebook er et forsøg på at lave en sammenhængende og forståelig
gennemgang af vores projekt fra A og næsten helt til Z.

1) Initialization
2) Data preprocessing
3) Dataset class and dataloader
4) Simple implementation of NRMS
5) Basic training loop
6) Metrics


## TODO

- Make separate file with all constants defined and columns
- Address duplicate values in the behaviors data frame (for now removed them)
- Find out if we should reshape, or squeeze, the 5 inview samples in our batches

## 1 Initialization

In [1]:
## IMPORTS

# System imports
import sys
import os
import gc
from pathlib import Path

# RecSys / eb-nerd imports
from ebrec.utils._behaviors import (create_binary_labels_column, truncate_history, sampling_strategy_wu2019)
from ebrec.utils._articles import create_article_id_to_value_mapping
from ebrec.utils._articles_behaviors import map_list_article_id_to_value
#from ebrec.models.newsrec.dataloader import NRMSDataLoader
#from ebrec.evaluation import MetricEvaluator, AucScore, NdcgScore, MrrScore

# Standard data science imports
import polars as pl
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from ebrec.utils._articles_behaviors import map_list_article_id_to_value
from ebrec.utils._python import (
    repeat_by_list_values_from_matrix,
    create_lookup_objects,
)

from myDataloader import NewsrecDataset, create_dataloader

################################################################################
## VARIABLES
################################################################################

HISTORY_SIZE = 30   # This is to control how many clicked articles are kept in history
NPRATIO = 4         # This is the same as K in the paper
SEED = 123          # Seed for reproducibility
BATCH_SIZE = 32     # Batch size given as input to the DataLoader

  from .autonotebook import tqdm as notebook_tqdm


## 2 Data preprocessing

Before we begin preprocessing the data to make the joined data frame for training,
we very briefly explore the data as to get a broad overview of columns etc.

Then we start preprocessing the data and finally join together the impression 
logs (behaviours) and users' browsed articles (history)

### 2.1 Load data

We use scan and then collect the data later while we process. Then we print the 
columns of each of the training data frames as well as the articles data frame
because we want to organize the columns afterwards in the following way:
1. Key columns
2. Essential columns (denoted USE_COLUMNS below)
3. Detail columns (the rest of the columns that are not essential, but would 
    probably improve the model if we implement them)

In [2]:
################################################################################
## SCAN / LOAD DATA
################################################################################

# Insert your own data path
DATA_PATH = Path("~/Documents/Studie/main/kurser/08_Fall_2024/02456_Deep_Learning/final_project/DeepLearning02456/tmp/Data").expanduser()
DATA_SIZE = f"ebnerd_small"     # [ebnerd_demo, ebnerd_small, ebnerd_large]
EMBEDDING_PATH = f"Ekstra_Bladet_contrastive_vector"

print(f"Attempting to scan data from \n\t{DATA_PATH / DATA_SIZE}")

# Scan data (this means that the data is not loaded into memory and we have to collect it when needed)
df_behaviors_train = pl.scan_parquet(
    DATA_PATH.joinpath(DATA_SIZE, "train", "behaviors.parquet")
)
df_history_train = pl.scan_parquet(
    DATA_PATH.joinpath(DATA_SIZE, "train", "history.parquet")
)
df_behaviors_val = df_behaviors = pl.scan_parquet(
    DATA_PATH.joinpath(DATA_SIZE, "validation", "behaviors.parquet")
)
df_history_val = df_behaviors = pl.scan_parquet(
    DATA_PATH.joinpath(DATA_SIZE, "validation", "history.parquet")
)
df_articles = pl.scan_parquet(DATA_PATH.joinpath("articles.parquet"))

# Load the contrastive vectors
article_embeddings = pl.scan_parquet(DATA_PATH.joinpath(EMBEDDING_PATH, "contrastive_vector.parquet"))

# Add the embeddings as a new column to the articles data frame and collect it
df_articles = df_articles.join(article_embeddings, on="article_id", how="left").collect()

print(f"Data scanned successfully...")

## Doing some quick exploration
# Keep data frames in a dictionary for easy access in the loop
scanned_dataframes = {
    "Behaviors Train": df_behaviors_train,
    "History Train": df_history_train,
    "Articles": df_articles,
    "Article embeddings": article_embeddings,
}

# Loop through each scanned data frame
for name, df in scanned_dataframes.items():
    print(f"\nColumns for {name}:")
    print("-" * (len(name) + 14)) # for proper alignment

    schema = df.schema
    max_col_len = max(len(col) for col in schema.keys())

    # Print columns and types with alignment
    for col, dtype in schema.items():
        print(f"{col:<{max_col_len}} : {dtype}")

print("\nAll schemas printed successfully.")

Attempting to scan data from 
	/Users/adamax/Documents/Studie/main/kurser/08_Fall_2024/02456_Deep_Learning/final_project/DeepLearning02456/tmp/Data/ebnerd_small
Data scanned successfully...

Columns for Behaviors Train:
-----------------------------
impression_id          : UInt32
article_id             : Int32
impression_time        : Datetime(time_unit='us', time_zone=None)
read_time              : Float32
scroll_percentage      : Float32
device_type            : Int8
article_ids_inview     : List(Int32)
article_ids_clicked    : List(Int32)
user_id                : UInt32
is_sso_user            : Boolean
gender                 : Int8
postcode               : Int8
age                    : Int8
is_subscriber          : Boolean
session_id             : UInt32
next_read_time         : Float32
next_scroll_percentage : Float32

Columns for History Train:
---------------------------
user_id                 : UInt32
impression_time_fixed   : List(Datetime(time_unit='us', time_zone=None))
scr

That was the very quick overview of the data. 

### 2.2 Preprocess data
Now we want to make the setup for joining the data frames together into one
data frame for training.

1. First we organize the columns from each data frame based on whether they are essential
to make our model work:
- Key columns
- Use columns (essential columns)
- Detail columns (other columns) 
2. Then we do some processing of the columns in the data frames to clean up
the data a bit, before joining them.
3. We join the two data frames

#### 2.2.1 - Organizing columns

In [3]:
################################################################################
## SETUP FOR CONSTRUCTING OUR TRAINING DATA FRAME (USING POLARS)
################################################################################

## (1) Organizing the columns

HISTORY_PREFIX = "his"
IMPRESSION_PREFIX = "imp"
ARTICLE_PREFIX = "art"

# Organizing the columns in the data frames based on their importance to make it work
IMPRESSION_KEY_COLUMNS = [
    "user_id",                  # UInt32
    "impression_id"
]

IMPRESSION_USE_COLUMNS = [
    "impression_time",          # Datetime(time_unit='us', time_zone=None)
    "article_ids_clicked",      # List(UInt32)
    "article_ids_inview"        # List(UInt32)
]

IMPRESSION_DETAIL_COLUMNS = [
    "read_time",              # Float32
    "scroll_percentage",      # Float32
    "device_type",            # Int8
    "is_sso_user",            # Boolean
    "gender",                 # Int8
    "postcode",               # Int8
    "age",                    # Int8
    "is_subscriber",          # Boolean
    "session_id",             # UInt32
    "next_read_time",         # Float32
    "next_scroll_percentage"  # Float32
]

HISTORY_KEY_COLUMNS = [
    "user_id",                # UInt32
]

HISTORY_USE_COLUMNS = [
    "article_id_fixed"        # List(UInt32)
]

HISTORY_DETAIL_COLUMNS = [
    "scroll_percentage_fixed" # List(Float32)
    "article_id_fixed"        # List(Int32)
    "read_time_fixed"         # List(Float32)
]

# OBS! Articles not used for now
# TODO - Add article columns
ARTICLE_KEY_COLUMNS = [
    "article_id"  # Int32
]

ARTICLE_USE_COLUMNS = [
    "title",              # String
    "subtitle",           # String
    "published_time",     # Datetime(time_unit='us', time_zone=None)
    "category",           # Int16
    "subcategory",        # List(Int16)
    "category_str",       # String
    "sentiment_score",    # Float32
    "sentiment_label"     # String
]

ARTICLE_DETAIL_COLUMNS = [
    "body",               # String
    "last_modified_time", # Datetime(time_unit='us', time_zone=None)
    "premium",            # Boolean
    "image_ids",          # List(Int64)
    "article_type",       # String
    "url",                # String
    "ner_clusters",       # List(String)
    "entity_groups",      # List(String)
    "topics",             # List(String)
    "total_inviews",      # Int32
    "total_pageviews",    # Int32
    "total_read_time",    # Float32
]



## Here I gather the columns I want to use for my training data frame
COLUMNS_FROM_HISTORY = HISTORY_KEY_COLUMNS + HISTORY_USE_COLUMNS
COLUMNS_FROM_BEHAVIORS = IMPRESSION_KEY_COLUMNS + IMPRESSION_USE_COLUMNS

print(f"\nColumns to be used from History: {COLUMNS_FROM_HISTORY}")
print(f"Columns to be used from Behaviors: {COLUMNS_FROM_BEHAVIORS}")


Columns to be used from History: ['user_id', 'article_id_fixed']
Columns to be used from Behaviors: ['user_id', 'impression_id', 'impression_time', 'article_ids_clicked', 'article_ids_inview']


#### 2.2.2 - Preprocessing

**Choices for history data frame**
- Selecting only essential columns 
- Renaming article_id_fixed to his_article_ids
- Truncating number of articles in history
- TODO: apply exponential decay from utils

In [11]:
# (2) Now we start the actual processing of the data frames

## HISTORY DATA FRAME
# Process the history data frame by selecting, renaming and then truncating the history
df_history_train = (
    df_history_train
    .select(COLUMNS_FROM_HISTORY)                                  # Selecting the columns we want to keep
    .rename({"article_id_fixed": f"{HISTORY_PREFIX}_article_ids"}) # using prefix: his_ to indicate origin: history
    #.with_columns(pl.col("his_article_ids").list.len().alias("his_num_articles")) # This line adds another column with the number of clicked articles in the history prior to truncation
    .pipe(                                                         # Truncating the history
            truncate_history,
            column="his_article_ids",
            history_size=HISTORY_SIZE,
            padding_value=0,
            enable_warning=False,
        )
    .collect()  # Collecting the data frame into memory in the end
    )

# Quick check that it worked
print(f"shape of df_history_train after processing: {df_history_train.shape}")
df_history_train.head(5)


shape of df_history_train after processing: (15143, 2)


user_id,his_article_ids
u32,list[i32]
13538,"[9767342, 9767751, … 9769366]"
14241,"[9763401, 9763250, … 9767852]"
20396,"[9763634, 9763401, … 9769679]"
34912,"[9766722, 9759476, … 9770882]"
37953,"[9762836, 9763942, … 9769306]"


**Choices for behaviors data frame**
- Selecting essential columns
- Filter out impressions with more than one clicked article
- Negative sampling with postive:negative ratio 1:4 (I actually think it should be higher)
- Make label column called click
- Expand article_ids_inview and rename to article_id

In [12]:
## BEHAVIORS DATA FRAME
df_behaviors_train = (
  df_behaviors_train
    .select(COLUMNS_FROM_BEHAVIORS) # selecting the columns we want to keep
    .with_columns(
        length=pl.col('article_ids_clicked').map_elements(lambda x: len(x)))  # adding a column with the length of the clicked articles
    .filter(pl.col('length') == 1)  # we only want users with exactly one click in their impression
    .collect()                      # Collecting the data frame into memory
    .pipe(sampling_strategy_wu2019, npratio=NPRATIO, shuffle=True, clicked_col="article_ids_clicked",
          inview_col="article_ids_inview", with_replacement=False, seed=SEED)   # down-sampling
    .pipe(create_binary_labels_column, clicked_col="article_ids_clicked",      # creating the binary labels column
          inview_col="article_ids_inview")
    .drop("length")
)

# WE WILL NOT USE THIS FOR NOW
#     .rename({"article_ids_inview": "article_id"}) # renaming to article id because we expand this column now
#     .explode("article_id") # expanding the article ids in view
#     .with_columns(click=pl.col("article_id").is_in(pl.col("article_ids_clicked")).cast(pl.Int8))
#     .drop(["article_ids_clicked", "length", "labels"])
#     .with_columns(pl.col("article_id").cast(pl.Int32))

print(f"shape of df_behaviors_train after processing: {df_behaviors_train.shape}")
df_behaviors_train.head(5)

shape of df_behaviors_train after processing: (231731, 6)


user_id,impression_id,impression_time,article_ids_clicked,article_ids_inview,labels
u32,u32,datetime[μs],list[i64],list[i64],list[i8]
139836,149474,2023-05-24 07:47:53,[9778657],"[9778728, 9778669, … 9778657]","[0, 0, … 1]"
143471,150528,2023-05-24 07:33:25,[9778623],"[9778669, 9778769, … 9778623]","[0, 0, … 1]"
151570,153068,2023-05-24 07:09:04,[9778669],"[9772866, 9776259, … 9693002]","[0, 0, … 0]"
151570,153070,2023-05-24 07:13:14,[9778628],"[9430567, 9778628, … 9525589]","[0, 1, … 0]"
151570,153071,2023-05-24 07:11:08,[9777492],"[9131971, 9335113, … 9778623]","[0, 0, … 0]"


### 2.3 Joining the data  frames

In [13]:
df_train = (
    df_behaviors_train
    .join(df_history_train, on="user_id", how="left")
)

print(f"shape of joined train data frame: {df_train.shape}")
df_train.head(5)  # Quick check that it worked

shape of joined train data frame: (231731, 7)


user_id,impression_id,impression_time,article_ids_clicked,article_ids_inview,labels,his_article_ids
u32,u32,datetime[μs],list[i64],list[i64],list[i8],list[i32]
139836,149474,2023-05-24 07:47:53,[9778657],"[9778728, 9778669, … 9778657]","[0, 0, … 1]","[0, 0, … 9765156]"
143471,150528,2023-05-24 07:33:25,[9778623],"[9778669, 9778769, … 9778623]","[0, 0, … 1]","[9767557, 9768062, … 9770989]"
151570,153068,2023-05-24 07:09:04,[9778669],"[9772866, 9776259, … 9693002]","[0, 0, … 0]","[9770620, 9770594, … 9770829]"
151570,153070,2023-05-24 07:13:14,[9778628],"[9430567, 9778628, … 9525589]","[0, 1, … 0]","[9770620, 9770594, … 9770829]"
151570,153071,2023-05-24 07:11:08,[9777492],"[9131971, 9335113, … 9778623]","[0, 0, … 0]","[9770620, 9770594, … 9770829]"


We will save this data frame, so we don't have to process it again.

In [4]:
# Saving df_train in parquet format and csv
df_train_file_name = "df_train_basic"

if (DATA_PATH / f"{df_train_file_name}.parquet").exists():
    print(f"The preprocessed data frame for training is already saved as {df_train_file_name}.parquet")
else:
    print("File did not exist... saving now.")
    df_train.write_parquet(f"{DATA_PATH}/{df_train_file_name}.parquet")

The preprocessed data frame for training is already saved as df_train_basic.parquet


## 3 Dataset class and dataloader

For now we focus on setting up the dataloader, and we will use the provided 
NRMSdataloader from the ebnerd directory.

We will need to
- create a mapping from article id to the contrastive vector embedding
- 

In [5]:
df_train = pl.read_parquet(f"{DATA_PATH}/{df_train_file_name}.parquet")

In [6]:
# Creating mapping from article_id to contrastive vector article embeddings
article_mapping = create_article_id_to_value_mapping(
    df=df_articles, value_col="contrastive_vector", article_col="article_id"
)

# Print the number of keys in the dictionary
print(f"Number of articles in the mapping: {len(article_mapping)}")

# Print the first few keys and their corresponding values
for i, (key, value) in enumerate(article_mapping.items()):
  print(f"Article ID: {key}, Encoded Value: {value}")
  if i >= 4:  # Limit to first 5 entries for brevity
    break

Number of articles in the mapping: 125541
Article ID: 3000022, Encoded Value: shape: (768,)
Series: '' [f32]
[
	-0.012159
	0.057097
	0.018299
	-0.038884
	-0.010863
	-0.04567
	0.030678
	0.028279
	-0.010959
	0.02433
	0.005196
	0.055502
	…
	-0.034803
	0.001971
	0.000293
	0.016569
	0.0356
	0.056305
	0.026861
	-0.046941
	0.029988
	-0.000547
	-0.037465
	0.025883
	0.013574
]
Article ID: 3000063, Encoded Value: shape: (768,)
Series: '' [f32]
[
	0.034482
	0.033533
	0.054598
	-0.023163
	0.009087
	-0.02119
	0.069386
	0.002645
	0.009205
	0.075007
	0.02691
	0.032995
	…
	0.052563
	0.057967
	-0.006074
	-0.002025
	0.026897
	0.049568
	-0.023337
	-0.000465
	0.007559
	-0.035644
	-0.0047
	-0.011909
	-0.023085
]
Article ID: 3000613, Encoded Value: shape: (768,)
Series: '' [f32]
[
	-0.014638
	0.030934
	0.036163
	0.039489
	-0.030487
	-0.051596
	-0.025907
	0.042978
	-0.031724
	0.025653
	-0.091189
	-0.002512
	…
	-0.085461
	-0.037983
	0.015554
	-0.003445
	-0.002479
	0.064819
	-0.057671
	-0.072227
	0.034131
	-0.

In [7]:
# Now we can define the dataloader for the training
train_dataloader = create_dataloader(
  df = df_train,
  history_column = "his_article_ids",
  article_dict = article_mapping,
  history_size = 30,
  embedding_dim = 768, # 768
  unknown_representation="zeros",
  eval_mode=False,
  batch_size=BATCH_SIZE # 32
  )

In [8]:
# Print the first batch of the training data loader
for i, batch in enumerate(train_dataloader):
    X_his, X_pred, y = batch  # batch has two elements: X and y (X is )
    print(f"Batch {i}")
    print(f"Embeddings: {np.shape(X_his)}")
    print(f"Embeddings: {np.shape(X_pred)}")
    print(f"Labels: {np.shape(y)}")
    if i == 0:  # Print only the first batch for brevity
        break

Batch 0
Embeddings: torch.Size([32, 30, 768])
Embeddings: torch.Size([32, 5, 768])
Labels: torch.Size([32, 5, 1])


In [12]:
# Reshape X_his to 32x30x768
X_his = X_his.view(-1, X_his.size(2), X_his.size(3))

# Reshape X_pred to 160x1x768
X_pred = X_pred.view(-1, 1, X_pred.size(3))

# Reshape y to 160x1
y = y.view(-1, 1)

print(f"New shape of X_his: {X_his.shape}")
print(f"New shape of X_pred: {X_pred.shape}")
print(f"New shape of y: {y.shape}")

New shape of X_his: torch.Size([32, 30, 768])
New shape of X_pred: torch.Size([160, 1, 768])
New shape of y: torch.Size([160, 1])


So our dataloader works, and the batches seem reasonable. We should decide on
whether we reshape such that the 5 inview articles become 5 distinct training
examples, i.e. squeezing together dimensions 32 and 5.

## 4 Implementation of NRMS

#### 4.1 Structure

**0. Hyperparameters stored in a class**

**1. Class:NRMSmodel (based on nn.Module)**
- Should take care of building the NewsEncoder, UserEncoder
- Should store things such as optimizer and loss
- Compute click probability and neg. log likelihood loss
- Score click probability by softmax of positive samples
- Forward the whole thing right?

**2. Class: NewsEncoder**

- Should be able to take one article and encode it through mhsa, 
then add. attention
- output should be the encoded article

**3. Class: UserEncoder**

- Should be able to take all of the articles in history and use mhsa on them
as well as add. attention to get user representation


**Considerations**
- How to work with the dims we have from dataloader?

In [None]:
# Define hyperparameters for the model

class hparams_nrms:
    # INPUT DIMENTIONS:
    history_size: int = HISTORY_SIZE # 30
    embedding_dim: int = 768
    batch_size: int = 32
    # CONTEXTUAL DIMENSIONS:
    npratio : int = NPRATIO
    inview_sample_size: int = (NPRATIO + 1)
    # MODEL ARCHITECTURE
    head_num: int = 16
    head_dim: int = 16
    attention_hidden_dim: int = 200   # Multi-head self-attention (mhsa)
    newsencoder_output_dim: int = 256
    additive_attention_dim: int = [512, 512, 512] # second attention layer after mhsa
    # MODEL OPTIMIZER:
    optimizer : str = "adam"
    loss: str = "cross_entropy_loss"
    dropout: float = 0.2
    learning_rate: float = 0.001
    seed: int = 123

In [None]:
class NRMSmodel(nn.Module):
    def __init__(self,
                 hparams: dict):
        super(NRMSmodel, self).__init__()
        self.hparams = hparams

        """Initialization steps for the NRMS model"""
        # BUILD MODEL
        self.model, self.scorer = self._buildNRMS()

        # LOSS FUNCTION


        # OPTIMIZER
        self.optimizer = self._get_opt(
            optimizer=self.hparams.optimizer, lr=self.hparams.learning_rate
        )

    def _buildNRMS(self):

        self.newsencoder = NewsEncoder(self.hparams,

        return

    # def _get_loss(self, loss: str):
    #     if loss == "cross_entropy_loss":
    #         return nn.CrossEntropyLoss()
    #     elif loss == "positive_sample_nll": # negative log likelihood of positive samples
    #         return nn.
    #     else:
    #         raise ValueError(f"Loss function {loss} not supported")

    # def click_prob(self, user_representation, news_representation):
    #     # Compute dot product and apply softmax
    #     pred_one = torch.matmul(news_present_one, user_present.unsqueeze(-1)).squeeze(-1)
    #     pred_one = torch.sigmoid(pred_one)

In [None]:
class NRMS(nn.Module):
    def __init__(self, hparams, newsencoder):
        super(NRMS, self).__init__()
        self.hparams = hparams
        self.newsencoder = newsencoder
        self.userencoder = UserEncoder(hparams, newsencoder)

    def forward(self, his_input_title, pred_input_title):

        # Encode the predicted titles
        batch_size, num_titles, title_size = pred_input_title.size()
        pred_input_title = pred_input_title.view(-1, title_size)

        # Encode the user history
        user_present = self.userencoder(his_input_title)  # u vector
        news_encoder = self.newsencoder(pred_input_title)  # r vector

        # Apply news_encoder to all inputs in pred_input_title with TimeDistributed
        news_present = TimeDistributed(news_encoder, batch_first=False).forward(pred_input_title)

        # reshape news_present
        news_present = news_present.view(batch_size, num_titles, -1)

        # Compute dot product and apply softmax
        preds = torch.matmul(news_present, user_present.unsqueeze(-1)).squeeze(-1)
        preds = F.softmax(preds, dim=-1)

        # Convert to binary labels by taking the argmax
        preds = torch.argmax(preds, dim=-1)

        return preds

    def score(self, his_input_title, pred_input_title_one):
        # Encode the user history
        user_present = self.userencoder(his_input_title)  # u vector

        # Encode the single predicted title
        pred_title_one_reshape = pred_input_title_one.view(-1, self.hparams.title_size)
        news_present_one = self.newsencoder(pred_title_one_reshape)  # r vector for one article

        # Compute dot product and apply sigmoid
        pred_one = torch.matmul(news_present_one, user_present.unsqueeze(-1)).squeeze(-1)
        pred_one = torch.sigmoid(pred_one)

        return pred_one


In [None]:
class NewsEncoder(nn.Module):
    def __init__(self,
                 hparams,
                 units_per_layer=[512, 512, 512]):
        super(NewsEncoder, self).__init__()
        self.hparams = hparams
        self.document_vector_dim = hparams.title_size
        self.output_dim = hparams.head_num * hparams.head_dim

        self.multihead_attention = nn.MultiheadAttention(embed_dim=self.document_vector_dim, num_heads=hparams.head_num)


        layers = []
        input_dim = self.document_vector_dim
        for units in units_per_layer:
            layers.append(nn.Linear(input_dim, units))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(units))
            layers.append(nn.Dropout(hparams.dropout))
            input_dim = units

        layers.append(nn.Linear(input_dim, self.output_dim))
        layers.append(nn.ReLU())

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        attn_output, _ = self.multihead_attention(x, x, x)
        return self.model(x)

In [None]:
class UserEncoder(nn.Module):
    def __init__(self, hparams, newsencoder):
        super(UserEncoder, self).__init__()
        self.newsencoder = TimeDistributed(newsencoder, batch_first=True)
        self.newsencoder_output_dim = hparams.newsencoder_output_dim
        self.multihead_attention = nn.MultiheadAttention(embed_dim=self.newsencoder_output_dim, num_heads=hparams.head_num)
        self.attention_layer = AttentionLayer2(hparams)



    def forward(self, x):
        # Encode the news history
        encoded_news = self.newsencoder(x)

        # Apply multi-head attention
        attn_output, _ = self.multihead_attention(encoded_news, encoded_news, encoded_news)

        # Apply the attention layer
        user_representation = self.attention_layer(attn_output)

        return user_representation
