In [None]:
from tqdm import tqdm
import random
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np
import pandas as pd
import os

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def check_folder_exist(fpath):
    if os.path.exists(fpath):
        print("dir \"" + fpath + "\" existed")
    else:
        try:
            os.mkdir(fpath)
        except:
            print("error when creating \"" + fpath + "\"")

def setup_path(fpath, is_dir = True):
    dirs = [p for p in fpath.split("/")]
    curP = ""
    dirs = dirs[:-1] if not is_dir else dirs
    for p in dirs:
        curP += p
        check_folder_exist(curP)
        curP += "/"


# Data Related            
def repeat_n_core(df, user_col_id, item_col_id, n_core, user_counts, item_counts):
    '''
    Iterative n_core filter

    @input:
    - df: [UserID, ItemID, ...]
    - n_core: number of core
    - user_counts: {uid: frequency}
    - item_counts: {iid: frequency}
    '''
    print("N-core is set to [5,100]")
    n_core = min(max(n_core, 5),100) # 5 <= n_core <= 100
    print("Filtering " + str(n_core) + "-core data")
    iteration = 0
    lastNRemove = len(df)  # the number of removed record
    proposedData = df.values
    originalSize = len(df)

    # each iteration, count number of records that need to delete
    while lastNRemove != 0:
        iteration += 1
        print("Iteration " + str(iteration))
        changeNum = 0
        newData = []
        for row in tqdm(proposedData):
            user, item = row[user_col_id], row[item_col_id]
            if user_counts[user] < n_core or item_counts[item] < n_core:
                user_counts[user] -= 1
                item_counts[item] -= 1
                changeNum += 1
            else:
                newData.append(row)
        proposedData = newData
        print("Number of removed record: " + str(changeNum))
        if changeNum > lastNRemove + 10000:
            print("Not converging, will use original data")
            break
        else:
            lastNRemove = changeNum
    print("Size change: " + str(originalSize) + " --> " + str(len(proposedData)))
    return pd.DataFrame(proposedData, columns=df.columns)

def run_multicore(df, user_key = "user_id", item_key = "item_id", n_core = 10, auto_core = False, filter_rate = 0.2):
    '''
    @input:
    - df: pd.DataFrame, col:[UserID,ItemID,...]
    - n_core: number of core
    - auto_core: automatically find n_core, set to True will ignore n_core
    - filter_rate: proportion of removal for user/item, require auto_core = True
    '''
    print(f"Filter {n_core if not auto_core else 'auto'}-core data.")
    uCounts = df[user_key].value_counts().to_dict() # {user_id: count}
    iCounts = df[item_key].value_counts().to_dict() # {item_id: count}

    # automatically find n_core based on filter rate
    if auto_core:
        print("Automatically find n_core that filter " + str(100*filter_rate) + "% of user/item")

        nCoreCounts = dict() # {n_core: [#user, #item]}
        for v,c in iCounts.items():
            if c not in nCoreCounts:
                nCoreCounts[c] = [0,1]
            else:
                nCoreCounts[c][1] += 1
        for u,c in uCounts.items():
            if c not in nCoreCounts:
                nCoreCounts[c] = [1,0]
            else:
                nCoreCounts[c][0] += 1

        # find n_core for: filtered data < filter_rate * length(data)
        userToRemove = 0 # number of user records to remove
        itemToRemove = 0 # number of item records to remove
        for c,counts in sorted(nCoreCounts.items()):
            userToRemove += counts[0] * c # #user * #core
            itemToRemove += counts[1] * c # #item * #core
            if userToRemove > filter_rate * len(df) or itemToRemove > filter_rate * len(df):
                n_core = c
                print("Autocore = " + str(n_core))
                break
    else:
        print("n_core = " + str(n_core))

    return repeat_n_core(df, 0, 1, n_core, uCounts, iCounts)

def padding_and_clip(sequence, max_len, padding_direction = 'left'):
    if len(sequence) < max_len:
        sequence = [0] * (max_len - len(sequence)) + sequence if padding_direction == 'left' else sequence + [0] * (max_len - len(sequence))
    sequence = sequence[-max_len:] if padding_direction == 'left' else sequence[:max_len]
    return sequence

def get_onehot_vocab(meta_df, features):
    print('build vocab for onehot features')
    vocab = {}
    for f in tqdm(features):
        value_list = list(meta_df[f].unique())
        vocab[f] = {}
        for i,v in enumerate(value_list):
            onehot_vec = np.zeros(len(value_list))
            onehot_vec[i] = 1
            vocab[f][v] = onehot_vec
    return vocab

def get_multihot_vocab(meta_df, features, sep = ','):
    print('build vocab for multihot features:')
    vocab = {}
    for f in features:
        print(f'\t{f}')
        ID_freq = {}
        for row in tqdm(meta_df[f]):
            IDs = str(row).split(sep)
            for ID in IDs:
                if ID not in ID_freq:
                    ID_freq[ID] = 1
                else:
                    ID_freq[ID] += 1
        v_list = list(ID_freq.keys())
        vocab[f] = {}
        for i,v in enumerate(v_list):
            onehot_vec = np.zeros(len(v_list))
            onehot_vec[i] = 1
            vocab[f][v] = onehot_vec
    return vocab

def get_ID_vocab(meta_df, features):
    print('build vocab for encoded ID features')
    vocab = {}
    for f in tqdm(features):
        value_list = list(meta_df[f].unique())
        vocab[f] = {v:i+1 for i,v in enumerate(value_list)}
    return vocab

def get_multiID_vocab(meta_df, features, sep = ','):
    print('build vocab for encoded ID features')
    vocab = {}
    for f in features:
        print(f'\t{f}:')
        ID_freq = {}
        for row in tqdm(meta_df[f]):
            IDs = str(row).split(sep)
            for ID in IDs:
                if ID not in ID_freq:
                    ID_freq[ID] = 1
                else:
                    ID_freq[ID] += 1
        v_list = list(ID_freq.keys())
        vocab[f] = {v:i+1 for i,v in enumerate(v_list)}
    return vocab


def show_batch(batch):
    for k, batch in batch.items():
        if torch.is_tensor(batch):
            print(f"{k}: size {batch.shape}, \n\tfirst 5 {batch[:5]}")
        else:
            print(f"{k}: {batch}")


def wrap_batch(batch, device):
    '''
    Build feed_dict from batch data and move data to device
    '''
    for k,val in batch.items():
        if type(val).__module__ == np.__name__:
            batch[k] = torch.from_numpy(val)
        elif torch.is_tensor(val):
            batch[k] = val
        elif type(val) is list:
            batch[k] = torch.tensor(val)
        else:
            continue
        if batch[k].type() == "torch.DoubleTensor":
            batch[k] = batch[k].float()
        batch[k] = batch[k].to(device)
    return batch



# Model Related 

def init_weights(m):
    if 'Linear' in str(type(m)):
#         nn.init.normal_(m.weight, mean=0.0, std=0.01)
        nn.init.xavier_normal_(m.weight, gain=1.)
        if m.bias is not None:
            nn.init.normal_(m.bias, mean=0.0, std=0.01)
    elif 'Embedding' in str(type(m)):
#         nn.init.normal_(m.weight, mean=0.0, std=0.01)
        nn.init.xavier_normal_(m.weight, gain=1.0)
        print("embedding: " + str(m.weight.data))
        with torch.no_grad():
            m.weight[m.padding_idx].fill_(0.)
    elif 'ModuleDict' in str(type(m)):
        for param in module.values():
            nn.init.xavier_normal_(param.weight, gain=1.)
            with torch.no_grad():
                param.weight[param.padding_idx].fill_(0.)


def get_regularization(*modules):
    reg = 0
    for m in modules:
        for p in m.parameters():
            reg = torch.mean(p * p) + reg
    return reg


def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(param.data)

def sample_categorical_action(action_prob, candidate_ids, slate_size, with_replacement = True,
                              batch_wise = False, return_idx = False):
    '''
    @input:
    - action_prob: (B, L)
    - candidate_ids: (B, L) or (1, L)
    - slate_size: K
    - with_replacement: sample with replacement
    - batch_wise: do batch wise candidate selection
    '''
    if with_replacement:
        indices = Categorical(action_prob).sample(sample_shape = (slate_size,))
        indices = torch.transpose(indices, 0, 1)
    else:
        indices = torch.cat([torch.multinomial(prob, slate_size, replacement = False).view(1,-1) \
                             for prob in action_prob], dim = 0)
    action = torch.gather(candidate_ids,1,indices) if batch_wise else candidate_ids[indices]
    if return_idx:
        return action.detach(), indices.detach()
    else:
        return action.detach()



# Learning               

class LinearScheduler(object):
    '''
    Code used in DQN: https://github.com/dxyang/DQN_pytorch/blob/master/utils/schedules.py
    '''

    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
        self.schedule_timesteps = schedule_timesteps
        self.final_p            = final_p
        self.initial_p          = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction  = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


class SinScheduler(object):
    '''
    Code used in DQN: https://github.com/dxyang/DQN_pytorch/blob/master/utils/schedules.py
    '''

    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
        self.schedule_timesteps = schedule_timesteps
        self.final_p            = final_p
        self.initial_p          = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction  = np.sin(min(float(t) / self.schedule_timesteps, 1.0) * np.pi * 0.5)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


In [None]:
music_rating = pd.read_table('movie_rating_kuaisim.csv', sep = ',')
music_profile = pd.read_table('movie_profiles_kuaisim.csv', sep = ',')
user_profile = pd.read_table('profiles.csv', sep = ',')
items = list(music_profile['item_id'].unique())
check_items = list(music_rating['item_id'].unique())
check_items = pd.Series(check_items).sort_values().tolist()

In [None]:
import os
import pickle
import argparse
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
from torch.utils.data import Dataset
from tqdm import tqdm


def worker_init_func(worker_id):
    worker_info = data.get_worker_info()
    worker_info.dataset.worker_id = worker_id

#############################################################################
#                              Dataset Class                                #
#############################################################################

class BaseReader(Dataset):

    @staticmethod
    def parse_data_args(parser):
        parser.add_argument('--train_file', type=str, required=True,
                            help='train data file_path')
        parser.add_argument('--val_file', type=str, default='',
                            help='val data file_path')
        parser.add_argument('--test_file', type=str, default='',
                            help='test data file_path')
        parser.add_argument('--n_worker', type=int, default=4,
                            help='number of worker for dataset loader')
        parser.add_argument('--data_separator', type=str, default='\t',
                            help='separator of csv file')
        return parser

    def log(self):
        print("Reader params:")
        print(f"\tn_worker: {self.n_worker}")
        for k,v in self.get_statistics().items():
            print(f"\t{k}: {v}")

    def __init__(self, args):
        '''
        - phase: one of ["train", "val", "test"]
        - data: {phase: pd.DataFrame}
        - data_fields: {field_name: (field_type, field_var)}
        - data_vocab: {field_name: {value: index}}
        '''
        self.phase = "train"
        self.n_worker = args.n_worker
        self._read_data(args)

    def _read_data(self, args):
        self.data = dict()
        print(f"Loading data files", end = '\r')
        self.data['train'] = pd.read_table(args.train_file, sep = args.data_separator)
        self.data['val'] = pd.read_table(args.val_file, sep = args.data_separator) \
                                if len(args.val_file) > 0 else self.data['train']
        self.data['test'] = pd.read_table(args.test_file, sep = args.data_separator) \
                                if len(args.test_file) > 0 else self.data['val']

    def get_statistics(self):
        return {'length': len(self)}

    def set_phase(self, phase):
        assert phase in ["train", "val", "test"]
        self.phase = phase

    def get_train_dataset(self):
        self.set_phase("train")
        return self

    def get_eval_dataset(self, phase = 'val'):
        self.set_phase(phase)
        return self

    def __len__(self):
        return len(self.data[self.phase])

    def __getitem__(self, idx):
        pass

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import math


class KRMBSeqReader(BaseReader):
    '''
    KuaiRand Multi-Behavior Data Reader
    '''

    @staticmethod
    def parse_data_args(parser):
        '''
        args:
        - user_meta_file
        - item_meta_file
        - max_hist_seq_len
        - val_holdout_per_user
        - test_holdout_per_user
        - meta_file_sep
        - from BaseReader:
            - train_file
            - val_file
            - test_file
            - n_worker
        '''
        parser = BaseReader.parse_data_args(parser)
        parser.add_argument('--user_meta_file', type=str, required=True,
                            help='user raw feature file_path')
        parser.add_argument('--item_meta_file', type=str, required=True,
                            help='item raw feature file_path')
        parser.add_argument('--max_hist_seq_len', type=int, default=100,
                            help='maximum history length in the sample')
        parser.add_argument('--val_holdout_per_user', type=int, default=5,
                            help='number of holdout records for val set')
        parser.add_argument('--test_holdout_per_user', type=int, default=5,
                            help='number of holdout records for test set')
        parser.add_argument('--meta_file_sep', type=str, default=',',
                            help='separater of user/item meta csv file')
        return parser

    def log(self):
        super().log()
        print(f"\tval_holdout_per_user: {self.val_holdout_per_user}")
        print(f"\ttest_holdout_per_user: {self.test_holdout_per_user}")

    def __init__(self, args):
        '''
        - max_hist_seq_len
        - val_holdout_per_user
        - test_holdout_per_user
        - from BaseReader:
            - phase
            - n_worker
        '''
        print("initiate KuaiRandMultiBehaior sequence reader")
        self.max_hist_seq_len = args.max_hist_seq_len
        self.val_holdout_per_user = args.val_holdout_per_user
        self.test_holdout_per_user = args.test_holdout_per_user
        super().__init__(args)

    def _read_data(self, args):
        '''
        - log_data: pd.DataFrame
        - data: {'train': [row_id], 'val': [row_id], 'test': [row_id]}
        - users: [user_id]
        - user_id_vocab: {user_id: encoded_user_id}
        - user_meta: {user_id: {feature_name: feature_value}}
        - user_vocab: {feature_name: {feature_value: one-hot vector}}
        - selected_user_features
        - items: [item_id]
        - item_id_vocab: {item_id: encoded_item_id}
        - item_meta: {item_id: {feature_name: feature_value}}
        - item_vocab: {feature_name: {feature_value: one-hot vector}}
        - selected_item_features: [feature_name]
        - padding_item_meta: {feature_name: 0}
        - user_history: {uid: [row_id]}
        - response_list: [response_type]
        - padding_response: {response_type: 0}
        -
        '''

        # read data_file
        print(f"Loading data files")
        self.log_data = pd.read_table(args.train_file, sep = args.data_separator)

        print("Load item meta data")
        item_meta_file = pd.read_csv(args.item_meta_file, sep = args.meta_file_sep)
        self.item_meta = item_meta_file.set_index('item_id').to_dict('index')
        print("Load user meta data")
        user_meta_file = pd.read_csv(args.user_meta_file, sep = args.meta_file_sep)
        self.user_meta = user_meta_file.set_index('user_id').to_dict('index')

        # user list, item list, user history
        self.users = list(self.log_data['user_id'].unique())
        self.itemns = list(self.log_data['item_id'].unique()) #changed this line for not sort items
        self.items = pd.Series(self.itemns).sort_values().tolist() #change this line for sort items
        self.user_history = {uid: list(self.log_data[self.log_data['user_id'] == uid].index) for uid in self.users}

        # id reindex
        self.user_id_vocab = {uid: i+1 for i,uid in enumerate(self.users)}
        self.item_id_vocab = {iid: i+1 for i,iid in enumerate(self.items)}

        # selected meta features
        self.selected_item_features = ['title',	'genres']
        self.selected_user_features = ['name',	'gender',	'age']

        # meta feature vocabulary, {feature_name: {feature_value: one-hot/multi-hot vector}}
        self.user_vocab = get_onehot_vocab(user_meta_file, self.selected_user_features)
        self.item_vocab = get_onehot_vocab(item_meta_file, self.selected_item_features[:-1])
        self.item_vocab.update(get_multihot_vocab(item_meta_file, ['genres']))
        self.padding_item_meta = {f: np.zeros_like(list(v_dict.values())[0]) \
                                  for f, v_dict in self.item_vocab.items()}

        # response meta
        self.response_list = ['is_like','is_follow','is_comment','is_forward','is_hate']
        self.response_dim = len(self.response_list)
        self.padding_response = {resp: 0. for i,resp in enumerate(self.response_list)}
        self.response_neg_sample_rate = self.get_response_weights()

        # {'train': [row_id], 'val': [row_id], 'test': [row_id]}
        self.data = self._sequence_holdout(args)

    def _sequence_holdout(self, args):
        '''
        Holdout validation and test set from log_data
        '''
        print(f"sequence holdout for users (-1, {args.val_holdout_per_user}, {args.test_holdout_per_user})")
        if args.val_holdout_per_user == 0 and args.test_holdout_per_user == 0:
            return {"train": self.log_data.index, "val": [], "test": []}
        data = {"train": [], "val": [], "test": []}
        for u in tqdm(self.users):
            sub_df = self.log_data[self.log_data['user_id'] == u]
            n_train = len(sub_df) - args.val_holdout_per_user - args.test_holdout_per_user
            if n_train < 0.6 * len(sub_df):
                continue
            data['train'].append(list(sub_df.index[:n_train]))
            data['val'].append(list(sub_df.index[n_train:n_train+args.val_holdout_per_user]))
            data['test'].append(list(sub_df.index[-args.test_holdout_per_user:]))
        for k,v in data.items():
            data[k] = np.concatenate(v)
        return data

    def get_response_weights(self):
        ratio = {}
        for f in self.response_list:
            counts = self.log_data[f].value_counts()
            ratio[f] = float(counts[1]) / counts[0]
        ratio['is_hate'] *= -1
        return ratio


    ###########################
    #        Iterator         #
    ###########################

    def __getitem__(self, idx):
        '''
        sample getter

        train batch after collate:
        {
            'user_id': (B,)
            'item_id': (B,)
            'is_click', 'long_view', ...: (B,)
            'uf_{feature}': (B,F_dim(feature)), user features
            'if_{feature}': (B,F_dim(feature)), item features
            'history': (B,max_H)
            'history_length': (B,)
            'history_if_{feature}': (B, max_H, F_dim(feature))
            'history_{response}': (B, max_H)
            'loss_weight': (B, n_response)
        }
        '''
        row_id = self.data[self.phase][idx]
        row = self.log_data.iloc[row_id]

        user_id = row['user_id'] # raw user ID
        item_id = row['item_id'] # raw item ID

        # user, item, responses
        record = {
            'user_id': self.user_id_vocab[row['user_id']], # encoded user ID
            'item_id': self.item_id_vocab[row['item_id']], # encoded item ID
        }
        for _,f in enumerate(self.response_list):
            record[f] = row[f]
        loss_weight = np.array([1. if record[f] == 1 \
                                    else -self.response_neg_sample_rate[f] if f == 'is_hate' \
                                        else self.response_neg_sample_rate[f]\
                                for i,f in enumerate(self.response_list)])
        record["loss_weight"] = loss_weight

        # meta features
        user_meta = self.get_user_meta_data(user_id)
        record.update(user_meta)
        item_meta = self.get_item_meta_data(item_id)
        record.update(item_meta)

        # history features (max_H,)
        H_rowIDs = [rid for rid in self.user_history[user_id] if rid < row_id][-self.max_hist_seq_len:]
        history, hist_length, hist_meta, hist_response = self.get_user_history(H_rowIDs)
        record['history'] = np.array(history)
        record['history_length'] = hist_length
        # hist_meta keys are already like 'history_if_title', 'history_if_genres'
        record.update(hist_meta)
        # hist_response keys are already like 'history_is_like', ...
        record.update(hist_response)
        # for f,v in hist_meta.items():
        #     record[f'history_{f}'] = v
        # for f,v in hist_response.items():
        #     record[f'history_{f}'] = v

        return record

    def get_user_meta_data(self, user_id):
        '''
        @input:
        - user_id: raw user ID
        @output:
        - user_meta_record: {'uf_{feature_name}: one-hot vector'}
        '''
        user_feature_dict = self.user_meta[user_id]
        user_meta_record = {f'uf_{f}': self.user_vocab[f][user_feature_dict[f]]\
                            for f in self.selected_user_features}
        return user_meta_record

    def get_item_meta_data(self, item_id):
        item_feature_dict = self.item_meta[item_id]

        # Exclude genres (multi-hot handled separately)
        onehot_feats = [f for f in self.selected_item_features if f != "genres" and f in self.item_vocab]

        item_meta_record = {
        f'if_{f}': self.item_vocab[f][item_feature_dict[f]]
        for f in onehot_feats
        }

        # Multi-hot encoding for genres
        genres_val = item_feature_dict.get("genres", "")
        tags = [t.strip() for t in genres_val.split(",") if t.strip()]
        if tags:
           vecs = [self.item_vocab["genres"][t] for t in tags if t in self.item_vocab["genres"]]
           if vecs:
              item_meta_record["if_genres"] = np.sum(vecs, axis=0)
           else:
              item_meta_record["if_genres"] = np.zeros_like(next(iter(self.item_vocab["genres"].values())))
        else:
           item_meta_record["if_genres"] = np.zeros_like(next(iter(self.item_vocab["genres"].values())))

        return item_meta_record


    def get_user_history(self, H_rowIDs):
        L = len(H_rowIDs)
        if L == 0:
           history = [0] * self.max_hist_seq_len

           # Build history feature arrays only for features we actually embed:
           # that's every key in self.item_vocab (e.g., 'title' one-hot and 'genres' multi-hot)
           hist_meta = {}
           for f, v_dict in self.item_vocab.items():
              pad_vec = np.zeros_like(next(iter(v_dict.values())))
              hist_meta[f'history_if_{f}'] = np.tile(pad_vec, (self.max_hist_seq_len, 1))

           history_response = {
            f'history_{resp}': np.zeros(self.max_hist_seq_len, dtype=float)
            for resp in self.response_list
           }
           return history, 0, hist_meta, history_response

        # L > 0
        H = self.log_data.iloc[H_rowIDs]
        item_ids = [self.item_id_vocab[iid] for iid in H['item_id']]
        history = padding_and_clip(item_ids, self.max_hist_seq_len)

        # build per-item meta then pad to (max_H, feat_dim)
        meta_list = [self.get_item_meta_data(iid) for iid in H['item_id']]

        hist_meta = {}
        for f in self.item_vocab.keys():  # e.g., 'title', 'genres'
            feat_stack = np.stack([m[f'if_{f}'] for m in meta_list], axis=0)  # (L, feat_dim)
            pad_rows = self.max_hist_seq_len - L
            if pad_rows > 0:
               pad_vec = np.zeros_like(feat_stack[0])
               pad = np.tile(pad_vec, (pad_rows, 1))                         # (pad_rows, feat_dim)
               feat_stack = np.concatenate([pad, feat_stack], axis=0)        # (max_H, feat_dim)
            hist_meta[f'history_if_{f}'] = feat_stack

        history_response = {}
        for resp in self.response_list:
           arr = np.array(H[resp], dtype=float)            # (L,)
           pad_rows = self.max_hist_seq_len - L
           if pad_rows > 0:
              pad = np.zeros(pad_rows, dtype=float)
              arr = np.concatenate([pad, arr], axis=0)    # (max_H,)
           history_response[f'history_{resp}'] = arr

        return history, L, hist_meta, history_response


    def get_statistics(self):
        '''
        - n_user
        - n_item
        - s_parsity
        - from BaseReader:
            - length
            - fields
        '''
        stats = {}
        stats["raw_data_size"] = len(self.log_data)
        stats["data_size"] = [len(self.data['train']), len(self.data['val']), len(self.data['test'])]
        stats["n_user"] = len(self.users)
        stats["n_item"] = len(self.items)
        stats["max_seq_len"] = self.max_hist_seq_len
        stats["user_features"] = self.selected_user_features
        stats["user_feature_dims"] = {f: len(list(v_dict.values())[0]) for f, v_dict in self.user_vocab.items()}
        stats["item_features"] = self.selected_item_features
        stats["item_feature_dims"] = {f: len(list(v_dict.values())[0]) for f, v_dict in self.item_vocab.items()}
        stats["feedback_type"] = self.response_list
        stats["feedback_size"] = self.response_dim
        stats["feedback_negative_sample_rate"] = self.response_neg_sample_rate
        return stats

In [None]:
import torch
import torch.nn as nn
import numpy as np

class BaseModel(nn.Module):

    #############################
    #     Optional Overwrite    #
    #############################

    @staticmethod
    def parse_model_args(parser):
        parser.add_argument('--model_path', type=str, default='',
                            help='Model save path.')
        parser.add_argument('--loss', type=str, default='bce',
                            help='loss type')
        parser.add_argument('--l2_coef', type=float, default=0.,
                            help='coefficient of regularization term')
        return parser

    def log(self):
        print("Model params")
        print("\tmodel_path = " + str(self.model_path))
        print("\tloss_type = " + str(self.loss_type))
        print("\tl2_coef = " + str(self.l2_coef))
        print("\tdevice = " + str(self.device))

    def __init__(self, *input_args):
        args, reader_stats, device = input_args
        # super(BaseModel, self).__init__()
        nn.Module.__init__(self)
        self.display_name = "BaseModel"
        self.reader_stats = reader_stats
        self.model_path = args.model_path
        self.loss_type = args.loss
        self.l2_coef = args.l2_coef
        self.no_reg = 0. < args.l2_coef < 1.
        self.device = device

        self._define_params(args, reader_stats)

        self.sigmoid = nn.Sigmoid()

    def get_regularization(self, *modules):
        return get_regularization(*modules)

    def show_params(self):
        print(f"All parameters for {self.display_name}========================")
        idx = 0
        all_params = []
        for name, param in self.named_parameters():
            if not param.requires_grad:
                # try:
                param_shape = list(param.size())
                print(" var {:3}: {:15} {}".format(idx, str(param_shape), name))
                num_params = 1
                if (len(param_shape) > 1):
                    for p in param_shape:
                        if (p > 0):
                            num_params = num_params * int(p)
                    all_params.append(num_params)
                elif len(param_shape) == 1:
                    all_params.append(param_shape[0])
                else:
                    all_params.append(1)
                idx += 1
        num_fixed_params = np.sum(all_params)
        idx = 0
        all_params = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                # try:
                param_shape = list(param.size())
                print(" var {:3}: {:15} {}".format(idx, str(param_shape), name))
                num_params = 1
                if (len(param_shape) > 1):
                    for p in param_shape:
                        if (p > 0):
                            num_params = num_params * int(p)
                    all_params.append(num_params)
                elif len(param_shape) == 1:
                    all_params.append(param_shape[0])
                else:
                    all_params.append(1)
                idx += 1
        num_params = np.sum(all_params)
        print("Total number of trainable params {}".format(num_params))
        print("Total number of fixed params {}".format(num_fixed_params))

    def do_forward_and_loss(self, feed_dict: dict) -> dict:
        '''
        Called during training
        '''
        out_dict = self.forward(feed_dict)
        return self.get_loss(feed_dict, out_dict)


    def forward(self, feed_dict: dict, return_prob=True) -> dict:
        out_dict = self.get_forward(feed_dict)
        # only add probs if this model actually produced "preds"
        if return_prob and "preds" in out_dict:
            out_dict["probs"] = self.sigmoid(out_dict["preds"])
        return out_dict

    def wrap_batch(self, batch):
        '''
        Build feed_dict from batch data and move data to self.device
        '''
        for k,val in batch.items():
            if type(val).__module__ == np.__name__:
                batch[k] = torch.from_numpy(val)
            elif torch.is_tensor(val):
                batch[k] = val
            elif type(val) is list:
                batch[k] = torch.tensor(val)
            else:
                continue
            if batch[k].type() == "torch.DoubleTensor":
                batch[k] = batch[k].float()
            batch[k] = batch[k].to(self.device)
        return batch

    def save_checkpoint(self):
        torch.save({
            "model_state_dict": self.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "reader_stats": self.reader_stats
        }, self.model_path + ".checkpoint")
        print("Model (checkpoint) saved to " + self.model_path + ".checkpoint")


    def load_from_checkpoint(self, model_path='', with_optimizer=True):
        if len(model_path) == 0:
            model_path = self.model_path
        ckpt_file = model_path + ".checkpoint"
        print("Load (checkpoint) from", ckpt_file)

        try:
            # First try: safest way (PyTorch 2.6+ default)
            checkpoint = torch.load(ckpt_file, map_location=self.device, weights_only=True)
        except TypeError:
            # PyTorch < 2.6 doesn't know weights_only
            checkpoint = torch.load(ckpt_file, map_location=self.device)
        except Exception as e:
            print("⚠️ weights_only=True failed:", e)
            print("Retrying with weights_only=False (trusted checkpoint)...")
            checkpoint = torch.load(ckpt_file, map_location=self.device, weights_only=False)

        # Restore fields
        if "reader_stats" in checkpoint:
            self.reader_stats = checkpoint["reader_stats"]
            print("reader_stats loaded:", self.reader_stats)

        if "model_state_dict" in checkpoint:
            self.load_state_dict(checkpoint["model_state_dict"])
        else:
            raise KeyError("No 'model_state_dict' in checkpoint!")

        if with_optimizer and "optimizer_state_dict" in checkpoint:
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        self.model_path = model_path
        print("Checkpoint loaded successfully ✅")

    def actions_before_train(self, info):  # e.g. initialization
        pass

    def actions_after_train(self, info):  # e.g. compression
        pass

    def actions_before_epoch(self, info): # e.g. expectation update
        pass

    def actions_after_epoch(self, info): # e.g. prunning
        pass

    #############################
    #   Require Implementation  #
    #############################

    def _define_params(self, args, reader_stats):  # the model components and parameters
        pass

    def get_forward(self, feed_dict: dict) -> dict:  # the forward function
        pass

    def get_loss(self, feed_dict: dict, out_dict: dict):  # the loss function
        pass

In [None]:
import torch.nn as nn

class DNN(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim = 1, dropout_rate = 0., do_batch_norm = True):
        super(DNN, self).__init__()
        self.in_dim = in_dim
        layers = []

        # hidden layers
        for hidden_dim in hidden_dims:
            linear_layer = nn.Linear(in_dim, hidden_dim)
            # torch.nn.init.xavier_uniform_(linear_layer.weight, gain=nn.init.calculate_gain('relu'))
            layers.append(linear_layer)
            in_dim = hidden_dim

            layers.append(nn.ReLU())
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            if do_batch_norm:
#                 layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.LayerNorm([hidden_dim]))

        # prediction layer
        last_layer = nn.Linear(in_dim, out_dim)
        layers.append(last_layer)
        # torch.nn.init.xavier_uniform_(last_layer.weight, gain=1.0)

        self.layers = nn.Sequential(*layers)

    def forward(self, inputs):
        """
        @input:
            `inputs`, [bsz, in_dim]
        @output:
            `logit`, [bsz, out_dim]
        """
        inputs = inputs.view(-1, self.in_dim)
        logit = self.layers(inputs)
        return logit

In [None]:
from matplotlib.pyplot import axes, axis
import torch
import torch.nn as nn



class KRMBUserResponse(BaseModel):
    '''
    KuaiRand Multi-Behavior user response model
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - user_latent_dim
        - item_latent_dim
        - enc_dim
        - attn_n_head
        - transformer_d_forward
        - transformer_n_layer
        - scorer_hidden_dims
        - dropout_rate
        - from BaseModel:
            - model_path
            - loss
            - l2_coef
        '''
        parser = BaseModel.parse_model_args(parser)

        parser.add_argument('--user_latent_dim', type=int, default=16,
                            help='user latent embedding size')
        parser.add_argument('--item_latent_dim', type=int, default=16,
                            help='item latent embedding size')
        parser.add_argument('--enc_dim', type=int, default=32,
                            help='item encoding size')
        parser.add_argument('--attn_n_head', type=int, default=4,
                            help='number of attention heads in transformer')
        parser.add_argument('--transformer_d_forward', type=int, default=64,
                            help='forward layer dimension in transformer')
        parser.add_argument('--transformer_n_layer', type=int, default=2,
                            help='number of encoder layers in transformer')
        parser.add_argument('--state_hidden_dims', type=int, nargs='+', default=[128],
                            help='hidden dimensions')
        parser.add_argument('--scorer_hidden_dims', type=int, nargs='+', default=[128],
                            help='hidden dimensions')
        parser.add_argument('--dropout_rate', type=float, default=0.1,
                            help='dropout rate in deep layers')
        return parser

    def log(self):
        print("KRMBUserResponse params:")
        print(f"\tuser_latent_dim: {self.user_latent_dim}")
        print(f"\titem_latent_dim: {self.item_latent_dim}")
        print(f"\tenc_dim: {self.enc_dim}")
        print(f"\tattn_n_head: {self.attn_n_head}")
        print(f"\tscorer_hidden_dims: {self.scorer_hidden_dims}")
        print(f"\tdropout_rate: {self.dropout_rate}")
        print(f"\tstate_dim: {self.state_dim}")
        super().log()

    def __init__(self, args, reader_stats, device):
        self.user_latent_dim = args.user_latent_dim
        self.item_latent_dim = args.item_latent_dim
        self.enc_dim = args.enc_dim
        self.attn_n_head = args.attn_n_head
        self.scorer_hidden_dims = args.scorer_hidden_dims
        self.dropout_rate = args.dropout_rate
        super().__init__(args, reader_stats, device)
        self.bce_loss = nn.BCEWithLogitsLoss(reduction = 'none')
        self.state_dim = 3*args.enc_dim

    def to(self, device):
        new_self = super(KRMBUserResponse, self).to(device)
        new_self.attn_mask = new_self.attn_mask.to(device)
        new_self.pos_emb_getter = new_self.pos_emb_getter.to(device)
        new_self.behavior_weight = new_self.behavior_weight.to(device)
        return new_self

    def _define_params(self, args, reader_stats):
        stats = reader_stats

        self.user_feature_dims = stats['user_feature_dims'] # {feature_name: dim}
        self.item_feature_dims = stats['item_feature_dims'] # {feature_name: dim}

        # user embedding
        self.uIDEmb = nn.Embedding(stats['n_user']+1, args.user_latent_dim)
        self.uFeatureEmb = {}
        for f,dim in self.user_feature_dims.items():
            embedding_module = nn.Linear(dim, args.user_latent_dim)
            self.add_module(f'UFEmb_{f}', embedding_module)
            self.uFeatureEmb[f] = embedding_module

        # item embedding
        self.iIDEmb = nn.Embedding(stats['n_item']+1, args.item_latent_dim)
        self.iFeatureEmb = {}
        for f,dim in self.item_feature_dims.items():
            embedding_module = nn.Linear(dim, args.item_latent_dim)
            self.add_module(f'IFEmb_{f}', embedding_module)
            self.iFeatureEmb[f] = embedding_module

        # feedback embedding
        self.feedback_types = stats['feedback_type']
        self.feedback_dim = stats['feedback_size']
        self.xtr_dim = 2*self.feedback_dim
        self.feedbackEncoder = nn.Linear(self.feedback_dim, args.enc_dim)
        self.set_behavior_hyper_weight(torch.ones(self.feedback_dim))

        # item embedding kernel encoder
        self.itemEmbNorm = nn.LayerNorm(args.item_latent_dim)
        self.userEmbNorm = nn.LayerNorm(args.user_latent_dim)
        self.itemFeatureKernel = nn.Linear(args.item_latent_dim, args.enc_dim)
        self.userFeatureKernel = nn.Linear(args.user_latent_dim, args.enc_dim)
        self.encDropout = nn.Dropout(self.dropout_rate)
        self.encNorm = nn.LayerNorm(args.enc_dim)

        # positional embedding
        self.max_len = stats['max_seq_len']
        self.posEmb = nn.Embedding(self.max_len, args.enc_dim)
        self.pos_emb_getter = torch.arange(self.max_len, dtype = torch.long)
        self.attn_mask = ~torch.tril(torch.ones((self.max_len,self.max_len), dtype=torch.bool))

        # sequence encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=2*args.enc_dim, dim_feedforward = args.transformer_d_forward,
                                                   nhead=args.attn_n_head, dropout = args.dropout_rate,
                                                   batch_first = True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.transformer_n_layer)

        # DNN state encoder
        self.stateNorm = nn.LayerNorm(args.enc_dim)

        # DNN scorer
        self.scorer_hidden_dims = args.scorer_hidden_dims
        self.scorer = DNN(3*args.enc_dim, args.state_hidden_dims, self.feedback_dim * args.enc_dim,
                          dropout_rate = args.dropout_rate, do_batch_norm = True)

    def set_behavior_hyper_weight(self, weight):
        self.behavior_weight = weight.view(-1)
        assert len(self.behavior_weight) == self.feedback_dim

    def get_forward(self, feed_dict: dict):
        '''
        This is used during simulator training
        When serving as a simulator, it calls encode_state() + get_pointwise_score()
        @input:
        - feed_dict: {
            'user_id': (B,)
            'uf_{feature_name}': (B,feature_dim), the user features
            'item_id': (B,), the target item
            'if_{feature_name}': (B,feature_dim), the target item features
            'history': (B,max_H)
            'history_if_{feature_name}': (B,max_H,feature_dim), the history item features
            ... (irrelevant input)
        }
        @output:
        - out_dict: {'preds': (B,-1,n_feedback), 'reg': scalar}
        '''
        B = feed_dict['user_id'].shape[0]

        # target item
        # (B, -1, enc_dim)
        item_enc, item_reg = self.get_item_encoding(feed_dict['item_id'],
                                          {k[3:]:v for k,v in feed_dict.items() if k[:3] == 'if_'}, B)


        # (B, -1, 1, enc_dim)
        item_enc = item_enc.view(B,-1,1,self.enc_dim)


        # user encoding
        state_encoder_output = self.encode_state(feed_dict, B)
        # (B, 1, 3*enc_dim)
        user_state = state_encoder_output['state'].view(B,1,3*self.enc_dim)
        # (B, -1, n_feedback), (B, -1, n_feedback)
        behavior_scores, point_scores = self.get_pointwise_scores(user_state, item_enc, B)


        # regularization terms
        reg = self.get_regularization(self.feedbackEncoder,
                                      self.itemFeatureKernel, self.userFeatureKernel,
                                      self.posEmb, self.transformer, self.scorer)
#         for v in self.uFeatureEmb.values():
#             reg += self.get_regularization(v)
#         for v in self.iFeatureEmb.values():
#             reg += self.get_regularization(v)
        reg = reg + state_encoder_output['reg'] + item_reg
        return {'preds': behavior_scores, 'state': user_state, 'reg': reg}
#         output_dict['reg'] = reg
#         return output_dict

    def encode_state(self, feed_dict, B):
        '''
        @input:
        - feed_dict: {
            'user_id': (B,)
            'uf_{feature_name}': (B,feature_dim), the user features
            'history': (B,max_H)
            'history_if_{feature_name}': (B,max_H,feature_dim), the history item features
            ... (irrelevant input)
        }
        - B: batch size
        @output:
        - out_dict:{
            'out_seq': (B,max_H,2*enc_dim)
            'state': (B,n_feedback*enc_dim)
            'reg': scalar
        }
        '''
        # user history
        # (B, max_H, enc_dim)
        history_enc, history_reg = self.get_item_encoding(feed_dict['history'],
                                             {f:feed_dict[f'history_if_{f}'] for f in self.iFeatureEmb}, B)
        history_enc = history_enc.view(B, self.max_len, self.enc_dim)
        # (1, max_H, enc_dim)
        pos_emb = self.posEmb(self.pos_emb_getter).view(1,self.max_len,self.enc_dim)
        # (B, max_H, enc_dim)
        seq_enc_feat = self.encNorm(self.encDropout(history_enc + pos_emb))
        # (B, max_H, enc_dim)
        feedback_emb = self.get_response_embedding(feed_dict, B)
        # (B, max_H, 2*enc_dim)
        seq_enc = torch.cat((seq_enc_feat, feedback_emb), dim = -1)
        # (B, max_H, 2*enc_dim)
        output_seq = self.transformer(seq_enc, mask = self.attn_mask)
        # (B, 2*enc_dim)
        hist_enc = output_seq[:,-1,:].view(B,2*self.enc_dim)
        # user features
        # (B, enc_dim), scalar
        user_enc, user_reg = self.get_user_encoding(feed_dict['user_id'],
                                          {k[3:]:v for k,v in feed_dict.items() if k[:3] == 'uf_'}, B)
        # (B, enc_dim)
        user_enc = self.encNorm(self.encDropout(user_enc)).view(B,self.enc_dim)
        # (B, 3*enc_dim)
        state = torch.cat([hist_enc,user_enc], 1)
        return {'output_seq': output_seq, 'state': state, 'reg': user_reg + history_reg}

    def get_user_encoding(self, user_ids, user_features, B):
        '''
        @input:
        - user_ids: (B,), encoded user id
        - user_features: {'uf_{feature_name}': (B, feature_dim)}
        '''
        # (B, 1, u_latent_dim)
        user_id_emb = self.uIDEmb(user_ids).view(B,1,self.user_latent_dim)
        # [(B, 1, u_latent_dim)] * n_user_feature
        user_feature_emb = [user_id_emb]
        for f,fEmbModule in self.uFeatureEmb.items():
            user_feature_emb.append(fEmbModule(user_features[f]).view(B,1,self.user_latent_dim))
        # (B, n_user_feature+1, u_latent_dim)
        combined_user_emb = torch.cat(user_feature_emb, 1)
        combined_user_emb = self.userEmbNorm(combined_user_emb)
        # (B, enc_dim)
        encoding = self.userFeatureKernel(combined_user_emb).sum(1)
        # regularization
        reg = torch.mean(user_id_emb * user_id_emb)
        return encoding, reg

    def get_item_encoding(self, item_ids, item_features, B):
        '''
        @input:
        - item_ids: (B,) or (B,H), encoded item id
        - item_features: {'if_{feature_name}': (B,feature_dim) or (B,H,feature_dim)}
        '''
        # (B, 1, i_latent_dim) or (B, H, i_latent_dim)
        item_id_emb = self.iIDEmb(item_ids).view(B,-1,self.item_latent_dim)
        L = item_id_emb.shape[1]
        # [(B, 1, i_latent_dim)] * n_item_feature or [(B, H, i_latent_dim)] * n_item_feature
        item_feature_emb = [item_id_emb]
        for f,fEmbModule in self.iFeatureEmb.items():
            f_dim = self.item_feature_dims[f]
            item_feature_emb.append(fEmbModule(item_features[f].view(B,L,f_dim)).view(B,-1,self.item_latent_dim))
        # (B, 1, n_item_feature+1, i_latent_dim) or (B, H, n_item_feature+1, i_latent_dim)
        combined_item_emb = torch.cat(item_feature_emb, -1).view(B, L, -1, self.item_latent_dim)
        combined_item_emb = self.itemEmbNorm(combined_item_emb)
        # (B, 1, enc_dim) or (B, H, enc_dim)
        encoding = self.itemFeatureKernel(combined_item_emb).sum(2)
        encoding = encoding.view(B, -1, self.enc_dim)
        encoding = self.encNorm(encoding)
        # regularization
        reg = torch.mean(item_id_emb * item_id_emb)
        return encoding, reg

    def get_response_embedding(self, feed_dict, B):
        resp_list = []
        for f in self.feedback_types:
            # (B, max_H)
            resp = feed_dict[f'history_{f}'].view(B, self.max_len)
            resp_list.append(resp)
        # (B, max_H, n_feedback)
        combined_resp = torch.cat(resp_list, -1).view(B,self.max_len,self.feedback_dim)
        # (B, max_H, i_latent_dim)
        resp_emb = self.feedbackEncoder(combined_resp)
        return resp_emb

    def get_loss(self, feed_dict: dict, out_dict: dict):
        """
        @input:
        - feed_dict: {...}
        - out_dict: {"preds":, "reg":}

        Loss terms implemented:
        - BCE
        """
        B = feed_dict['user_id'].shape[0]
        # (B, -1, n_feedback)
        preds = out_dict['preds'].view(B,-1,self.feedback_dim)
        # [(B, -1, 1)] * n_feedback
        targets = {f:feed_dict[f].view(B,-1).to(torch.float) for f in self.feedback_types}
        # (B, -1, n_feedback)
        loss_weight = feed_dict['loss_weight'].view(B,-1,self.feedback_dim)

        if self.loss_type == 'bce':
            behavior_loss = {}
            loss = 0
            for i,fb in enumerate(self.feedback_types):
                if self.behavior_weight[i] == 0:
                    continue
                Y = targets[fb].view(-1)
                P = preds[:,:,i].view(-1)
                W = loss_weight[:,:,i].view(-1)
                # (B*L,)
                point_loss = self.bce_loss(self.sigmoid(P), Y)
                behavior_loss[fb] = torch.mean(point_loss).item()
                point_loss = torch.mean(point_loss * W)
                point_loss = torch.mean(point_loss)
                loss = self.behavior_weight[i] * point_loss + loss
        else:
            raise NotImplemented
        out_dict['loss'] = loss + self.l2_coef * out_dict['reg']
        out_dict['behavior_loss'] = behavior_loss
        return out_dict


    def get_pointwise_scores(self, user_state, item_enc, B):
        '''
        Get user-item pointwise interaction scores
        @input:
        - user_state: (B, state_dim)
        - item_enc: (B, -1, 1, enc_dim) for batch-wise candidates or (1, -1, 1, enc_dim) for universal candidates
        - B: batch size
        @output:
        - behavior_scores: (B, -1, n_feedback)
        '''
        # scoring
        # (B, 1, n_feedback, enc_dim)
        behavior_attn = self.scorer(user_state).view(B,1,self.feedback_dim,self.enc_dim)
        # (B, 1, n_feedback, enc_dim)
        behavior_attn = self.stateNorm(behavior_attn)
        # (B, -1, n_feedback)
        point_scores = (behavior_attn * item_enc).mean(dim = -1).view(B,-1,self.feedback_dim)
        return point_scores, torch.mean(point_scores, dim = -1)

In [None]:
from tqdm import tqdm
from time import time
import torch
from torch.utils.data import DataLoader
import argparse
import numpy as np
import os
from sklearn.metrics import roc_auc_score

def do_eval(model, reader, args):
    reader.set_phase("val")

    eval_loader = DataLoader(reader, batch_size=1, shuffle=False, pin_memory=True, num_workers=1, persistent_workers=True)


    val_report = {'loss': [], 'auc': {}}
    Y_dict = {f: [] for f in model.feedback_types}
    P_dict = {f: [] for f in model.feedback_types}
    pbar = tqdm(total = len(reader))
    with torch.no_grad():
        for i, batch_data in enumerate(eval_loader):
            wrapped_batch = wrap_batch(batch_data, device = args.device)
            out_dict = model.do_forward_and_loss(wrapped_batch)
            loss = out_dict['loss']
            val_report['loss'].append(loss.item())
            for j,f in enumerate(model.feedback_types):
                Y_dict[f].append(wrapped_batch[f].view(-1).detach().cpu().numpy())
                P_dict[f].append(out_dict['preds'][:,:,j].view(-1).detach().cpu().numpy())
            pbar.update(args.batch_size)
    val_report['loss'] = (np.mean(val_report['loss']), np.min(val_report['loss']), np.max(val_report['loss']))
    for f in model.feedback_types:
        val_report['auc'][f] = roc_auc_score(np.concatenate(Y_dict[f]),
                                             np.concatenate(P_dict[f]))
    pbar.close()
    return val_report

In [None]:
import logging
from typing import Any, Dict, Optional, List, Tuple
from yacs.config import CfgNode
import os
import math


# logger
def set_logger(log_file, name="default"):
    """
    Set logger.
    Args:
        log_file (str): log file path
        name (str): logger name
    """

    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s - %(module)s - %(funcName)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    print(formatter.format(logging.LogRecord(
    name='root',
    level=logging.INFO,
    pathname=None,
    lineno=None,
    msg='Test message',
    args=None,
    exc_info=None,
    )))
    output_folder = "output"
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Create the 'log' folder if it doesn't exist
    log_folder = os.path.join(output_folder, "log")
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    # Create the 'message' folder if it doesn't exist
    message_folder = os.path.join(output_folder, "message")
    if not os.path.exists(message_folder):
        os.makedirs(message_folder)
    log_file = os.path.join(log_folder, log_file)
    handler = logging.FileHandler(log_file, mode="w")
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logger.handlers = []
    logger.addHandler(handler)
    print(logger)
    return logger

def load_cfg(cfg_file: str, new_allowed: bool = True) -> CfgNode:
    """
    Load config from file.
    Args:
        cfg_file (str): config file path
        new_allowed (bool): whether to allow new keys in config
    """
    with open(cfg_file, "r") as fi:
        cfg = CfgNode.load_cfg(fi)
    cfg.set_new_allowed(new_allowed)
    return cfg

def add_variable_to_config(cfg: CfgNode, name: str, value: Any) -> CfgNode:
    """
    Add variable to config.
    Args:
        cfg (CfgNode): config
        name (str): variable name
        value (Any): variable value
    """
    cfg.defrost()
    cfg[name] = value
    cfg.freeze()
    return cfg

def ensure_dir(dir_path):
    """
    Make sure the directory exists, if it does not exist, create it
    Args:
        dir_path (str): The directory path.
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def calculate_entropy(movie_types):
    type_freq = {}
    for movie_type in movie_types:
        if movie_type in type_freq:
            type_freq[movie_type] += 1
        else:
            type_freq[movie_type] = 1

    total_movies = len(movie_types)

    entropy = 0
    for key in type_freq:
        prob = type_freq[key] / total_movies
        entropy -= prob * math.log2(prob)

    return entropy


def get_entropy(inters, data):
    genres = data.get_genres_by_id(inters)
    entropy = calculate_entropy(genres)
    return entropy

In [None]:
import csv
import os
import math
import numpy as np


class Data:
    """
    Data class for loading data from local files.
    """

    def __init__(self, config):
        self.config = config
        self.items = {}
        self.users = {}
        self.db = None
        self.tot_relationship_num = 0
        self.netwerk_density = 0.0
        self.role_id = -1
        self.interrating = {}
        self.user_ratings = {}
        self.item_ratings = {}
        self.load_items(config["item_path"])
        self.load_users(config["user_path"])
        self.load_interactions_rating(config["rating_path"])

    def load_items(self, file_path):
        """
        Load items from local file.
        """
        with open(file_path, "r", newline="") as file:
            reader = csv.reader(file)
            next(reader)  # Skip the header line
            for row in reader:
                item_id, title, genre, description = row
                self.items[int(item_id)] = {
                    "name": title.strip(),
                    "genre": genre,
                    "description": description.strip(),
                    "inter_cnt": 0,
                    "mention_cnt": 0,
                }

    def load_users(self, file_path):
        """
        Load users from local file.
        """
        cnt = 1
        with open(file_path, "r", newline="") as file:
            reader = csv.reader(file)
            next(reader)  # Skip the header line
            for row in reader:
                # print(len(row), row)
                user_id, name, gender, age, traits, status, interest, feature = row
                # user_id, name, gender, age, status, pos, neg = row
                self.users[cnt] = {
                    "name": name,
                    "gender": gender,
                    "age": int(age),
                    "traits": traits,
                    "status": status,
                    "interest": interest,
                    "feature": feature,
                }
                cnt += 1


    def load_interactions_rating(self, file_path):
      """
      Load user-item interactions (with rating) from local file.
      Stores in self.interrating as a dict:
      {user_id: [(item_id, rating), ...], ...}
      """
      with open(file_path, "r", newline="") as file:
        reader = csv.reader(file)
        header = next(reader)  # Skip the header line
        for row in reader:
            user_id, item_id, rating = row
            user_id = int(user_id)
            item_id = int(item_id)
            rating = int(rating)
            if user_id not in self.interrating:
                self.interrating[user_id] = []
            self.interrating[user_id].append((item_id, rating))

            if user_id not in self.user_ratings:
                self.user_ratings[user_id] = []
            self.user_ratings[user_id].append(rating)

            # Store item rating
            if item_id not in self.item_ratings:
                self.item_ratings[item_id] = []
            self.item_ratings[item_id].append(rating)

      # Compute and store average historical rating for each user
      self.user_avg_rating = {uid: sum(ratings)/len(ratings) for uid, ratings in self.user_ratings.items() if ratings}

      # Compute and store average historical rating for each item
      self.item_avg_rating = {iid: sum(ratings)/len(ratings) for iid, ratings in self.item_ratings.items() if ratings}


    def get_full_items(self):
        return list(self.items.keys())

    def get_inter_popular_items(self):
        """
        Get the most popular items based on the number of interactions.
        """
        ids = sorted(
            self.items.keys(), key=lambda x: self.items[x]["inter_cnt"], reverse=True
        )[:3]
        return self.get_item_names(ids)

    def add_inter_cnt(self, item_names):
        item_ids = self.get_item_ids(item_names)
        print("item ids:", item_ids)
        for item_id in item_ids:
            self.items[item_id]["inter_cnt"] += 1

    def add_mention_cnt(self, item_names):
        item_ids = self.get_item_ids(item_names)
        for item_id in item_ids:
            self.items[item_id]["mention_cnt"] += 1

    def get_mention_popular_items(self):
        """
        Get the most popular items based on the number of mentions.
        """
        ids = sorted(
            self.items.keys(), key=lambda x: self.items[x]["mention_cnt"], reverse=True
        )[:3]
        return self.get_item_names(ids)

    def get_item_names(self, item_ids):
        return ["<" + self.items[item_id]["name"] + ">" for item_id in item_ids]

    def get_item_ids(self, item_names):
        item_ids = []
        for item in item_names:
            for item_id, item_info in self.items.items():
                if item_info["name"] in item:
                    item_ids.append(item_id)
                    break
        return item_ids

    def get_item_ids_exact(self, item_names):
        """
        Get item ids from item names.
        I coundn't find any difference with the get_item_ids(item_names) function
        """
        item_ids = []
        for item in item_names:
            for item_id, item_info in self.items.items():
                if item_info["name"] == item:
                    item_ids.append(item_id)
                    break
        return item_ids

    def get_full_users(self):
        return list(self.users.keys())

    def get_user_profile(self, user_id):
        """
        Return the user profile as a formatted string for the given user_id.
        """
        user = self.users.get(user_id)
        if not user:
           return f"User ID {user_id} not found."

        profile = (
        f"User ID: {user_id};; Name: {user['name']}\n"
        f"Gender: {user['gender']};; Age: {user['age']};; Status: {user['status']}\n"
        f"Traits: {user['traits']};; Interest: {user['interest']};; Feature: {user['feature']}\n")
        return profile

    def get_user_names(self, user_ids):
        return [self.users[user_id]["name"] for user_id in user_ids]

    def get_user_ids(self, user_names):
        user_ids = []
        for user in user_names:
            for user_id, user_info in self.users.items():
                if user_info["name"] == user:
                    user_ids.append(user_id)
                    break
        return user_ids

    def get_user_num(self):
        """
        Return the number of users.
        """
        return len(self.users.keys())

    def get_item_num(self):
        """
        Return the number of items.
        """
        return len(self.items.keys())

    def search_items(self, item, k=50):
        """
        Search similar items from faiss db.
        Args:
            item: str, item name
            k: int, number of similar items to return
        """
        docs = self.db.similarity_search(item, k)
        item_names = [doc.page_content for doc in docs]
        return item_names


    def get_item_description_by_id(self, item_ids):
        """
        Get description of items by item id.
        """
        return [self.items[item_id]["description"] for item_id in item_ids]

    def get_item_description_by_name(self, item_names):
        """
        Get description of items by item name.
        """
        item_descriptions = []
        for item in item_names:
            found = False
            for item_id, item_info in self.items.items():
                if item_info["name"] == item.strip(" <>"):
                    item_descriptions.append(item_info["description"])
                    found = True
                    break
            if not found:
                item_descriptions.append("")
        return item_descriptions

    def get_genres_by_id(self, item_ids):
        """
        Get genre of items by item id.
        """
        return [
            genre
            for item_id in item_ids
            for genre in self.items[item_id]["genre"].split('|')
        ]

    def hit_at_k(self, ground_truth, predicted, k):
        """Return 1 if any of the top-k predicted are relevant, else 0."""
        return int(bool(set(ground_truth) & set(predicted[:k])))

    def ndcg_at_k(self, ground_truth, predicted, k):
        """Compute NDCG@k for a single user."""
        def dcg(rel):
          return np.sum([(2**r - 1) / np.log2(i + 2) for i, r in enumerate(rel)])

        rel = [1 if item in ground_truth else 0 for item in predicted[:k]]
        ideal_rel = sorted([1]*min(len(ground_truth), k) + [0]*(k - min(len(ground_truth), k)), reverse=True)
        dcg_score = dcg(rel)
        idcg_score = dcg(ideal_rel)
        return dcg_score / idcg_score if idcg_score > 0 else 0.0

    def mse(self, ground_truth_ratings, predicted_ratings, items=None):
        """Compute MSE for ratings (on items in both sets)."""
        if items is None:
           items = set(ground_truth_ratings.keys()) & set(predicted_ratings.keys())
        else:
           items = set(items) & set(ground_truth_ratings.keys()) & set(predicted_ratings.keys())
        if not items:
           return np.nan
        errors = [(ground_truth_ratings[i] - predicted_ratings[i]) ** 2 for i in items]
        return np.mean(errors)

    def rmse(self, ground_truth_ratings, predicted_ratings, items=None):
        """Compute RMSE for ratings."""
        return np.sqrt(self.mse(ground_truth_ratings, predicted_ratings, items))

    def safe_log(self, x):
        """Numerically safe log."""
        return math.log(max(x, 1e-15))

    def ordered_probit_probs(self, pred_int, K, taus=None):
        """
        Compute ordered probit class probabilities for a predicted integer rating.

        pred_int : int
            The predicted integer rating (e.g., 1..K).
        K : int
            Number of rating classes (e.g., 5 for 1–5 stars).
        taus : list or array, optional
            Thresholds separating the ordered categories.
            If None, uses equally spaced thresholds [1.5, 2.5, ..., K-0.5].
        """

        if taus is None:
           taus = np.array([1.5 + i for i in range(K-1)])  # default thresholds

        assert len(taus) == K-1

        def Phi(z):
            return 0.5 * (1.0 + math.erf(z / math.sqrt(2.0)))  # Normal CDF

        probs = []
        for k in range(1, K+1):
           if k == 1:
              lower = -np.inf
              upper = taus[0]
           elif k == K:
              lower = taus[-1]
              upper = np.inf
           else:
              lower = taus[k-2]
              upper = taus[k-1]
           p_lower = 0.0 if lower == -np.inf else Phi((lower - pred_int))
           p_upper = 1.0 if upper == np.inf else Phi((upper - pred_int))
           probs.append(max(p_upper - p_lower, 1e-15))

        probs = np.array(probs)
        probs /= probs.sum()  # normalize
        return probs

In [None]:
class BaseModel(object):
    """Base class for all models."""

    def __init__(self, config, n_users, n_items):
        self.config = config
        self.items = None

    def get_full_sort_items(self, user_id, *args, **kwargs):
        """Get a list of sorted items for a given user."""
        raise NotImplementedError

    def _sort_full_items(self, user_id, *args, **kwargs):
        """Sort a list of items for a given user."""
        raise NotImplementedError

In [None]:
from typing import Union, List, Optional
import torch
import torch.nn as nn


class MF(BaseModel, nn.Module):
    def __init__(self, config, n_users, n_items):
        BaseModel.__init__(self, config, n_users, n_items)
        nn.Module.__init__(self)
        self.config = config
        self.embedding_size = config["embedding_size"]
        self.n_users = n_users
        self.n_items = n_items
        torch.manual_seed(config['seed'])
        # define layers and loss
        self.user_embedding = nn.Embedding(self.n_users+1, self.embedding_size)
        self.item_embedding = nn.Embedding(self.n_items+1, self.embedding_size)
        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)

    def forward(self, user, item):
        """Predicts the rating of a user for an item."""
        user_embed = self.user_embedding(user)
        item_embed = self.item_embedding(item)

        # Dot product between user and item embeddings to predict rating
        predicted_rating = (user_embed * item_embed).sum(1)

        return predicted_rating

    def get_full_sort_items(self, user, items):
        """Get a list of sorted items for a given user."""
        predicted_ratings = self.forward(user, items)
        sorted_items = self._sort_full_items(user, predicted_ratings, items)
        return sorted_items.tolist()

    def _sort_full_items(self, user, predicted_ratings, items):
        """Sort items based on their predicted ratings."""
        # Sort items based on ratings in descending order and return item indices
        _, sorted_indices = torch.sort(predicted_ratings, descending=True)
        return items[sorted_indices]

In [None]:
import math
from typing import List, Tuple, Iterable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class LightGCN(BaseModel, nn.Module):
    """
    LightGCN for recommendation.
    Reference: He et al., "LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation"
    """

    def __init__(
        self,
        config,
        n_users: int,
        n_items: int,
        interactions: Iterable[Tuple[int, int]],
        n_layers: int = 3,
        embedding_dim: int = 64,
        device: Optional[torch.device] = None,
    ):
        """
        interactions: iterable of (user_id, item_id) pairs (duplicates okay, handled as multi-edges of weight 1)
        """
        BaseModel.__init__(self, config, n_users, n_items)
        nn.Module.__init__(self)

        self.n_users = n_users
        self.n_items = n_items
        self.n_layers = n_layers
        self.embedding_dim = embedding_dim
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        torch.manual_seed(config.get("seed", 2023))

        # Embeddings
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)

        # Build normalized adjacency once
        self.Graph = self._build_normalized_adj(interactions).to(self.device)

        self.to(self.device)

    # ---------- Graph utilities ----------
    def _build_normalized_adj(self, interactions: Iterable[Tuple[int, int]]) -> torch.sparse.FloatTensor:
        """
        Build the symmetrically normalized adjacency matrix:
            A_hat = D^{-1/2} * A * D^{-1/2}
        for the bipartite user-item graph where nodes = users + items
        Size: (n_nodes, n_nodes) with n_nodes = n_users + n_items
        """
        n_nodes = self.n_users + self.n_items

        # Collect COO edges (user <-> item, undirected)
        rows = []
        cols = []
        vals = []

        for u, i, *rest in interactions:
            # Allow (u,i) or (u,i,rating) — rating ignored here (implicit 1)
            if u < 0 or u >= self.n_users or i < 0 or i >= self.n_items:
                continue
            i_offset = self.n_users + i  # item node index in unified graph
            # u -> i
            rows.append(u); cols.append(i_offset); vals.append(1.0)
            # i -> u
            rows.append(i_offset); cols.append(u); vals.append(1.0)

        if len(rows) == 0:
            # Empty graph fallback (identity to avoid NaNs)
            indices = torch.arange(n_nodes, dtype=torch.long)
            indices = torch.stack([indices, indices], dim=0)
            values = torch.ones(n_nodes, dtype=torch.float32)
            return torch.sparse_coo_tensor(indices, values, (n_nodes, n_nodes)).coalesce()

        indices = torch.tensor([rows, cols], dtype=torch.long)
        values = torch.tensor(vals, dtype=torch.float32)
        A = torch.sparse_coo_tensor(indices, values, (n_nodes, n_nodes)).coalesce()

        # Degree vector d = sum of rows
        deg = torch.sparse.sum(A, dim=1).to_dense()  # (n_nodes,)
        # Avoid divide-by-zero
        deg = torch.clamp(deg, min=1e-12)
        d_inv_sqrt = torch.pow(deg, -0.5)

        # Normalize values: for each edge (i,j), val *= d^-1/2[i] * d^-1/2[j]
        row, col = A.indices()
        norm_vals = A.values() * d_inv_sqrt[row] * d_inv_sqrt[col]

        A_hat = torch.sparse_coo_tensor(A.indices(), norm_vals, A.size()).coalesce()
        return A_hat

    # ---------- Embedding propagation ----------
    def propagate(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Perform K-layer LightGCN propagation and return final user & item embeddings.
        E^(0) = concat([U, I])  -> shape (n_users + n_items, d)
        E_final = 1/(K+1) * sum_{k=0..K} E^(k)
        """
        E_u0 = self.user_embedding.weight
        E_i0 = self.item_embedding.weight
        E0 = torch.cat([E_u0, E_i0], dim=0)  # (n_users + n_items, d)

        all_layers = [E0]
        x = E0
        for _ in range(self.n_layers):
            x = torch.sparse.mm(self.Graph, x)  # LightGCN propagation (no weights, no nonlinearity)
            all_layers.append(x)

        E = torch.stack(all_layers, dim=0).mean(dim=0)  # layer-wise average
        Eu, Ei = torch.split(E, [self.n_users, self.n_items], dim=0)
        return Eu, Ei

    # ---------- Scoring ----------
    def forward(self, users: torch.Tensor, items: torch.Tensor) -> torch.Tensor:
        """
        Predict scores for (users, items). Supports:
          - users: (B,), items: (B,)  -> elementwise scores
          - users: (B,), items: (N,)  -> broadcast users for all N items (returns (B,N))
        """
        users = users.to(self.device)
        items = items.to(self.device)

        Eu, Ei = self.propagate()

        if users.dim() == 1 and items.dim() == 1 and users.shape[0] == items.shape[0]:
            u_emb = Eu[users]                   # (B, d)
            i_emb = Ei[items]                   # (B, d)
            return (u_emb * i_emb).sum(dim=1)   # (B,)

        # Broadcast to (B, N)
        u_emb = Eu[users]                       # (B, d)
        i_emb = Ei[items]                       # (N, d)
        scores = u_emb @ i_emb.t()              # (B, N)
        return scores

    def predict(self, user_ids: List[int], item_ids: Optional[List[int]] = None) -> np.ndarray:
        """
        Convenience method to get scores as numpy.
        - If item_ids is None: scores for all items for each user -> shape (B, n_items)
        - Else: scores for user_ids x item_ids -> (B, len(item_ids))
        """
        self.eval()
        with torch.no_grad():
            u = torch.tensor(user_ids, dtype=torch.long, device=self.device)
            if item_ids is None:
                items = torch.arange(self.n_items, dtype=torch.long, device=self.device)
                scores = self.forward(u, items)        # (B, n_items)
            else:
                items = torch.tensor(item_ids, dtype=torch.long, device=self.device)
                scores = self.forward(u, items)        # (B, len(item_ids))
        return scores.detach().cpu().numpy()

    # ---------- Ranking API (BaseModel) ----------
    def get_full_sort_items(self, user_id: int, seen_items: Optional[Iterable[int]] = None, top_k: Optional[int] = None) -> List[int]:
        """
        Rank all items for a given user (descending by predicted score).
        Optionally drop previously seen items.
        """
        scores = self.predict([user_id], None).ravel()  # (n_items,)
        if seen_items is not None:
            # push seen items to bottom
            seen_items = [i for i in seen_items if 0 <= i < self.n_items]
            scores[np.array(seen_items, dtype=np.int64)] = -np.inf

        order = np.argsort(-scores)  # descending
        if top_k is not None:
            order = order[:top_k]
        return order.tolist()

    def _sort_full_items(self, user_id: int, predicted_ratings: torch.Tensor, items: torch.Tensor):
        # Not used here; keeping for BaseModel compatibility
        _, idx = torch.sort(predicted_ratings, descending=True)
        return items[idx]


class BPRLoss(nn.Module):
    """
    Pairwise Bayesian Personalized Ranking loss with L2 regularization on embeddings.
    """
    def __init__(self, reg: float = 1e-4):
        super().__init__()
        self.reg = reg

    def forward(self, u_emb, pos_emb, neg_emb):
        pos_scores = (u_emb * pos_emb).sum(dim=1)
        neg_scores = (u_emb * neg_emb).sum(dim=1)
        loss = -F.logsigmoid(pos_scores - neg_scores).mean()
        reg = (u_emb.norm(2).pow(2) + pos_emb.norm(2).pow(2) + neg_emb.norm(2).pow(2)) / u_emb.shape[0]
        return loss + self.reg * reg

In [None]:
import os
from typing import Optional, Tuple, Iterable, List, Dict
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


class InteractionDataset(Dataset):
    """
    Dataset that yields dense user vectors (torch.float32).
    Accepts a scipy.sparse CSR matrix or a numpy array of shape (n_users, n_items).
    """
    def __init__(self, user_item_matrix):
        # Accept CSR matrix or numpy array
        if hasattr(user_item_matrix, "toarray") and hasattr(user_item_matrix, "tocsr"):
            self.mat = user_item_matrix.tocsr()
            self.is_sparse = True
            self.n_users, self.n_items = self.mat.shape
        else:
            self.mat = np.asarray(user_item_matrix, dtype=np.float32)
            self.is_sparse = False
            self.n_users, self.n_items = self.mat.shape

    def __len__(self):
        return self.n_users

    def __getitem__(self, idx):
        if self.is_sparse:
            row = self.mat.getrow(idx).toarray().astype(np.float32).squeeze(0)
            return torch.from_numpy(row)
        else:
            return torch.from_numpy(self.mat[idx].astype(np.float32))


class MultiVAE(BaseModel, nn.Module):
    """
    MultiVAE model (variational autoencoder for collaborative filtering).
    - encoder: two-layer MLP -> mu, logvar
    - decoder: linear mapping from latent z to item logits
    """
    def __init__(
        self,
        config: dict,
        n_users: int,
        n_items: int,
        hidden_dim: int = 600,
        latent_dim: int = 200,
        dropout=0.5,
        device: Optional[torch.device] = None,
    ):
        BaseModel.__init__(self, config, n_users, n_items)
        nn.Module.__init__(self)

        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        self.n_users = n_users
        self.n_items = n_items

        # model dims
        self.hidden_dim = config.get("vae_hidden_dim", hidden_dim)
        self.latent_dim = config.get("vae_latent_dim", latent_dim)
        self.dropout = config.get("vae_dropout", dropout)

        # encoder
        self.encoder_fc1 = nn.Linear(n_items, self.hidden_dim)
        self.encoder_fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.mu_layer = nn.Linear(self.hidden_dim, self.latent_dim)
        self.logvar_layer = nn.Linear(self.hidden_dim, self.latent_dim)

        # decoder (linear + bias to produce logits over items)
        self.decoder = nn.Linear(self.latent_dim, n_items, bias=True)

        self.act = nn.Tanh()
        self.drop = nn.Dropout(self.dropout)

        # initialization
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

        self.to(self.device)

    # ---------- VAE ops ----------
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x: (B, n_items) input user vectors
        returns mu, logvar (each (B, latent_dim))
        """
        h = self.drop(self.act(self.encoder_fc1(x)))
        h = self.drop(self.act(self.encoder_fc2(h)))
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        return mu, logvar

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Returns logits for items: (B, n_items)
        """
        logits = self.decoder(z)
        return logits

    def forward(self, x: torch.Tensor, sample: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        One forward pass:
        - encodes x -> mu, logvar
        - draws z (reparameterization)
        - decodes logits
        Returns: logits, mu, logvar
        """
        mu, logvar = self.encode(x)
        if self.training and sample:
            z = self.reparameterize(mu, logvar)
        else:
            z = mu  # use mean for deterministic inference
        logits = self.decode(z)
        return logits, mu, logvar

    # ---------- Loss / ELBO ----------
    def loss_function(
        self,
        logits: torch.Tensor,
        input_batch: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        anneal: float = 1.0,
    ) -> Tuple[torch.Tensor, float, float]:
        """
        Multinomial likelihood version used in MultiVAE:
        recon_loss = - sum_j x_j * log_softmax(logits)_j

        KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
        total_loss = recon_loss + anneal * KL
        Returns (loss, recon_loss_scalar, kl_scalar)
        """
        # reconstruction: log softmax and weighted by counts
        log_softmax = F.log_softmax(logits, dim=1)  # (B, n_items)
        # input_batch may be counts (0/1 or counts). multiply and sum per sample
        recon = -torch.sum(log_softmax * input_batch, dim=1)  # (B,)
        recon_loss = torch.mean(recon)

        # KL
        kl = -0.5 * torch.sum(1.0 + logvar - mu.pow(2) - logvar.exp(), dim=1)  # (B,)
        kl_loss = torch.mean(kl)

        loss = recon_loss + anneal * kl_loss
        return loss, recon_loss.item(), kl_loss.item()


    # ---------- Embeddings helper (for ranking similar to LightGCN propagate) ----------
    def get_user_latent(self, x: torch.Tensor) -> torch.Tensor:
        """
        Return deterministic latent mu for x (no sampling).
        x: (B, n_items)
        returns mu: (B, latent_dim)
        """
        mu, logvar = self.encode(x)
        return mu

    # ---------- Prediction / Ranking (BaseModel methods) ----------
    def predict(self, user_vectors: torch.Tensor) -> np.ndarray:
        """
        Predict item scores for a batch of user vectors.
        user_vectors: torch.Tensor (B, n_items) on cpu or device
        returns: numpy array (B, n_items) of logits (higher = more recommended)
        """
        self.eval()
        with torch.no_grad():
            user_vectors = user_vectors.to(self.device)
            logits, mu, logvar = self.forward(user_vectors, sample=False)
            # Return logits (no softmax) so they can be sorted. move to cpu
            return logits.detach().cpu().numpy()

    def get_full_sort_items(self, user_id: int, user_item_matrix, seen_items: Optional[Iterable[int]] = None, top_k: Optional[int] = None) -> List[int]:
        """
        Rank all items for a given user.
        - user_item_matrix: the full user-item matrix (csr or numpy) to extract the user's vector
        - seen_items: optional iterable to mask (push to bottom)
        """
        # build user vector
        if hasattr(user_item_matrix, "getrow"):
            vec = user_item_matrix.getrow(user_id).toarray().astype(np.float32)
            user_vec = torch.from_numpy(vec).to(self.device)
        else:
            user_vec = torch.from_numpy(np.asarray(user_item_matrix[user_id], dtype=np.float32)).to(self.device)

        scores = self.predict(user_vec.unsqueeze(0)).ravel()  # (n_items,)

        if seen_items is not None:
            seen = [i for i in seen_items if 0 <= i < self.n_items]
            scores[np.array(seen, dtype=np.int64)] = -np.inf

        order = np.argsort(-scores)
        if top_k is not None:
            order = order[:top_k]
        return order.tolist()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


class FMModel(BaseModel, nn.Module):
    def __init__(self, config, n_users, n_items, n_factors=16):
        BaseModel.__init__(self, config, n_users, n_items)
        nn.Module.__init__(self)
        self.n_users = n_users
        self.n_items = n_items
        self.n_factors = n_factors

        # Linear terms
        self.user_bias = nn.Embedding(n_users, 1)
        self.item_bias = nn.Embedding(n_items, 1)

        # Factorization embeddings
        self.user_embedding = nn.Embedding(n_users, n_factors)
        self.item_embedding = nn.Embedding(n_items, n_factors)

        # Global bias
        self.global_bias = nn.Parameter(torch.zeros(1))

        # Init
        nn.init.normal_(self.user_embedding.weight, std=0.01)
        nn.init.normal_(self.item_embedding.weight, std=0.01)
        nn.init.zeros_(self.user_bias.weight)
        nn.init.zeros_(self.item_bias.weight)

    def forward(self, user_ids, item_ids):
        user_vec = self.user_embedding(user_ids)
        item_vec = self.item_embedding(item_ids)

        linear_terms = self.user_bias(user_ids) + self.item_bias(item_ids)
        interaction = torch.sum(user_vec * item_vec, dim=1, keepdim=True)

        score = self.global_bias + linear_terms + interaction
        return score.view(-1)

    def get_full_sort_items(self, user_id):
        device = next(self.parameters()).device
        user_id_tensor = torch.tensor([user_id], device=device)
        item_ids = torch.arange(self.n_items, device=device)
        user_id_expand = user_id_tensor.expand(self.n_items)
        scores = self.forward(user_id_expand, item_ids)
        return torch.argsort(scores, descending=True).tolist()

    def _sort_full_items(self, user_id):
        return self.get_full_sort_items(user_id)

In [None]:
from __future__ import annotations
import math
import random
from typing import List, Dict, Tuple, Iterable

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

# ------------------------------
# Utility: set all seeds
# ------------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def sasrec_pointwise_step(model, batch, device, logit_clip=20.0):
    users, seqs, pos_items, neg_items, mask = batch
    seqs, pos_items, neg_items = seqs.to(device), pos_items.to(device), neg_items.to(device)
    mask = mask.to(device).bool()

    seq_out = model(seqs)
    item_emb = model.item_embedding.weight

    pos_logits = torch.sum(seq_out * item_emb[pos_items], dim=-1).clamp(-logit_clip, logit_clip)
    neg_logits = torch.sum(seq_out * item_emb[neg_items], dim=-1).clamp(-logit_clip, logit_clip)

    valid_mask = (pos_items > 0) & mask
    if valid_mask.sum() == 0:
        return torch.tensor(0.0, device=device, requires_grad=True)

    loss_pos = F.binary_cross_entropy_with_logits(pos_logits[valid_mask], torch.ones_like(pos_logits[valid_mask]))
    loss_neg = F.binary_cross_entropy_with_logits(neg_logits[valid_mask], torch.zeros_like(neg_logits[valid_mask]))

    loss = loss_pos + loss_neg

    # L2 reg, skip padding embedding
    if model.l2_emb > 0:
        loss += model.l2_emb * (item_emb[1:].norm(p=2) ** 2) / 2

    return loss

# ------------------------------
# Dataset & Collate
# ------------------------------
class SASRecDataset(Dataset):
    """Builds sequences for SASRec training.

    Args:
        user2items: dict mapping user_id -> list of interacted item ids (in time order)
        n_items: total number of items (max ID)
        max_seq_len: truncate/pad sequences to this length
        min_seq_len: smallest effective length (>=2 to have a next item)
    """
    def __init__(self,
                 user2items: Dict[int, List[int]],
                 n_items: int,
                 max_seq_len: int = 50,
                 min_seq_len: int = 2):
        self.user2items = user2items
        self.users = [u for u, seq in user2items.items() if len(seq) >= min_seq_len]
        self.n_items = n_items
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx: int):
        user = self.users[idx]
        full = self.user2items[user]
        # Truncate to latest max_seq_len
        seq = full[-self.max_seq_len:]
        return user, seq


def _pad_sequence(seq: list, max_len: int) -> list:
    """Left-pad sequence with 0 to max_len."""
    seq = seq[-max_len:]
    return [0] * (max_len - len(seq)) + seq

def _build_pos_items(seq: list) -> list:
    """Next-item targets, last position padded with 0."""
    return seq[1:] + [0]


def _sample_negatives(seq: list, n_items: int) -> list:
    """Sample negatives for each position, avoiding seq items."""
    user_set = set(seq)
    negatives = []
    for _ in range(len(seq)):
        neg = random.randint(1, n_items - 1)  # avoid 0
        while neg in user_set:
            neg = random.randint(1, n_items - 1)
        negatives.append(neg)
    return negatives


def sasrec_collate(batch, n_items: int, max_seq_len: int):
    users, seqs = zip(*batch)
    seqs = [_pad_sequence(s, max_seq_len) for s in seqs]

    pos_items = [_build_pos_items(s) for s in seqs]
    neg_items = [_sample_negatives(s, n_items) for s in seqs]
    mask = [[1 if x != 0 else 0 for x in s] for s in seqs]

    return (
        torch.tensor(users, dtype=torch.long),
        torch.tensor(seqs, dtype=torch.long),
        torch.tensor(pos_items, dtype=torch.long),
        torch.tensor(neg_items, dtype=torch.long),
        torch.tensor(mask, dtype=torch.float),
    )


def build_user2items(train_data):
    user2items = defaultdict(list)
    for u, i, r in sorted(train_data, key=lambda x: (x[0], x[2])):
        user2items[u].append(i)
    return user2items


# ------------------------------
# Model
# ------------------------------
class SASRec(nn.Module, BaseModel):
    def __init__(self, config, n_users: int, n_items: int):
        nn.Module.__init__(self)  # initialize nn.Module
        BaseModel.__init__(self, config, n_users, n_items)  # keep BaseModel logic

        self.n_items = n_items
        self.hidden_units = int(config.get("hidden_units", 128))
        self.max_seq_len = int(config.get("max_seq_len", 50))
        self.num_heads = int(config.get("num_heads", 2))
        self.num_blocks = int(config.get("num_blocks", 2))
        self.dropout_rate = float(config.get("dropout_rate", 0.2))
        self.l2_emb = float(config.get("l2_emb", 0.0))

        # Embeddings
        self.item_embedding = nn.Embedding(n_items + 1, self.hidden_units, padding_idx=0)
        self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_units)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=self.hidden_units,
                nhead=self.num_heads,
                dim_feedforward=self.hidden_units * 4,
                dropout=self.dropout_rate,
                activation="gelu",
                batch_first=True,
            ) for _ in range(self.num_blocks)
        ])
        self.dropout = nn.Dropout(self.dropout_rate)
        self.layer_norm = nn.LayerNorm(self.hidden_units)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.normal_(self.item_embedding.weight, std=0.02)
        nn.init.normal_(self.position_embedding.weight, std=0.02)

    def _causal_mask(self, L: int, device=None):
        # True means masked in PyTorch
        mask = torch.triu(torch.ones(L, L, device=device), diagonal=1).bool()
        return mask

    def forward(self, item_seq: torch.Tensor) -> torch.Tensor:
        """
        Encode sequence safely, with padding & causal masks.
        Args:
            item_seq: (B, L) padded with 0 on left
        Returns:
            seq_out: (B, L, H)
        """
        B, L = item_seq.shape
        device = item_seq.device
        pos_ids = torch.arange(L, device=device).unsqueeze(0).expand(B, L)

        # Embeddings
        x = self.item_embedding(item_seq) + self.position_embedding(pos_ids)
        x = self.layer_norm(x)
        x = self.dropout(x)

        key_padding_mask = (item_seq == 0)  # True = pad
        attn_mask = self._causal_mask(L, device=device)  # True = masked

        for blk in self.blocks:
            x = blk(
            x,
            src_mask=attn_mask,
            src_key_padding_mask=key_padding_mask
            )
            # Safety: replace NaN/inf
            if torch.isnan(x).any() or torch.isinf(x).any():
               x = torch.nan_to_num(x, nan=0.0, posinf=1e4, neginf=-1e4)

        return x

    @torch.no_grad()
    def get_full_sort_items(self, user_id, user_seq: torch.Tensor):
        """Return sorted item IDs by score for a single user.
        Args:
            user_seq: (1, L)
        Returns:
            torch.Tensor of sorted item ids (desc)
        """
        self.eval()
        seq_out = self.forward(user_seq)[:, -1, :]  # (1, H)
        all_item_emb = self.item_embedding.weight  # (n_items+1, H)
        scores = torch.matmul(seq_out, all_item_emb.t()).squeeze(0)  # (n_items+1,)
        # Avoid recommending padding id 0
        scores[0] = -1e9
        return torch.argsort(scores, descending=True)

    def _sort_full_items(self, user_id, *args, **kwargs):
        raise NotImplementedError("Use get_full_sort_items(user_id, user_seq)")

In [None]:
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import random
from collections import Counter
from sklearn.model_selection import train_test_split
from scipy.stats import spearmanr
from torch.utils.data import DataLoader


class Recommender:
    """
    Recommender System class
    """

    def __init__(self, config, logger, data):
        self.data = data
        self.config = config
        self.logger = logger
        self.page_size = config["page_size"]
        self.items_per_page = config["items_per_page"]
        self.random_k = config["rec_random_k"]
        self.train_data = []
        self.n_layers = 3
        self.embedding_dim = 4
        if config["rec_model"] == "MF":
           self.model = MF(config, self.data.get_user_num(), self.data.get_item_num())
        elif config["rec_model"] == "LightGCN":
           self.model = LightGCN(config, self.data.get_user_num(), self.data.get_item_num(), self.train_data, n_layers=self.n_layers, embedding_dim=self.embedding_dim)
        elif config["rec_model"] == "SASRec":
           self.model = SASRec(config, self.data.get_user_num(), self.data.get_item_num())
        elif config["rec_model"] == "MultiVAE":
           self.model = MultiVAE(config, self.data.get_user_num(), self.data.get_item_num())
        elif config["rec_model"] == "FM":
           self.model = FMModel(config, self.data.get_user_num(), self.data.get_item_num())
        else:
           raise ValueError(f"Unknown model: {config['rec_model']}")

        self.criterion = nn.MSELoss()
        if config["rec_model"] != "FM":
           self.optimizer = optim.Adam(self.model.parameters(), lr=config["lr"], weight_decay=1e-5)

        self.epoch_num = config["epoch_num"]
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.record = {}
        self.round_record = {}
        self.positive = {}
        self.interaction_dict = {}
        self.inter_df = None
        self.inter_num = 0
        for user in self.data.get_full_users():
            self.record[user] = []
            self.positive[user] = []
            self.round_record[user] = []
        self.user_data = {
            "user": [],
            "N_expose": [],
            "N_view": [],
            "N_like": [],
            "N_exit": [],
            "S_sat": []
            }
        self.rating_feeling = {
            "User": [],
            "Rating": [],
            "Feelings": []
        }

    def sample_bpr_triples(self,
                           user_pos_items: List[List[int]],
                           n_items: int,
                           batch_size: int,
                           device: torch.device,
                           ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        user_pos_items: list of lists; for each user u, a list of items they've interacted with (positive)
        returns (users, pos_items, neg_items) tensors of length batch_size
        """
        users = []
        pos = []
        neg = []
        for _ in range(batch_size):
           # sample a user with at least one positive
           while True:
               u = np.random.randint(0, len(user_pos_items))
               if len(user_pos_items[u]) > 0:
                  break

           i = np.random.choice(user_pos_items[u])
           # sample a negative item

           while True:
               j = np.random.randint(0, n_items)
               if j not in user_pos_items[u]:
                  break

           users.append(u)
           pos.append(i)
           neg.append(j)

        return (
        torch.tensor(users, dtype=torch.long, device=self.device),
        torch.tensor(pos, dtype=torch.long, device=self.device),
        torch.tensor(neg, dtype=torch.long, device=self.device),
        )

    def train_lightgcn_bpr(self,
        reg: float = 1e-4,
        log: bool = True):
        """
        Train LightGCN with BPR loss.
        - train_interactions/val_interactions can be (u,i) or (u,i,rating>0) tuples.
        - If ckpt_path is provided, saves the best (by simple val recall proxy) state dict.
        """

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config["lr"])
        self.criterion = BPRLoss(reg=reg)

        # Build user -> positives list
        user_pos = [[] for _ in range(self.model.n_users)]
        for u, i, *rest in self.train_data:
            if 0 <= u < self.model.n_users and 0 <= i < self.model.n_items:
               user_pos[u].append(i)

        # Basic validation proxy: count of positives ranked in top-10 (cheap & optional)
        def quick_val_topk_hits(k: int = 10) -> float:
            if self.val_data is None:
               return -1.0
            # build val positives per user
            val_pos = [[] for _ in range(self.model.n_users)]
            for u, i, *rest in self.val_data:
               if 0 <= u < self.model.n_users and 0 <= i < self.model.n_items:
                   val_pos[u].append(i)

            hits = 0
            total = 0

            for u in range(self.model.n_users):
               if not val_pos[u]:
                  continue
               recs = self.model.get_full_sort_items(u, seen_items=set(user_pos[u]), top_k=k)
               s = set(val_pos[u])
               hits += len([r for r in recs if r in s])
               total += min(k, len(s))
            return hits / total if total > 0 else -1.0

        best_metric = -math.inf

        # Build full checkpoint file path
        ckpt_file = os.path.join(self.config['checkpoint_path'], "best_lightGCN_model.pth")
        os.makedirs(self.config['checkpoint_path'], exist_ok=True)

        for epoch in range(1, self.epoch_num + 1):
            self.model.train()

            # One epoch of mini-batch BPR
            n_steps = max(1, sum(len(v) for v in user_pos) // max(1, config['batch_size']))
            losses = []
            for _ in range(n_steps):
                users, pos_items, neg_items = self.sample_bpr_triples(user_pos, self.model.n_items, config['batch_size'], self.device)
                Eu, Ei = self.model.propagate()
                u_emb = Eu[users]
                p_emb = Ei[pos_items]
                n_emb = Ei[neg_items]
                loss = self.criterion(u_emb, p_emb, n_emb)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                losses.append(loss.item())

            # quick validation metric
            metric = quick_val_topk_hits(k=10)
            if log:
               print(f"[Epoch {epoch:3d}] BPR Loss: {np.mean(losses):.4f} | Val@10: {metric:.4f}")

            # Save checkpoint if validation improves
            if metric > best_metric:
               best_metric = metric
               torch.save({
                   "epoch": self.epoch_num + 1,
                   "model_state_dict": self.model.state_dict(),
                   "n_users": self.data.get_user_num(),
                   "n_items": self.data.get_item_num(),
                   "n_layers": self.n_layers,
                   "embedding_dim": self.embedding_dim,
                   "metric": metric,
                   }, ckpt_file)
               self.logger.info(f"Best model updated at epoch {epoch+1}, saved to {ckpt_file}")

        # Load best (optional)
        if ckpt_file is not None and best_metric > -math.inf:
           # At the end, reload the best weights for inference
           checkpoint = torch.load(ckpt_file)
           self.model.load_state_dict(checkpoint['model_state_dict'])

    def bce_sampled_loss(self, seq_out: torch.Tensor,
        pos_items: torch.Tensor,
        neg_items: torch.Tensor,
        item_embedding: nn.Embedding,
        mask: torch.Tensor,
        l2_emb: float = 0.0) -> torch.Tensor:
        """Binary cross-entropy on sampled positives/negatives per position.
        Args:
            seq_out: (B, L, H)
            pos_items: (B, L) next-item ids (0 where no target)
            neg_items: (B, L) sampled negatives (0 where no target)
            item_embedding: embedding module (to fetch item vectors)
            mask: (B, L) boolean, True where a target exists (i.e., pos_items > 0)
            l2_emb: weight decay on item embeddings (regularizes pos/neg lookups)
        """
        B, L, H = seq_out.shape

        pos_vecs = item_embedding(pos_items)  # (B, L, H)
        neg_vecs = item_embedding(neg_items)  # (B, L, H)

        pos_logits = (seq_out * pos_vecs).sum(-1)  # (B, L)
        neg_logits = (seq_out * neg_vecs).sum(-1)  # (B, L)

        # Targets: pos -> 1, neg -> 0
        pos_loss = F.binary_cross_entropy_with_logits(pos_logits[mask], torch.ones_like(pos_logits[mask]))
        neg_loss = F.binary_cross_entropy_with_logits(neg_logits[mask], torch.zeros_like(neg_logits[mask]))
        loss = pos_loss + neg_loss

        if l2_emb > 0:
           reg = (pos_vecs[mask].pow(2).sum() + neg_vecs[mask].pow(2).sum()) / mask.sum().clamp_min(1)
           loss = loss + l2_emb * reg
        return loss

    def train_sasrec(self, grad_clip=1.0, logit_clip=20.0):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Build user->items dict and dataset
        user2items = build_user2items(self.train_data)
        n_items_global = int(self.data.get_item_num())
        max_seq_len = self.model.max_seq_len

        train_dataset = SASRecDataset(user2items, n_items=n_items_global, max_seq_len=max_seq_len)
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            collate_fn=lambda batch: sasrec_collate(batch, n_items=n_items_global, max_seq_len=max_seq_len)
            )

        # Validation loader
        val_user2items = build_user2items(self.val_data)
        val_dataset = SASRecDataset(val_user2items, n_items=n_items_global, max_seq_len=max_seq_len)
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config['batch_size'],
            shuffle=False,
            collate_fn=lambda batch: sasrec_collate(batch, n_items=n_items_global, max_seq_len=max_seq_len)
            )

        # Checkpoint setup
        ckpt_file = os.path.join(self.config['checkpoint_path'], "best_SASRec_model.pth")
        os.makedirs(self.config['checkpoint_path'], exist_ok=True)
        best_metric = -float("inf")

        for epoch in range(1, self.epoch_num + 1):
           self.model.train()
           running = 0.0
           n_steps = 0

           for users, seqs, pos, neg, umask in train_loader:
              users, seqs, pos, neg, umask = (
                users.to(self.device),
                seqs.to(self.device),
                pos.to(self.device),
                neg.to(self.device),
                umask.to(self.device).bool(),  # ensure bool type
                )

              # Compute mask dynamically
              mask = pos > 0
              mask = mask.bool()
              if mask.sum() == 0:
                 continue  # skip batch with no valid positions

              # Compute stable loss
              loss = sasrec_pointwise_step(self.model, (users, seqs, pos, neg, mask), device=self.device, logit_clip=logit_clip)

              self.optimizer.zero_grad(set_to_none=True)
              loss.backward()
              if grad_clip is not None:
                 nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
              self.optimizer.step()

              running += loss.item()
              n_steps += 1

           avg_loss = running / max(1, n_steps)
           print(f"Epoch {epoch}/{self.epoch_num} - train loss: {avg_loss:.4f}")

           # Validation
           if val_loader is not None:
              self.model.eval()
              val_loss = 0.0
              n_val_steps = 0
              with torch.no_grad():
                  for users, seqs, pos, neg, umask in val_loader:
                    #  print("Users:", users)
                    #  print("Pos min/max:", pos.min().item(), pos.max().item())
                    #  print("Neg min/max:", neg.min().item(), neg.max().item())
                     users, seqs, pos, neg, umask = (
                         users.to(self.device),
                         seqs.to(self.device),
                         pos.to(self.device),
                         neg.to(self.device),
                         umask.to(self.device).bool(),)

                     mask = pos > 0
                     if mask.sum() == 0:
                        # print("Skipped empty batch")
                        continue

                     loss = sasrec_pointwise_step(self.model, (users, seqs, pos, neg, mask), device=self.device, logit_clip=logit_clip)
                    #  print("Batch loss:", loss.item())
                     val_loss += loss.item()
                     n_val_steps += 1
              val_loss /= max(1, n_val_steps)
              metric = -val_loss
              print(f"Epoch {epoch}, Val Loss: {val_loss:.4f}")
           else:
              metric = -avg_loss

           # Save best model
           if metric > best_metric:
              best_metric = metric
              torch.save({
                "epoch": epoch,
                "model_state_dict": self.model.state_dict(),
                "metric": metric,
              }, ckpt_file)
              print(f"Best model updated at epoch {epoch}, saved to {ckpt_file}")

        # Load best model after training
        checkpoint = torch.load(ckpt_file, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])


    def train_fm(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['lr'])
        self.criterion = nn.MSELoss()

        best_metric = -math.inf
        ckpt_file = os.path.join(self.config['checkpoint_path'], "best_FM_model.pth")
        os.makedirs(self.config['checkpoint_path'], exist_ok=True)

        # 🔹 Wrap datasets into DataLoaders
        train_loader = DataLoader(self.train_data, batch_size=self.config['batch_size'], shuffle=True)
        val_loader = None
        if self.val_data is not None:
           val_loader = DataLoader(self.val_data, batch_size=self.config['batch_size'], shuffle=False)


        for epoch in range(self.epoch_num):
           self.model.train()
           total_loss = 0

           for user, item, rating in train_loader:  # must be DataLoader
              user, item, rating = user.to(self.device), item.to(self.device), rating.float().to(self.device)

              self.optimizer.zero_grad()
              preds = self.model(user, item)
              loss = self.criterion(preds, rating)
              loss.backward()
              self.optimizer.step()

              total_loss += loss.item()

           avg_loss = total_loss / len(train_loader)
           print(f"Epoch {epoch+1}/{self.epoch_num}, Train Loss: {avg_loss:.4f}")

           metric = None
           if val_loader is not None:
              self.model.eval()
              with torch.no_grad():
                  val_loss = 0
                  for user, item, rating in val_loader:
                      user, item, rating = user.to(self.device), item.to(self.device), rating.to(self.device)
                      preds = self.model(user, item)
                      val_loss += self.criterion(preds, rating).item()
                  val_loss /= len(val_loader)

              metric = -val_loss
              print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}")
           else:
              metric = -avg_loss  # fallback if no validation set

           # save best model
           if metric > best_metric:
              best_metric = metric
              torch.save({
                "epoch": epoch+1,
                "model_state_dict": self.model.state_dict(),
                "n_users": self.model.n_users,
                "n_items": self.model.n_items,
                "n_factors": self.model.n_factors,
                "metric": metric,
                }, ckpt_file)
              print(f"Best model updated at epoch {epoch+1}, saved to {ckpt_file}")

        checkpoint = torch.load(ckpt_file, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])


    def train_multivae(self,
        weight_decay: float = 0.0,
        anneal_cap: float = 0.2,
        total_anneal_steps: int = 200000,
        patience: int = 100,
        verbose: bool = True,):
        """
        Train MultiVAE on user-item matrix (CSR or dense numpy).
        - anneal_cap: maximum beta for KL weighting
        - total_anneal_steps: number of optimization steps over which to ramp beta from 0->anneal_cap
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        dataset = InteractionDataset(self.train_matrix)
        loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, drop_last=False)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config['lr'], weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5)
        update_count = 0
        best_val_loss = float("inf")

        # Build full checkpoint file path
        ckpt_file = os.path.join(self.config['checkpoint_path'], "best_MultiVAE_model.pth")
        os.makedirs(self.config['checkpoint_path'], exist_ok=True)

        wait = 0

        # AMP scaler (for mixed precision training)
        scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

        # validation dataset (create once, not per epoch)
        val_loader, val_dataset = None, None
        if self.val_matrix is not None:
           val_dataset = InteractionDataset(self.val_matrix)
           val_loader = DataLoader(val_dataset, batch_size=self.config['batch_size'], shuffle=False)

        for epoch in range(1, self.epoch_num + 1):
            self.model.train()
            epoch_loss, epoch_recon, epoch_kl, n_batches = 0.0, 0.0, 0.0, 0

            for batch in loader:
                batch = batch.to(self.device).float()
                assert batch.shape[1] == self.model.n_items, "Batch dimension mismatch!"

                # anneal factor
                if total_anneal_steps > 0:
                   anneal = min(anneal_cap, update_count / total_anneal_steps)
                else:
                   anneal = anneal_cap

                self.optimizer.zero_grad()

                with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                    logits, mu, logvar = self.model(batch, sample=True)
                    # clamp logvar inside model (optional, numerical stability)
                    logvar = torch.clamp(logvar, min=-10, max=10)
                    loss, recon_l, kl_l = self.model.loss_function(logits, batch, mu, logvar, anneal)

                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                scaler.step(self.optimizer)
                scaler.update()

                epoch_loss += loss.item()
                epoch_recon += recon_l
                epoch_kl += kl_l
                update_count += 1
                n_batches += 1

            scheduler.step()
            avg_train_elbo = epoch_loss / n_batches

            if verbose:
               print(f"[Epoch {epoch}] Train ELBO: {avg_train_elbo:.4f} | "
                  f"Recon: {epoch_recon / n_batches:.4f} | KL: {epoch_kl / n_batches:.4f} | "
                  f"Anneal: {anneal:.4f}")

            # ---------- Validation ----------
            if val_loader is not None:
               self.model.eval()
               val_losses = []
               with torch.no_grad():
                   for vb in val_loader:
                      vb = vb.to(self.device).float()
                      logits, mu, logvar = self.model(vb, sample=False)
                      logvar = torch.clamp(logvar, min=-10, max=10)
                      vloss, vrec, vkl = self.model.loss_function(logits, vb, mu, logvar, anneal)
                      val_losses.append(vloss.item() * len(vb))

               val_loss = float(np.sum(val_losses) / len(val_dataset))

               if verbose:
                 print(f"  -> Val ELBO: {val_loss:.4f}")

               # save best with early stopping
               if val_loss < best_val_loss:
                  best_val_loss = val_loss
                  wait = 0
                  torch.save({
                    "epoch": epoch,
                    "model_state_dict": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "n_users": self.data.get_user_num(),
                    "n_items": self.data.get_item_num(),
                    "hidden_dim": self.config["vae_hidden_dim"],
                    "latent_dim": self.config["vae_latent_dim"],
                    "config": self.config,
                    }, ckpt_file)
                  self.logger.info(f"Best model updated at epoch {epoch}, saved to {ckpt_file}")
               else:
                  wait += 1
                  if wait >= patience:
                     print(f"Early stopping triggered at epoch {epoch}")
                     break

        # reload best model
        checkpoint = torch.load(ckpt_file, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def save_model(self, path):
        torch.save(self.model.state_dict(), path)

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))

    def swap_items(self, lst, page_size, random_k):
        total_pages = len(lst) // page_size
        lst = lst[: total_pages * page_size]
        for page in range(1, total_pages // 2 + 1):
            start_idx = (page - 1) * page_size
            end_idx = start_idx + page_size - 1
            symmetric_start_idx = (total_pages - page) * page_size
            symmetric_end_idx = symmetric_start_idx + page_size

            for k in range(1, random_k + 1):
                lst[end_idx - k], lst[symmetric_end_idx - k] = (
                    lst[symmetric_end_idx - k],
                    lst[end_idx - k],
                )

        return lst

    def add_random_items(self, user, item_ids):
        item_ids = self.swap_items(item_ids, self.page_size, self.random_k)
        return item_ids

    def ordered_probit_loglik(self, y_true, y_pred_int, K=5, taus=None):
        """
        Compute log-likelihood for ordered probit model given integer predictions.

        y_true : list or array
           True ratings (1..K).
        y_pred_int : list or array
           Predicted integer ratings (1..K).
        K : int
           Number of rating categories (default 5).
        taus : list or array, optional
           Thresholds (default: equally spaced).
        """

        assert len(y_true) == len(y_pred_int), "Mismatch in true vs predicted length"
        ll = 0.0
        for t, p in zip(y_true, y_pred_int):
           probs = self.data.ordered_probit_probs(p, K, taus)
           ll += self.data.safe_log(probs[t-1])  # subtract 1 for 0-based index
        avg_ll = ll / len(y_true)
        return ll, avg_ll

    def update_user_interactions(self, user_id, new_items):
        """
        Updates the directory of user_id and interacted_items.
        - interaction_dict: dict mapping user_id -> set of interacted item ids
        - user_id: int or str
        - new_items: iterable of item ids (list, set, etc)

        After calling, interaction_dict[user_id] contains all unique interacted items.
        """
        # Ensure the user's interaction set exists
        if user_id not in self.interaction_dict:
          self.interaction_dict[user_id] = set()

        new_items = set(new_items) - self.interaction_dict[user_id]
        self.interaction_dict[user_id].update(new_items)

    def get_full_manual_items(self, user_id, gt_ratio, rd_ratio, total_items, read = None, heard = None):
        """
        Get a list of manual items for a given user.
        """
        gtruth_items = self.data.interrating[user_id]
        gt_items = [item for item, rating in gtruth_items]
        rd_items = self.data.get_full_items()

        # Remove any gt_items from rd_items to avoid duplicates
        rd_items = list(set(rd_items) - set(gt_items))

        # 1. Determine counts
        total_ratio = gt_ratio + rd_ratio
        gt_count = round(total_items * gt_ratio / total_ratio)
        rd_count = total_items - gt_count

        # Make sure we don't try to sample more than available
        gt_count = min(gt_count, len(gt_items))
        rd_count = min(rd_count, len(rd_items))

        # 2. Randomly sample
        chosen_gt = random.sample(gt_items, gt_count) if gt_count > 0 else []
        chosen_rd = random.sample(rd_items, rd_count) if rd_count > 0 else []

        # 3. Combine and shuffle if desired
        final_items = chosen_gt + chosen_rd

        # items discriptions
        sorted_item_names = self.data.get_item_names(final_items)
        description = self.data.get_item_description_by_id(final_items)
        eb_item = [
            sorted_item_names[i]
            + ";;"
            + description[i]
            + ";; Genre: "
            + self.data.get_genres_by_id([final_items[i]])[0]
            for i in range(len(sorted_item_names))
        ]
        return final_items, eb_item, chosen_gt

    def get_full_sort_items(self, user, random=False):
        """
        Get a list of sorted items for a given user.
        """
        items = self.data.get_full_items()
        user_tensor = torch.tensor(user)
        items_tensor = torch.tensor(items)
        sorted_items = self.model.get_full_sort_items(user_tensor, items_tensor)
        if self.random_k > 0 and random == True:
            sorted_items = self.add_random_items(user, sorted_items)
        sorted_items = [item for item in sorted_items if item not in self.record[user]]
        sorted_item_names = self.data.get_item_names(sorted_items)
        description = self.data.get_item_description_by_id(sorted_items)
        items = [
            sorted_item_names[i]
            + ";;"
            + description[i]
            + ";; Genre: "
            + self.data.get_genres_by_id([sorted_items[i]])[0]
            for i in range(len(sorted_item_names))
        ]
        return sorted_items, items

    def get_item(self, idx):
        item_name = self.data.get_item_names([idx])[0]
        description = self.data.get_item_description_by_id([idx])[0]
        item = item_name + ";;" + description
        return item

    def get_search_items(self, item_name):
        return self.data.search_items(item_name)

    def get_inter_num(self):
        return self.inter_num

    def update_history_by_name(self, user_id, item_names):
        """
        Update the history of a given user.
        """
        item_names = [item_name.strip(" <>'\"") for item_name in item_names]
        item_ids = self.data.get_item_ids(item_names)
        self.record[user_id].extend(item_ids)

    def update_history_by_id(self, user_id, item_ids):
        """
        Update the history of a given user.
        """
        self.record[user_id].extend(item_ids)

    def update_positive(self, user_id, item_names):
        """
        Update the positive history of a given user.
        """
        item_ids = self.data.get_item_ids(item_names)
        if len(item_ids) == 0:
            return
        self.positive[user_id].extend(item_ids)
        self.inter_num += len(item_ids)

    def update_positive_by_id(self, user_id, item_id):
        """
        Update the history of a given user.
        """
        self.positive[user_id].append(item_id)

    def save_interaction(self):
        """
        Save the interaction history to a csv file.
        """
        inters = []
        users = self.data.get_full_users()
        for user in users:
            for item in self.positive[user]:
                new_row = {"user_id": user, "item_id": item, "rating": 1}
                inters.append(new_row)

            for item in self.record[user]:
                if item in self.positive[user]:
                    continue
                new_row = {"user_id": user, "item_id": item, "rating": 0}
                inters.append(new_row)

        df = pd.DataFrame(inters)
        df.to_csv(
            self.config["interaction_path"],
            index=False,
        )

        self.inter_df = df

    def add_train_data(self, user, item, label):
        self.train_data.append((user, item, label))

    def clear_train_data(self):
        self.train_data = []

    def add_user(self, user_id, N_expose, N_view, N_like, N_exit, S_sat):
        self.user_data["user"].append(user_id)
        self.user_data["N_expose"].append(N_expose)
        self.user_data["N_view"].append(N_view)
        self.user_data["N_like"].append(N_like)
        self.user_data["N_exit"].append(N_exit)
        self.user_data["S_sat"].append(S_sat)

    def add_review(self, user_id, rating, feelings):
        self.rating_feeling["User"].append(user_id)
        self.rating_feeling["Rating"].append(rating)
        self.rating_feeling["Feelings"].append(feelings)

    def satisfaction_metrics(self):
        sm_df = pd.DataFrame(self.user_data)
        if len(sm_df) == 0:
           return None  # no data yet

        metrics = {}
        sm_df["view_ratio"] = sm_df["N_view"] / sm_df["N_expose"]
        sm_df["like_ratio"] = sm_df["N_like"] / sm_df["N_expose"]

        metrics["P_view"] = sm_df["view_ratio"].mean()
        metrics["N_like"] = sm_df["N_like"].mean()
        metrics["P_like"] = sm_df["like_ratio"].mean()
        metrics["N_exit"] = sm_df["N_exit"].mean()
        metrics["S_sat"] = sm_df["S_sat"].mean()

        return metrics

    def get_entropy(
        self,
    ):
        tot_entropy = 0
        for user in self.record.keys():
            inters = self.record[user]
            genres = self.data.get_genres_by_id(inters)
            entropy = calculate_entropy(genres)
            tot_entropy += entropy

        return tot_entropy / len(self.record.keys())

    def check_train_data(self):
        """
        Print or inspect the training data.
        """
        print("Training Data:")
        for user, item, label in self.train_data:
            print(f"User: {user}, Item: {item}, Label: {label}")

    def create_train_data(self):
        """
        Create a training dataset with random samples.

        Args:
            num_samples (int): Number of samples to generate.
        """
        self.clear_train_data()  # Clear existing training data
        all_data = self.data.interrating  # You need to implement this or use available interaction data

        # Convert dict to list of (user, item, label)
        # triplets = []
        for user, interactions in all_data.items():
            for item, rating in interactions:
                self.add_train_data(user, item, float(rating))
                # triplets.append((user, item, float(rating)))  # keep exact rating

        # Split 80% train, 20% temp (to further split into val/test)
        self.train_data, self.temp_data = train_test_split(self.train_data, test_size=0.2, random_state=2025)

        # Split temp into 10% val and 10% test (from the total dataset)
        self.val_data, self.test_data = train_test_split(self.temp_data, test_size=0.5, random_state=2025)

        train_users = max([u for u, i, r in self.train_data]) + 1
        train_items = max([i for u, i, r in self.train_data]) + 1

        val_users = max([u for u, i, r in self.val_data]) + 1
        val_items = max([i for u, i, r in self.val_data]) + 1

        test_users = max([u for u, i, r in self.test_data]) + 1
        test_items = max([i for u, i, r in self.test_data]) + 1

        n_items_global = int(self.data.get_item_num())

        # Initialize user-item matrix
        self.train_matrix = np.zeros((train_users, n_items_global), dtype=np.float32)
        self.val_matrix = np.zeros((val_users, n_items_global), dtype=np.float32)
        self.test_matrix = np.zeros((test_users, n_items_global), dtype=np.float32)

        # Fill interactions safely
        for u, i, r in self.train_data:
           if i >= n_items_global: continue  # skip bad indices
           self.train_matrix[u, i] = r

        for u, i, r in self.val_data:
           if i >= n_items_global: continue
           self.val_matrix[u, i] = r

        for u, i, r in self.test_data:
           if i >= n_items_global: continue
           self.test_matrix[u, i] = r



    def calculate_user_metrics(
        self, user_id, sim_recommended, all_items, threshold = 3):
        """
        Evaluate precision, recall, (optionally real) accuracy, and F1 for a single user.

        Returns:
            dict: { 'precision': float, 'recall': float, 'accuracy': float, 'f1': float }
        """

        if user_id not in self.data.interrating:
           return {'precision': 0, 'recall': 0, 'accuracy': 0, 'f1': 0}

        ground_truth_pairs = self.data.interrating[user_id]
        gt_relevant = set(item for item, rating in ground_truth_pairs if rating >= threshold and item in all_items)
        sim_recommended = set(sim_recommended)
        all_items = set(all_items)

        TP = len(gt_relevant & sim_recommended)
        FP = len(sim_recommended - gt_relevant)
        FN = len(gt_relevant - sim_recommended)
        TN = len(all_items - (gt_relevant | sim_recommended))

        precision = TP / (TP + FP) if (TP + FP) else 0.0
        recall = TP / (TP + FN) if (TP + FN) else 0.0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0.0
        accuracy = (TP + TN) / len(all_items) if all_items else 0

        print("precision:", precision, "recall:", recall, "accuracy:", accuracy, "f1:", f1)
        return precision, recall, accuracy, f1

    def precisionandrecallk(
        self, user_id, recommended, k):
        if user_id not in self.data.interrating:
           return {'precision_at_k': 0, 'recall_at_k': 0}

        sim_recommended = list(dict.fromkeys(recommended))
        ground_truth_pairs = self.data.interrating[user_id]
        gt_relevant = set(item for item, rating in ground_truth_pairs if rating >= 3)
        recommended_at_k = sim_recommended[:k]
        hits = sum([1 for item in recommended_at_k if item in gt_relevant])
        precision_at_k = hits / k
        recall_at_k = hits / len(gt_relevant) if gt_relevant else 0
        return precision_at_k, recall_at_k


    def calculation_of_rating(self, user_id, item_names, book_rating):
        item_ids = self.data.get_item_ids([item_names])
        if user_id in self.data.interrating:
           # Check for item in user's ratings
           for (itm, rating) in self.data.interrating[user_id]:
               if itm == item_ids[0]:
                  return (rating, book_rating)

        # If not found
        return (0, book_rating)


    def calc_mse_rmse_rating_percentages(self, rating_pairs):

        print("Incoming rating_pairs:", rating_pairs[:20])  # show first 20 pairs
        print("Total pairs:", len(rating_pairs))

        # Remove pairs with zero in ground truth or predicted rating
        filtered_pairs = [(gt, pred) for gt, pred in rating_pairs
                          if int(gt) != 0]

        print("After filtering:", filtered_pairs[:20])
        print("Remaining pairs:", len(filtered_pairs))

        if not filtered_pairs:
           # No valid data after filtering
           return None, None, {}, {}, None, None, None

        # Convert ratings to int
        gt = [int(gt) for gt, pred in filtered_pairs]
        pred = [int(pred) for gt, pred in filtered_pairs]
        mse = np.mean([(g - p) ** 2 for g, p in zip(gt, pred)])
        rmse = np.sqrt(mse)
        loglike, ob_loglike = self.ordered_probit_loglik(gt, pred)
        rho, p_value = spearmanr(gt, pred)

        gt_count = Counter(gt)
        pred_count = Counter(pred)
        total = len(filtered_pairs)

        gt_pct = {r: gt_count.get(r, 0) / total * 100 for r in range(1, 6)}
        pred_pct = {r: pred_count.get(r, 0) / total * 100 for r in range(1, 6)}
        return mse, rmse, gt_pct, pred_pct, loglike, ob_loglike, rho


    def test_recommendations(self, user_id):
        # Get the full list of items
        all_items = self.data.get_full_items()

        # Convert the user ID to tensor
        user_tensor = torch.tensor(user_id)

        # Convert all items to tensor
        items_tensor = torch.tensor(all_items)

        # Get sorted items based on the model's prediction
        sorted_items = self.model.get_full_sort_items(user_tensor, items_tensor)

        # Filter out items that are already in the user's history
        recommended_items = [item for item in sorted_items if item not in self.record[user_id]]

        # Return the recommended items
        return recommended_items

    def evaluate(self, dataset):
        self.model.eval()
        users = torch.tensor([x[0] for x in dataset])
        items = torch.tensor([x[1] for x in dataset])
        labels = torch.tensor([x[2] for x in dataset]).float()

        with torch.no_grad():
             outputs = self.model(users, items)
             loss = self.criterion(outputs, labels)
        return loss.item()

    def load_checkpoint(self, path="best_model.pth", resume_training=False):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        if resume_training:
           # Load optimizer state to resume training exactly where it left off
           self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
           start_epoch = checkpoint["epoch"]
           self.logger.info(f"Resuming training from epoch {start_epoch}")
           return start_epoch
        else:
           self.model.eval()  # set to eval mode for inference

    def train_mf(self):
        if len(self.train_data) == 0:
            print("No training data!")
            return

        users = [x[0] for x in self.train_data]
        items = [x[1] for x in self.train_data]
        labels = [x[2] for x in self.train_data]


        dataset = torch.utils.data.TensorDataset(
        torch.tensor(users), torch.tensor(items), torch.tensor(labels))

        train_loader = torch.utils.data.DataLoader(
        dataset, batch_size=self.config["batch_size"], shuffle=True)

        self.model.train()

        best_val_loss = float("inf")

        # Build full checkpoint file path
        ckpt_file = os.path.join(self.config['checkpoint_path'], "best_MF_model.pth")
        os.makedirs(self.config['checkpoint_path'], exist_ok=True)

        for epoch in range(self.epoch_num):
            epoch_loss = 0.0

            for user, item, label in train_loader:

                self.optimizer.zero_grad()
                outputs = self.model(user, item)
                loss = self.criterion(outputs, label.float())
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()

            val_loss = self.evaluate(self.val_data)  # Evaluate on validation set

            self.logger.info(
            f"Epoch {epoch+1}/{self.epoch_num}, Train Loss: {epoch_loss/len(train_loader):.4f}, "
            f"Val Loss: {val_loss:.4f}")

            # Save checkpoint if validation improves
            if val_loss < best_val_loss:
               best_val_loss = val_loss
               torch.save({
                "epoch": epoch + 1,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "val_loss": val_loss,
                }, ckpt_file)
               self.logger.info(f"Best model updated at epoch {epoch+1}, saved to {ckpt_file}")

        # At the end, reload the best weights for inference
        checkpoint = torch.load(ckpt_file)
        self.model.load_state_dict(checkpoint['model_state_dict'])

    def load_best_model(self):
        if self.config['rec_model'] == 'MF':
           ckpt_file = os.path.join(self.config['checkpoint_path'], "best_MF_model.pth")
        elif self.config['rec_model'] == 'MultiVAE':
           ckpt_file = os.path.join(self.config['checkpoint_path'], "best_MultiVAE_model.pth")
        elif self.config['rec_model'] == 'LightGCN':
           ckpt_file = os.path.join(self.config['checkpoint_path'], "best_lightGCN_model.pth")
        elif self.config['rec_model'] == 'FM':
           ckpt_file = os.path.join(self.config['checkpoint_path'], "best_FM_model.pth")
        elif self.config['rec_model'] == 'SASRec':
           ckpt_file = os.path.join(self.config['checkpoint_path'], "best_SASRec_model.pth")
        else:
           raise ValueError(f"Unknown model type: {self.config['rec_model']}")

        # Build full checkpoint file path
        checkpoint = torch.load(ckpt_file)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        self.logger.info(f"Loaded best model from {ckpt_file}")

    def get_rec_discription(self, final_items):
        # items discriptions
        sorted_item_names = self.data.get_item_names(final_items)
        description = self.data.get_item_description_by_id(final_items)
        eb_item = [
            sorted_item_names[i]
            + ";;"
            + description[i]
            + ";; Genre: "
            + self.data.get_genres_by_id([final_items[i]])[0]
            for i in range(len(sorted_item_names))
        ]
        return eb_item

    def get_full_rankings(self, use_test=False, batch_size=512):
        """
        Compute full rankings for all users in self.data.
        - training items are pushed to the end
        - optionally, ground-truth test items can be put on top
        """
        if self.config['rec_model'] == 'MF':
           n_users = self.data.get_user_num()
           n_items = self.data.get_item_num()

           item_embed = self.model.item_embedding.weight[:n_items, :]

           self.full_rankings = np.zeros((n_users, n_items), dtype=int)

           for user in range(n_users):
              user_tensor = torch.tensor([user])
              user_embed = self.model.user_embedding(user_tensor)

              scores = torch.matmul(user_embed, item_embed.T).squeeze(0).detach().numpy()

              # Only consider valid item indices
              train_items = [x[1] for x in self.train_data if x[0] == user and x[1] < n_items]
              scores[train_items] = -np.inf

              self.full_rankings[user] = np.argsort(-scores)

              # # Optionally move ground-truth test items on top
              if use_test:
                 test_items = [x[1] for x in self.test_data if x[0] == user and x[1] < n_items]
                 for idx, item in enumerate(test_items):
                    if item in self.full_rankings[user]:
                       current_pos = np.where(self.full_rankings[user] == item)[0][0]
                       self.full_rankings[user][idx], self.full_rankings[user][current_pos] = (
                         self.full_rankings[user][current_pos],
                         self.full_rankings[user][idx])

        elif self.config['rec_model'] == 'LightGCN':
            n_users = self.data.get_user_num()
            n_items = self.data.get_item_num()

            # === 1. Get all user/item embeddings from LightGCN ===
            self.model.eval()
            with torch.no_grad():
                all_user_emb, all_item_emb = self.model.propagate()
                # shapes: (n_users, embed_dim), (n_items, embed_dim)

            self.full_rankings = np.zeros((n_users, n_items), dtype=int)

            for user in range(n_users):
               # Get user embedding
               user_embed = all_user_emb[user].unsqueeze(0)   # (1, embed_dim)

               # Compute scores for all items (dot product)
               scores = torch.matmul(user_embed, all_item_emb.T).squeeze(0).cpu().numpy()

               # Push training items to -inf
               train_items = [x[1] for x in self.train_data if x[0] == user and x[1] < n_items]
               scores[train_items] = -np.inf

               # Sort descending
               self.full_rankings[user] = np.argsort(-scores)

               # # Optionally move ground-truth test items on top
               if use_test:
                  test_items = [x[1] for x in self.test_data if x[0] == user and x[1] < n_items]
                  for idx, item in enumerate(test_items):
                     if item in self.full_rankings[user]:
                        current_pos = np.where(self.full_rankings[user] == item)[0][0]
                        self.full_rankings[user][idx], self.full_rankings[user][current_pos] = (
                           self.full_rankings[user][current_pos],
                           self.full_rankings[user][idx])

        elif self.config['rec_model'] == 'MultiVAE':
            # n_users = self.data.get_user_num()
            # n_items = self.data.get_item_num()
            if use_test:
               matrix = self.test_matrix
            else:
               matrix = self.train_matrix

            n_users, n_items = matrix.shape

            self.model.eval()
            self.full_rankings = np.zeros((n_users, n_items), dtype=int)

            with torch.no_grad():
                for start in range(0, n_users, batch_size):
                   end = min(start + batch_size, n_users)

                   # === 1. Build user input batch (interaction vectors) ===
                   batch_users = []
                   for u in range(start, end):
                      row = self.train_matrix[u]  # should return (n_items,) vector
                      batch_users.append(row)
                   batch_users = torch.tensor(batch_users, dtype=torch.float32).to(self.model.device)

                   # === 2. Forward pass through MultiVAE ===
                   logits, mu, logvar = self.model(batch_users)   # shape: (batch_size, n_items)
                   scores = logits.cpu().numpy()

                   # === 3. Postprocess each user in batch ===
                   for i, u in enumerate(range(start, end)):
                      user_scores = scores[i]

                      train_items = [x[1] for x in self.train_data if x[0] == u and x[1] < n_items]
                      user_scores[train_items] = -np.inf

                      # Sort items by descending score
                      self.full_rankings[u] = np.argsort(-user_scores)

                      # Optionally move ground-truth test items on top
                      if use_test:
                         test_items = [x[1] for x in self.test_data if x[0] == u and x[1] < n_items]
                         for idx, item in enumerate(test_items):
                            if item in self.full_rankings[u]:
                               current_pos = np.where(self.full_rankings[u] == item)[0][0]
                               self.full_rankings[u][idx], self.full_rankings[u][current_pos] = (
                                   self.full_rankings[u][current_pos],
                                   self.full_rankings[u][idx])

        elif self.config['rec_model'] == 'FM':
            n_users = self.data.get_user_num()
            n_items = self.data.get_item_num()

            self.full_rankings = np.zeros((n_users, n_items), dtype=int)

            self.model.eval()
            device = next(self.model.parameters()).device

            for user in range(n_users):
               # Generate all item IDs
               item_ids = torch.arange(n_items, device=device)
               user_ids = torch.full((n_items,), user, dtype=torch.long, device=device)

               # Compute scores using FM forward
               with torch.no_grad():
                  scores = self.model(user_ids, item_ids).cpu().numpy()

               # Push training items to the end
               train_items = [x[1] for x in self.train_data if x[0] == user and x[1] < n_items]
               scores[train_items] = -np.inf

               # Sort items by descending score
               ranking = np.argsort(-scores)

               # Optionally move ground-truth test items to the top
               if use_test:
                  test_items = [x[1] for x in self.test_data if x[0] == user and x[1] < n_items]
                  for idx, item in enumerate(test_items):
                     if item in ranking:
                        current_pos = np.where(ranking == item)[0][0]
                        ranking[idx], ranking[current_pos] = ranking[current_pos], ranking[idx]

               self.full_rankings[user] = ranking

        elif self.config['rec_model'] == 'SASRec':
            n_users = self.data.get_user_num()
            n_items = self.data.get_item_num()

            self.full_rankings = np.zeros((n_users, n_items), dtype=int)
            self.model.eval()
            device = next(self.model.parameters()).device

            with torch.no_grad():
               for start in range(0, n_users, batch_size):
                  end = min(start + batch_size, n_users)
                  batch_users = list(range(start, end))

                  # Build input sequences for batch
                  batch_seqs = []
                  for u in batch_users:
                     # Get user interaction sequence from train_data
                     user_items = [x[1] for x in self.train_data if x[0] == u]
                     padded_seq = _pad_sequence(user_items, self.model.max_seq_len)
                     batch_seqs.append(padded_seq)

                  batch_seqs = torch.tensor(batch_seqs, dtype=torch.long, device=device)

                  # Forward pass: get sequence embeddings
                  seq_out = self.model(batch_seqs)  # (B, L, H)
                  seq_out_last = seq_out[:, -1, :]  # use last position (B, H)

                  # All item embeddings
                  all_item_emb = self.model.item_embedding.weight[:n_items, :]  # (n_items, H)

                  # Compute scores
                  scores = torch.matmul(seq_out_last, all_item_emb.T)  # (B, n_items)
                  scores = scores.cpu().numpy()

                  # Mask training items
                  for i, u in enumerate(batch_users):
                     train_items = [x[1] for x in self.train_data if x[0] == u and x[1] < n_items]
                     scores[i, train_items] = -np.inf  # push train items to the end

                     ranking = np.argsort(-scores[i])  # full ranking by score (highest first)
                     if use_test:
                        test_items = [x[1] for x in self.test_data if x[0] == u and x[1] < n_items]
                        # Keep only test items that appear in ranking
                        test_items_in_ranking = [item for item in ranking if item in test_items]
                        # Take at most 5 test items
                        top_test_items = test_items_in_ranking[:5]
                        # Remaining items (exclude the ones we forced to the top)
                        other_items = [item for item in ranking if item not in top_test_items]

                        # New ranking: top test items first, then the rest in score order
                        ranking = np.array(top_test_items + other_items)
                     # Store final ranking
                     self.full_rankings[u] = ranking

In [None]:
from argparse import Namespace
import torch


class BaseRLEnvironment():
    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - max_step_per_episode
        - initial_temper
        '''
        parser.add_argument('--max_step_per_episode', type=int, default=100, help='max number of iteration allowed in each episode')
        parser.add_argument('--initial_temper', type=float, default=10, help='initial temper of users')
        return parser


    def __init__(self, args):
        self.device = args.device
        super().__init__()
        self.max_step_per_episode = args.max_step_per_episode
        self.initial_temper = args.initial_temper

    def reset(self, params):
        pass

    def step(self, action):
        pass


    def get_user_model(self, log_path, device, from_load = True):
        infile = open(log_path, 'r')
        class_args = eval(infile.readline()) # example: Namespace(model='KRMBUserResponse', reader='KRMBSeqReader')
        model_args = eval(infile.readline()) # model parameters in Namespace
        infile.close()
        checkpoint = torch.load(model_args.model_path + ".checkpoint", map_location=device)
        reader_stats = checkpoint["reader_stats"]
        modelClass = eval('{0}.{0}'.format(class_args.model))
        model = modelClass(model_args, reader_stats, device)
        if from_load:
            model.load_from_checkpoint(model_args.model_path, with_optimizer = False)
        model = model.to(device)
        return reader_stats, model, model_args

    def get_reader(self, log_path):
        infile = open(log_path, 'r')
        class_args = eval(infile.readline()) # example: Namespace(model='KRMBUserResponse', reader='KRMBSeqReader')
        training_args = eval(infile.readline()) # model parameters in Namespace
        training_args.val_holdout_per_user = 0
        training_args.test_holdout_per_user = 0
        training_args.device = self.device
        infile.close()
        # readerClass = eval('{0}.{0}'.format(class_args.reader))
        readerClass = eval(class_args.reader)
        reader = readerClass(training_args)
        return reader, training_args


    def get_observation_from_batch(self, sample_batch):
        '''
        extract observation from the reader's sample batch
        @input:
        - sample_batch: {
            'user_id': (B,)
            'uf_{feature}': (B,F_dim(feature)), user features
            'history': (B,max_H)
            'history_length': (B,)
            'history_if_{feature}': (B, max_H * F_dim(feature))
            'history_{response}': (B, max_H)
            ... user ground truth feedbacks are not included as observation
        }
        @output:
        - observation: {'user_profile': {'user_id': (B,),
                                         'uf_{feature_name}': (B, feature_dim)},
                        'user_history': {'history': (B, max_H),
                                         'history_if_{feature_name}': (B, max_H, feature_dim),
                                         'history_{response}': (B, max_H),
                                         'history_length': (B, )}}
        '''
        sample_batch = wrap_batch(sample_batch, device = self.device)
        profile = {'user_id': sample_batch['user_id']}
        for k,v in sample_batch.items():
            if 'uf_' in k:
                profile[k] = v
        history = {'history': sample_batch['history']}
        for k,v in sample_batch.items():
            if 'history_' in k:
                history[k] = v
        return {'user_profile': profile, 'user_history': history}


In [None]:
import numpy as np
import torch
import random
from copy import deepcopy
from argparse import Namespace
from torch.utils.data import DataLoader
from torch.distributions import Categorical
import torch.nn.functional as F


class KREnvironment_WholeSession_GPU(BaseRLEnvironment):
    '''
    KuaiRand simulated environment for consecutive list-wise recommendation
    Main interface:
    - parse_model_args: for hyperparameters
    - reset: reset online environment, monitor, and corresponding initial observation
    - step: action --> new observation, user feedbacks, and other updated information
    - get_candidate_info: obtain the entire item candidate pool
    Main Components:
    - data reader: self.reader for user profile&history sampler
    - user immediate response model: see self.get_response
    - no user leave model: see self.get_leave_signal
    - candidate item pool: self.candidate_ids, self.candidate_item_meta
    - history monitor: self.env_history, not set up until self.reset
    '''
    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - uirm_log_path
        - slate_size
        - episode_batch_size
        - item_correlation
        - single_response
        - from BaseRLEnvironment
            - max_step_per_episode
            - initial_temper
        '''
        parser = BaseRLEnvironment.parse_model_args(parser)
        parser.add_argument('--uirm_log_path', type=str, required=True,
                            help='log path for pretrained user immediate response model')
        parser.add_argument('--slate_size', type=int, required=6,
                            help='number of item per recommendation slate')
        parser.add_argument('--episode_batch_size', type=int, default=32,
                            help='episode sample batch size')
        parser.add_argument('--item_correlation', type=float, default=0,
                            help='magnitude of item correlation')
        parser.add_argument('--single_response', action='store_true',
                            help='only include the first feedback as reward signal')
        return parser


    def __init__(self, args):
        super(KREnvironment_WholeSession_GPU, self).__init__(args)
        self.uirm_log_path = args.uirm_log_path
        self.slate_size = args.slate_size
        self.episode_batch_size = args.episode_batch_size
        self.rho = args.item_correlation
        self.single_response = args.single_response

        # --- load logged model config ---
        with open(args.uirm_log_path, 'r') as infile:
            class_args = eval(infile.readline())
            model_args = eval(infile.readline())
        print("Environment arguments: \n" + str(model_args))
        assert (class_args.reader in ['KRMBSeqReader', 'MLSeqReader']
                and 'KRMBUserResponse' in class_args.model)

        # --- load reader ---
        print("Load user sequence reader")
        reader, reader_args = self.get_reader(args.uirm_log_path)
        self.reader = reader
        print(self.reader.get_statistics())

        # --- load user response model ---
        print("Load immediate user response model")
        uirm_stats, uirm_model, uirm_args = self.get_user_model(args.uirm_log_path, args.device)
        self.immediate_response_stats = uirm_stats
        self.immediate_response_model = uirm_model
        self.max_hist_len = uirm_stats['max_seq_len']
        self.response_types = uirm_stats['feedback_type']
        self.response_dim = len(self.response_types)
        self.response_weights = torch.tensor(
            list(self.reader.get_response_weights().values()),
            dtype=torch.float,
            device=args.device
        )
        if args.single_response:
            self.response_weights = torch.zeros_like(self.response_weights)
            self.response_weights[0] = 1

        print("Setup candidate item pool")

        # (n_item,)
        self.candidate_iids = torch.tensor(
            [reader.item_id_vocab[iid] for iid in reader.items],
            device=self.device
        )

        # build item meta from reader
        candidate_meta = [reader.get_item_meta_data(iid) for iid in reader.items]
        self.candidate_item_meta = {}
        self.n_candidate = len(candidate_meta)

        # NORMALIZE KEYS HERE: force them to start with "if_"
        for k in candidate_meta[0]:
            key = k if k.startswith("if_") else f"if_{k}"
            self.candidate_item_meta[key] = torch.FloatTensor(
                np.concatenate([meta[k] for meta in candidate_meta])
            ).view(self.n_candidate, -1).to(self.device)

        # now we can safely call the model the same way the original code did
        item_enc, _ = self.immediate_response_model.get_item_encoding(
            self.candidate_iids,
            {k[3:]: v for k, v in self.candidate_item_meta.items()},  # strip "if_"
            1
        )
        self.candidate_item_encoding = item_enc.view(-1, self.immediate_response_model.enc_dim)

        # spaces
        self.gt_state_dim = self.immediate_response_model.state_dim
        self.action_dim = self.slate_size
        self.observation_space = self.reader.get_statistics()
        self.action_space = self.n_candidate

        self.immediate_response_model.to(args.device)
        self.immediate_response_model.device = args.device

    def get_candidate_info(self, feed_dict, all_item=True):
        if all_item:
            cand = {'item_id': self.candidate_iids}
            cand.update(self.candidate_item_meta)   # all "if_..."
            return cand
        else:
            cand = {'item_id': feed_dict['item_id']}
            idx = feed_dict['item_id'] - 1
            cand.update({k: v[idx] for k, v in self.candidate_item_meta.items()})
            return cand

    def reset(self, params={'empty_history': True}):
        if 'empty_history' not in params:
            params['empty_history'] = False
        BS = params.get('batch_size', self.episode_batch_size)

        self.batch_iter = iter(DataLoader(self.reader, batch_size=BS, shuffle=True,
                                          pin_memory=True, num_workers=8))
        sample_info = next(self.batch_iter)
        self.sample_batch = self.get_observation_from_batch(sample_info)
        self.current_observation = self.sample_batch
        self.current_step = torch.zeros(self.episode_batch_size, device=self.device)
        self.current_sample_head_in_batch = BS

        self.current_temper = torch.ones(self.episode_batch_size, device=self.device) * self.initial_temper
        self.current_sum_reward = torch.zeros(self.episode_batch_size, device=self.device)

        self.env_history = {'step': [0.], 'leave': [], 'temper': [], 'coverage': [], 'ILD': []}

        return deepcopy(self.current_observation)


    def step(self, step_dict):
        '''
        users react to the recommendation action
        @input:
        - step_dict: {'action': (B, slate_size),
                      'action_features': (B, slate_size, item_dim) }
        @output:
        - new_observation: {'user_profile': {'user_id': (B,),
                                             'uf_{feature_name}': (B, feature_dim)},
                            'user_history': {'history': (B, max_H),
                                             'history_if_{feature_name}': (B, max_H, feature_dim),
                                             'history_{response}': (B, max_H),
                                             'history_length': (B, )}}
        - response_dict: {'immediate_response': see self.get_response@output - response_dict,
                          'done': (B,)}
        - update_info: see self.update_observation@output - update_info
        '''

        with torch.no_grad():
            action = step_dict['action']  # (B, slate_size)

            # ----- build batch for user model -----
            # notice we pass "if_..." keys here
            batch = {
                'item_id': self.candidate_iids[action],   # (B, slate_size)
                **self.current_observation['user_profile'],
                **self.current_observation['user_history'],
                **{k: v[action] for k, v in self.candidate_item_meta.items()}  # all "if_..."
            }

            response_dict = self.immediate_response_model(batch)
            response = response_dict['preds']  # (B, slate_size, n_feedback) named "probs" in your earlier code

            # some versions call it 'probs', some return 'preds'
            if 'probs' in response_dict:
               behavior_scores = response_dict['probs']
            else:
               behavior_scores = response_dict['preds']   # (B, slate_size, n_feedback)

            # sample / binarize feedback
            sampled_feedback = (torch.sigmoid(behavior_scores) > 0.5).float()

            # done + temper update uses immediate_response
            done_mask = self.get_leave_signal(
            None,
            None,
            {'immediate_response': sampled_feedback}
            )

            response_out = {
                'immediate_response': (torch.sigmoid(behavior_scores) > 0.5).float(),
                'done': done_mask,
                'coverage': len(torch.unique(action)),
                'ILD': 0.0,  # fill with your ILD if you need it
            }

            update_info = self.update_observation(None,
                                                  action,
                                                  response_out['immediate_response'],
                                                  done_mask)

            # env_history update: step, leave, temper, converage, ILD
            self.current_step += 1
            n_leave = done_mask.sum()
            self.env_history['leave'].append(n_leave.item())
            self.env_history['temper'].append(torch.mean(self.current_temper).item())
            self.env_history['coverage'].append(response_out['coverage'])
            self.env_history['ILD'].append(response_out['ILD'])

        # when users left, new users come into the running batch
        if n_leave > 0:
            final_steps = self.current_step[done_mask].detach().cpu().numpy()
            for fst in final_steps:
                self.env_history['step'].append(fst)

            if self.current_sample_head_in_batch + n_leave < self.episode_batch_size:
                # reuse previous batch if there are sufficient samples for n_leave
                head = self.current_sample_head_in_batch
                tail = self.current_sample_head_in_batch + n_leave
                for obs_key in ['user_profile', 'user_history']:
                    for k, v in self.sample_batch[obs_key].items():
                        src = v[head:tail]              # new users
                        dst = self.current_observation[obs_key][k]
                        # 🔧 if dst is flattened but src is 3D, flatten src to match
                        if dst.dim() == 2 and src.dim() == 3:
                            B, H, D = src.shape
                            src = src.view(B, H * D)
                        self.current_observation[obs_key][k][done_mask] = src
                self.current_sample_head_in_batch += n_leave
            else:
                # sample new users to fill in the blank
                sample_info = self.sample_new_batch_from_reader()
                self.sample_batch = self.get_observation_from_batch(sample_info)
                for obs_key in ['user_profile', 'user_history']:
                    for k, v in self.sample_batch[obs_key].items():
                        src = v[:n_leave]
                        dst = self.current_observation[obs_key][k]
                        # 🔧 same fix here
                        if dst.dim() == 2 and src.dim() == 3:
                            B, H, D = src.shape
                            src = src.view(B, H * D)
                        self.current_observation[obs_key][k][done_mask] = src
                self.current_sample_head_in_batch = n_leave

            self.current_step[done_mask] *= 0
            self.current_temper[done_mask] *= 0
            self.current_temper[done_mask] += self.initial_temper
        else:
            self.env_history['step'].append(self.env_history['step'][-1])

        return deepcopy(self.current_observation), response_out, update_info


    def get_response(self, step_dict):
        '''
        @input:
        - step_dict: {'action': (B, slate_size)}
        @output:
        - response_dict: {'immediate_response': (B, slate_size, n_feedback),
                          'user_state': (B, gt_state_dim),
                          'coverage': scalar,
                          'ILD': scalar}
        '''
        # actions (exposures), (B, slate_size), indices of self.candidate_iid
        action = step_dict['action']
        coverage = len(torch.unique(action))
        B = self.episode_batch_size

        ########################################
        # This is where the action take effect #
        # (B, action_dim, 1, enc_dim)
        batch = {'item_id': self.candidate_iids[action]}
        batch.update(self.current_observation['user_profile'])
        batch.update(self.current_observation['user_history'])
        batch.update({k:v[action] for k,v in self.candidate_item_meta.items()})
        out_dict = self.immediate_response_model(batch)
        ########################################

        # (B, slate_size, n_feedback)
        behavior_scores = out_dict['probs']

        # (B, slate_size, item_dim)
        item_enc = self.candidate_item_encoding[action].view(B, self.slate_size, -1)
        item_enc_norm = F.normalize(item_enc, p = 2.0, dim = -1)
        # (B, slate_size)
        corr_factor = self.get_intra_slate_similarity(item_enc_norm)

        # user response sampling
        # (B, slate_size, n_feedback)
        point_scores = torch.sigmoid(behavior_scores) - corr_factor.view(B, self.slate_size, 1) * self.rho
        point_scores[point_scores < 0] = 0

        # (B, slate_size, n_feedback)
        response = torch.bernoulli(point_scores).detach()

        return {'immediate_response': response,
                'user_state': out_dict['state'],
                'coverage': coverage,
                'ILD': 1 - torch.mean(corr_factor).item()}

    def get_ground_truth_user_state(self, profile, history):
        batch_data = {}
        batch_data.update(profile)
        batch_data.update(history)
        gt_state_dict = self.immediate_response_model.encode_state(batch_data, self.episode_batch_size)
        gt_user_state = gt_state_dict['state'].view(self.episode_batch_size,1,self.gt_state_dim)
        return gt_user_state

    def get_intra_slate_similarity(self, action_item_encoding):
        '''
        @input:
        - action_item_encoding: (B, slate_size, enc_dim)
        @output:
        - similarity: (B, slate_size)
        '''
        B, L, d = action_item_encoding.shape
        # pairwise similarity in a slate (B, L, L)
        pair_similarity = torch.mean(action_item_encoding.view(B,L,1,d) * action_item_encoding.view(B,1,L,d), dim = -1)
        # similarity to slate average, (B, L)
        point_similarity = torch.mean(pair_similarity, dim = -1)
        return point_similarity

    def get_leave_signal(self, user_state, action, response_dict):
        '''
        User leave model maintains the user temper, and a user leaves when the temper drops below 1.
        @input:
        - user_state: not used in this env
        - action: not used in this env
        - response_dict: (B, slate_size, n_feedback)
        @process:
        - update temper
        @output:
        - done_mask:
        '''
        # (B, slate_size, n_feedback)
        point_reward = response_dict['immediate_response'] * self.response_weights.view(1,1,-1)
        # (B, slate_size)
        combined_reward = torch.sum(point_reward, dim = 2)
        # (B, )
        temper_boost = torch.mean(combined_reward, dim = 1)
        # temper update for leave model
        temper_update = temper_boost - 2
        temper_update[temper_update > 0] = 0
        temper_update[temper_update < -2] = -2
        self.current_temper += temper_update
        # leave signal
        done_mask = self.current_temper < 1
        return done_mask


    def update_observation(self, user_state, action, response, done_mask, update_current=True):
        # (B, slate_size), encoded item id
        rec_list = self.candidate_iids[action]
        old_history = self.current_observation['user_history']
        max_H = self.max_hist_len
        L = old_history['history_length'] + self.slate_size
        L[L > max_H] = max_H
        new_history = {
            'history': torch.cat((old_history['history'], rec_list), dim=1)[:, -max_H:],
            'history_length': L
        }
        # IMPORTANT: store history with "history_if_..." so model can read it
        for k, candidate_meta_features in self.candidate_item_meta.items():
            meta_features = candidate_meta_features[action]  # (B, slate_size, feat_dim)
            prev_meta = old_history[f'history_{k}'].view(self.episode_batch_size, max_H, -1)
            new_history[f'history_{k}'] = torch.cat((prev_meta, meta_features), dim=1)[:, -max_H:, :].view(
                self.episode_batch_size, -1
            )

        # feedback histories (your original loop can stay)
        for i, R in enumerate(self.immediate_response_model.feedback_types):
            k = f'history_{R}'
            new_history[k] = torch.cat((old_history[k], response[:, :, i]), dim=1)[:, -max_H:]

        if update_current:
            self.current_observation['user_history'] = new_history

        return {
            'slate': rec_list,
            'updated_observation': {
                'user_profile': deepcopy(self.current_observation['user_profile']),
                'user_history': deepcopy(new_history)
            }
        }

    def sample_new_batch_from_reader(self):
        '''
        @output
        - sample_info: see BaseRLEnvironment.get_observation_from_batch@input - sample_batch
        '''
        new_sample_flag = False
        try:
            sample_info = next(self.batch_iter)
            if sample_info['user_profile'].shape[0] != self.episode_batch_size:
                new_sample_flag = True
        except:
            new_sample_flag = True
        if new_sample_flag:
            self.batch_iter = iter(DataLoader(self.reader, batch_size = self.episode_batch_size, shuffle = True,
                                              pin_memory = True, num_workers = 8))
            sample_info = next(self.batch_iter)
        return sample_info

    def stop(self):
        self.batch_iter = None

    def get_new_iterator(self, B):
        return iter(DataLoader(self.reader, batch_size = B, shuffle = True,
                               pin_memory = True, num_workers = 8))


    def create_observation_buffer(self, buffer_size):
        '''
        @input:
        - buffer_size: L, scalar
        @output:
        - observation: {'user_profile': {'user_id': (L,),
                                         'uf_{feature_name}': (L, feature_dim)},
                        'user_history': {'history': (L, max_H),
                                         'history_if_{feature_name}': (L, max_H * feature_dim),
                                         'history_{response}': (L, max_H),
                                         'history_length': (L,)}}
        '''
        observation = {'user_profile': {'user_id': torch.zeros(buffer_size).to(torch.long).to(self.device)},
                       'user_history': {'history': torch.zeros(buffer_size, self.max_hist_len).to(torch.long).to(self.device),
                                        'history_length': torch.zeros(buffer_size).to(torch.long).to(self.device)}}
        for f,f_dim in self.observation_space['user_feature_dims'].items():
            observation['user_profile'][f'uf_{f}'] = torch.zeros(buffer_size, f_dim).to(torch.float).to(self.device)
        for f,f_dim in self.observation_space['item_feature_dims'].items():
            observation['user_history'][f'history_if_{f}'] = torch.zeros(buffer_size, f_dim * self.max_hist_len)\
                                                                                .to(torch.float).to(self.device)
        for f in self.observation_space['feedback_type']:
            observation['user_history'][f'history_{f}'] = torch.zeros(buffer_size, self.max_hist_len)\
                                                                                .to(torch.float).to(self.device)
        return observation

    def get_report(self, smoothness = 10):
        return {k: np.mean(v[-smoothness:]) for k,v in self.env_history.items()}

In [None]:
from matplotlib.pyplot import axes, axis
import torch
import torch.nn as nn


class BackboneUserEncoder(BaseModel):
    '''
    KuaiRand Multi-Behavior user response model
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - state_user_latent_dim
        - state_item_latent_dim
        - state_transformer_enc_dim
        - state_transformer_n_head
        - state_transformer_d_forward
        - state_transformer_n_layer
        - state_dropout_rate
        - from BaseModel:
            - model_path
            - loss
            - l2_coef
        '''
        parser = BaseModel.parse_model_args(parser)

        parser.add_argument('--state_user_latent_dim', type=int, default=16,
                            help='user latent embedding size')
        parser.add_argument('--state_item_latent_dim', type=int, default=16,
                            help='item latent embedding size')
        parser.add_argument('--state_transformer_enc_dim', type=int, default=32,
                            help='item encoding size')
        parser.add_argument('--state_transformer_n_head', type=int, default=4,
                            help='number of attention heads in transformer')
        parser.add_argument('--state_transformer_d_forward', type=int, default=64,
                            help='forward layer dimension in transformer')
        parser.add_argument('--state_transformer_n_layer', type=int, default=2,
                            help='number of encoder layers in transformer')
        parser.add_argument('--state_dropout_rate', type=float, default=0.1,
                            help='dropout rate in deep layers of state encoder')
        return parser

    def __init__(self, args, reader_stats):
        self.user_latent_dim = args.state_user_latent_dim
        self.item_latent_dim = args.state_item_latent_dim
        self.enc_dim = args.state_transformer_enc_dim
        self.state_dim = 3*self.enc_dim
        self.attn_n_head = args.state_transformer_n_head
        self.dropout_rate = args.state_dropout_rate
        super().__init__(args, reader_stats, args.device)

    def to(self, device):
        new_self = super(BackboneUserEncoder, self).to(device)
        new_self.attn_mask = new_self.attn_mask.to(device)
        new_self.pos_emb_getter = new_self.pos_emb_getter.to(device)
        return new_self

    def _define_params(self, args, reader_stats):
        stats = self.reader_stats

        self.user_feature_dims = stats['user_feature_dims'] # {feature_name: dim}
        self.item_feature_dims = stats['item_feature_dims'] # {feature_name: dim}

        # user embedding
        self.uIDEmb = nn.Embedding(stats['n_user']+1, args.state_user_latent_dim)
        self.uFeatureEmb = {}
        for f,dim in self.user_feature_dims.items():
            embedding_module = nn.Linear(dim, args.state_user_latent_dim)
            self.add_module(f'UFEmb_{f}', embedding_module)
            self.uFeatureEmb[f] = embedding_module

        # item embedding
        self.iIDEmb = nn.Embedding(stats['n_item']+1, args.state_item_latent_dim)
        self.iFeatureEmb = {}
        for f,dim in self.item_feature_dims.items():
            embedding_module = nn.Linear(dim, args.state_item_latent_dim)
            self.add_module(f'IFEmb_{f}', embedding_module)
            self.iFeatureEmb[f] = embedding_module

        # feedback embedding
        self.feedback_types = stats['feedback_type']
        self.feedback_dim = stats['feedback_size']
        self.feedbackEncoder = nn.Linear(self.feedback_dim, args.state_transformer_enc_dim)

        # item embedding kernel encoder
        self.itemEmbNorm = nn.LayerNorm(args.state_item_latent_dim)
        self.userEmbNorm = nn.LayerNorm(args.state_user_latent_dim)
        self.itemFeatureKernel = nn.Linear(args.state_item_latent_dim, args.state_transformer_enc_dim)
        self.userFeatureKernel = nn.Linear(args.state_user_latent_dim, args.state_transformer_enc_dim)
        self.encDropout = nn.Dropout(args.state_dropout_rate)
        self.encNorm = nn.LayerNorm(args.state_transformer_enc_dim)

        # positional embedding
        self.max_len = stats['max_seq_len']
        self.posEmb = nn.Embedding(self.max_len, args.state_transformer_enc_dim)
        self.pos_emb_getter = torch.arange(self.max_len, dtype = torch.long)
        self.attn_mask = ~torch.tril(torch.ones((self.max_len,self.max_len), dtype=torch.bool))

        # sequence encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=2*args.state_transformer_enc_dim,
                                                   dim_feedforward = args.state_transformer_d_forward,
                                                   nhead=args.state_transformer_n_head,
                                                   dropout = args.state_dropout_rate,
                                                   batch_first = True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.state_transformer_n_layer)

        # DNN state encoder

#         self.stateNorm = nn.LayerNorm(self.state_dim)
#         self.finalStateLayer = DNN(3*args.state_transformer_enc_dim, args.state_hidden_dims, self.state_dim,
#                                 dropout_rate = args.dropout_rate, do_batch_norm = True)

        #self.actionModule = torch.nn.Sigmoid(self.actionModule)

    def get_forward(self, feed_dict: dict):
        '''
        @input:
        - feed_dict: {
            'user_id': (B,)
            'uf_{feature_name}': (B,feature_dim), the user features
            'history': (B,max_H)
            'history_if_{feature_name}': (B,max_H,feature_dim), the history item features
        }
        @output:
        - out_dict: {'state': (B, state_dim),
                    'reg': scalar}
        '''
        B = feed_dict['user_id'].shape[0]
        # user encoding
        state_encoder_output = self.encode_state(feed_dict, B)
        # regularization terms
        reg = self.get_regularization(self.feedbackEncoder,
                                      self.itemFeatureKernel, self.userFeatureKernel,
                                      self.posEmb, self.transformer)
        reg = reg + state_encoder_output['reg']

        return {'state': state_encoder_output['state'],
                'reg': reg}

    def encode_state(self, feed_dict, B):
        '''
        @input:
        - feed_dict: {
            'user_id': (B,)
            'uf_{feature_name}': (B,feature_dim), the user features
            'history': (B,max_H)
            'history_if_{feature_name}': (B,max_H,feature_dim), the history item features
            ... (irrelevant input)
        }
        - B: batch size
        @output:
        - out_dict:{
            'out_seq': (B,max_H,2*enc_dim)
            'state': (B,n_feedback*enc_dim)
            'reg': scalar
        }
        '''
        # user history item encodings (B, max_H, enc_dim)
        history_enc, history_reg = self.get_item_encoding(feed_dict['history'],
                                             {f:feed_dict[f'history_if_{f}'] for f in self.iFeatureEmb}, B)
        history_enc = history_enc.view(B, self.max_len, self.enc_dim)

        # positional encoding (1, max_H, enc_dim)
        pos_emb = self.posEmb(self.pos_emb_getter).view(1,self.max_len,self.enc_dim)

        # feedback embedding (B, max_H, enc_dim)
        feedback_emb = self.get_response_embedding({f: feed_dict[f'history_{f}'] for f in self.feedback_types}, B)

        # sequence item encoding (B, max_H, enc_dim)
        seq_enc_feat = self.encNorm(self.encDropout(history_enc + pos_emb))
        # (B, max_H, 2*enc_dim)
        seq_enc = torch.cat((seq_enc_feat, feedback_emb), dim = -1)

        # transformer output (B, max_H, 2*enc_dim)
        output_seq = self.transformer(seq_enc, mask = self.attn_mask)

        # user history encoding (B, 2*enc_dim)
        hist_enc = output_seq[:,-1,:].view(B,2*self.enc_dim)

        # static user profile features
        # (B, enc_dim), scalar
        user_enc, user_reg = self.get_user_encoding(feed_dict['user_id'],
                                          {k[3:]:v for k,v in feed_dict.items() if k[:3] == 'uf_'}, B)
        # (B, enc_dim)
        user_enc = self.encNorm(self.encDropout(user_enc)).view(B,self.enc_dim)

        # user state (B, 3*enc_dim) combines user history and user profile features
        state = torch.cat([hist_enc,user_enc], 1)
        # (B, enc_dim)
#         state = self.stateNorm(self.finalStateLayer(state))
        return {'output_seq': output_seq, 'state': state, 'reg': user_reg + history_reg}

    def get_user_encoding(self, user_ids, user_features, B):
        '''
        @input:
        - user_ids: (B,)
        - user_features: {'uf_{feature_name}': (B, feature_dim)}
        @output:
        - encoding: (B, enc_dim)
        - reg: scalar
        '''
        # (B, 1, u_latent_dim)
        user_id_emb = self.uIDEmb(user_ids).view(B,1,self.user_latent_dim)
        # [(B, 1, u_latent_dim)] * n_user_feature
        user_feature_emb = [user_id_emb]
        for f,fEmbModule in self.uFeatureEmb.items():
            user_feature_emb.append(fEmbModule(user_features[f]).view(B,1,self.user_latent_dim))
        # (B, n_user_feature+1, u_latent_dim)
        combined_user_emb = torch.cat(user_feature_emb, 1)
        combined_user_emb = self.userEmbNorm(combined_user_emb)
        # (B, enc_dim)
        encoding = self.userFeatureKernel(combined_user_emb).sum(1)
        # regularization
        reg = torch.mean(user_id_emb * user_id_emb)
        return encoding, reg

    def get_item_encoding(self, item_ids, item_features, B):
        '''
        @input:
        - item_ids: (B,) or (B,L)
        - item_features: {'{feature_name}': (B,feature_dim) or (B,L,feature_dim)}
        @output:
        - encoding: (B, 1, enc_dim) or (B, L, enc_dim)
        - reg: scalar
        '''
        # (B, 1, i_latent_dim) or (B, L, i_latent_dim)
        item_id_emb = self.iIDEmb(item_ids).view(B,-1,self.item_latent_dim)
        L = item_id_emb.shape[1]
        # [(B, 1, i_latent_dim)] * n_item_feature or [(B, L, i_latent_dim)] * n_item_feature
        item_feature_emb = [item_id_emb]
        for f,fEmbModule in self.iFeatureEmb.items():
            f_dim = self.item_feature_dims[f]
            item_feature_emb.append(fEmbModule(item_features[f].view(B,L,f_dim)).view(B,-1,self.item_latent_dim))
        # (B, 1, n_item_feature+1, i_latent_dim) or (B, L, n_item_feature+1, i_latent_dim)
        combined_item_emb = torch.cat(item_feature_emb, -1).view(B, L, -1, self.item_latent_dim)
        combined_item_emb = self.itemEmbNorm(combined_item_emb)
        # (B, 1, enc_dim) or (B, L, enc_dim)
        encoding = self.itemFeatureKernel(combined_item_emb).sum(2)
        encoding = self.encNorm(encoding.view(B, -1, self.enc_dim))
        # regularization
        reg = torch.mean(item_id_emb * item_id_emb)
        return encoding, reg

    def get_response_embedding(self, resp_dict, B):
        '''
        @input:
        - resp_dict: {'{response}': (B, max_H)}
        @output:
        - resp_emb: (B, max_H, enc_dim)
        '''
        resp_list = []
        for f in self.feedback_types:
            # (B, max_H)
            resp = resp_dict[f].view(B, self.max_len)
            resp_list.append(resp)
        # (B, max_H, n_feedback)
        combined_resp = torch.cat(resp_list, -1).view(B,self.max_len,self.feedback_dim)
        # (B, max_H, enc_dim)
        resp_emb = self.feedbackEncoder(combined_resp)
        return resp_emb


In [None]:
import torch
import torch.nn as nn


class BaseOnlinePolicy(BaseModel):
    '''
    Pointwise model
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - from BackboneUserEncoder:
            - user_latent_dim
            - item_latent_dim
            - transformer_enc_dim
            - transformer_n_head
            - transformer_d_forward
            - transformer_n_layer
            - state_hidden_dims
            - dropout_rate
            - from BaseModel:
                - model_path
                - loss
                - l2_coef
        '''
        parser = BackboneUserEncoder.parse_model_args(parser)
#         parser.add_argument('--score_clip', type=float, default=2.0,
#                             help='ranking scores will be [-clip, +clip]')
        return parser

    def __init__(self, args, env, device):
        self.slate_size = args.slate_size # from environment arguments
        # BaseModel initialization:
        # - reader_stats, model_path, loss_type, l2_coef, no_reg, device
        # - _define_params(args): enc_dim, state_dim, action_dim
#         self.score_clip = args.score_clip
        super().__init__(args, env.reader.get_statistics(), device)
        self.display_name = "BaseOnlinePolicy"

    def to(self, device):
        new_self = super(BaseOnlinePolicy, self).to(device)
        self.userEncoder.device = device
        self.userEncoder = self.userEncoder.to(device)
        return new_self

    def _define_params(self, args, reader_stats):
        #self.userEncoder = BackboneUserEncoder(args, self.reader_stats, self.device)
        self.userEncoder = BackboneUserEncoder(args, reader_stats)
        self.enc_dim = self.userEncoder.enc_dim
        self.state_dim = self.userEncoder.state_dim
        self.action_dim = self.slate_size

#         # user_state2condition layer
#         self.state2z = nn.Linear(self.state_dim, self.enc_dim)
#         self.zNorm = nn.LayerNorm(self.enc_dim)

        self.bce_loss = nn.BCEWithLogitsLoss(reduction = 'none')



    def get_forward(self, feed_dict: dict):
        '''
        @input:
        - feed_dict: {
            'observation':{
                'user_profile':{
                    'user_id': (B,)
                    'uf_{feature_name}': (B,feature_dim), the user features}
                'user_history':{
                    'history': (B,max_H)
                    'history_if_{feature_name}': (B,max_H,feature_dim), the history item features}
            'candidates':{
                'item_id': (B,L) or (1,L), the target item
                'item_{feature_name}': (B,L,feature_dim) or (1,L,feature_dim), the target item features}
            'epsilon': scalar,
            'do_explore': boolean,
            'candidates': {
                'item_id': (B,L) or (1,L), the target item
                'item_{feature_name}': (B,L,feature_dim) or (1,L,feature_dim), the target item features},
            'action_dim': slate size K,
            'action': (B,K),
            'response': {
                'reward': (B,),
                'immediate_response': (B,K*n_feedback)},
            'is_train': boolean
        }
        @output:
        - out_dict: {
            'state': (B,state_dim),
            'prob': (B,K),
            'action': (B,K),
            'reg': scalar}
        '''
        observation = feed_dict['observation']
#         candidates = feed_dict['candidates']
        # observation --> user state
        state_encoder_output = self.get_user_state(observation)
        # (B, state_dim)
        user_state = state_encoder_output['state']
        # user state --> prob, action
        out_dict = self.generate_action(user_state, feed_dict)

        out_dict['state'] = user_state
        out_dict['reg'] = state_encoder_output['reg'] + out_dict['reg']

        return out_dict

    def get_user_state(self, observation):
        feed_dict = {}
        feed_dict.update(observation['user_profile'])
        feed_dict.update(observation['user_history'])
        B = feed_dict['user_id'].shape[0]
        return self.userEncoder(feed_dict, B)

    def get_loss_observation(self):
        return ['loss']

    def generate_action(self, user_state, feed_dict):
        '''
        This function will be called in the following places:
        * OnlineAgent.run_episode_step() with {'action': None, 'response': None,
                                               'epsilon': >0, 'do_explore': True, 'is_train': False}
        * OnlineAgent.step_train() with {'action': tensor, 'response': {'reward': , 'immediate_response': },
                                         'epsilon': 0, 'do_explore': False, 'is_train': True}
        * OnlineAgent.test() with {'action': None, 'response': None,
                                   'epsilon': 0, 'do_explore': False, 'is_train': False}

        @input:
        - user_state
        - feed_dict
        @output:
        - out_dict: {'prob': (B, K),
                     'action': (B, K),
                     'reg': scalar}
        '''
        pass


    def get_loss(self, feed_dict, out_dict):
        '''
        @input:
        - feed_dict: same as get_forward@input-feed_dict
        - out_dict: {
            'state': (B,state_dim),
            'prob': (B,K),
            'action': (B,K),
            'reg': scalar,
            'immediate_response': (B,K),
            'reward': (B,)}
        @output
        - loss
        '''
        pass

In [None]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np


class TwoStageOnlinePolicy(BaseOnlinePolicy):
    '''
    Pointwise model
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - initial_list_size
        - stage1_state2z_hidden_dims
        - stage1_pos_offset
        - stage1_neg_offset
        - initial_loss_coef
        - from BackboneUserEncoder:
            - user_latent_dim
            - item_latent_dim
            - transformer_enc_dim
            - transformer_n_head
            - transformer_d_forward
            - transformer_n_layer
            - state_hidden_dims
            - dropout_rate
            - from BaseModel:
                - model_path
                - loss
                - l2_coef
        '''
        parser = BaseOnlinePolicy.parse_model_args(parser)
        parser.add_argument('--initial_list_size', type=int, default=50,
                            help='candidate list size after initial ranker')
        parser.add_argument('--stage1_state2z_hidden_dims', type=int, nargs="+", default=[128],
                            help='hidden dimensions of state_slate encoding layers')
        parser.add_argument('--stage1_pos_offset', type=float, default=0.8,
                            help='smooth offset of positive prob')
        parser.add_argument('--stage1_neg_offset', type=float, default=0.1,
                            help='smooth offset of negative prob')
        parser.add_argument('--initial_loss_coef', type=float, default=0.1,
                            help='relative importance of training loss of initial ranker')
        return parser

    def __init__(self, args, env, device):
        self.initial_list_size = args.initial_list_size
        self.stage1_state2z_hidden_dims = args.stage1_state2z_hidden_dims
        self.stage1_pos_offset = args.stage1_pos_offset
        self.stage1_neg_offset = args.stage1_neg_offset
        self.initial_loss_coef = args.initial_loss_coef
        # BaseOnlinePolicy initialization:
        # - reader_stats, model_path, loss_type, l2_coef, no_reg, device, slate_size
        # - _define_params(args): userEncoder, enc_dim, state_dim, action_dim
        super().__init__(args, env, device)
        self.display_name = "TwoStageOnlinePolicy"
        self.train_initial = True
        self.train_rerank = True

    def to(self, device):
        new_self = super(TwoStageOnlinePolicy, self).to(device)
        return new_self

    def _define_params(self, args):
        '''
        Default two stage policy (pointwise initial ranker + no reranking)
        '''
        # userEncoder, enc_dim, state_dim, action_dim
        super()._define_params(args)
        # p_forward
        self.stage1State2Z = DNN(self.state_dim, args.stage1_state2z_hidden_dims, self.enc_dim,
                           dropout_rate = args.dropout_rate, do_batch_norm = True)
        self.stage1ZNorm = nn.LayerNorm(self.enc_dim)

    def generate_action(self, user_state, feed_dict):
        '''
        This function will be called in the following places:
        * OnlineAgent.run_episode_step() with {'action': None, 'response': None,
                                               'epsilon': >0, 'do_explore': True, 'is_train': False}
        * OnlineAgent.step_train() with {'action': tensor, 'response': {'reward': , 'immediate_response': },
                                         'epsilon': 0, 'do_explore': False, 'is_train': True}
        * OnlineAgent.test() with {'action': None, 'response': None,
                                   'epsilon': 0, 'do_explore': False, 'is_train': False}

        @input:
        - user_state
        - feed_dict
        @output:
        - out_dict: {'prob': (B, K),
                     'action': (B, K),
                     'reg': scalar}
        '''
        # batch-wise candidates has shape (B,L), non-batch-wise candidates has shape (1,L)
        batch_wise = True
        if feed_dict['candidates']['item_id'].shape[0] == 1:
            batch_wise = False
        feed_dict['do_batch_wise'] = batch_wise
        # during training, candidates is always the full item set and has shape (1,L) where L=N
        if feed_dict['is_train']:
            assert not batch_wise
        do_uniform = np.random.random() < feed_dict['epsilon']
        feed_dict['do_uniform'] = do_uniform

        initial_out_dict = self.generate_initial_rank(user_state, feed_dict)
        out_dict = self.generate_final_action(user_state, feed_dict, initial_out_dict)
        out_dict['initial_prob'] = initial_out_dict['initial_prob']
        out_dict['initial_action'] = initial_out_dict['initial_action']
        out_dict['reg'] = initial_out_dict['reg'] + out_dict['reg']
        return out_dict

    def generate_initial_rank(self, user_state, feed_dict):
        '''
        @input:
        - user_state: (B, state_dim)
        - feed_dict: same as BaseOnlinePolicy.get_forward@feed_dict
        @output:
        - out_dict: {'initial_prob': the initial list's item probabilities, (B, K) if training, (B, C) in inference,
                     'initial_action': the initial list, (B, K) if training, (B, C) if inference,
                     'candidate_item_enc': (B, L, enc_dim),
                     'reg': scalar}
        '''
        candidates = feed_dict['candidates']
        slate_size = feed_dict['action_dim']
        action_slate = feed_dict['action'] # (B, K)
        do_explore = feed_dict['do_explore']
        do_uniform = feed_dict['do_uniform']
        epsilon = feed_dict['epsilon']
        is_train = feed_dict['is_train']
        batch_wise = feed_dict['do_batch_wise']

        B = user_state.shape[0]
        # (1,L,enc_dim) or (B,L,enc_dim)
        candidate_item_enc, reg = self.userEncoder.get_item_encoding(candidates['item_id'],
                                                       {k[5:]: v for k,v in candidates.items() if k != 'item_id'},
                                                                     B if batch_wise else 1)
        # (B, enc_dim)
        Z = self.stage1State2Z(user_state)
        Z = self.stage1ZNorm(Z)
        # (B, L)
        score = torch.sum(Z.view(B,1,self.enc_dim) * candidate_item_enc, dim = -1) #/ self.enc_dim
#         score = torch.clamp(score, -self.score_clip, self.score_clip)

        if is_train or torch.is_tensor(action_slate):
            stage1_n_neg = self.initial_list_size - self.slate_size
            # (B, C-K)
            neg_indices = Categorical(torch.ones_like(score)).sample((stage1_n_neg,)).transpose(0,1)
            # (B, C)
            indices = torch.cat((action_slate, neg_indices), dim = 1)
            score = torch.gather(score, 1, indices)
            prob = torch.softmax(score, dim = 1)
            selected_P = prob
            initial_action = indices
            # scalar
            reg = self.get_regularization(self.stage1State2Z)
        else:
            # (B, L)
            prob = torch.softmax(score, dim = 1)
            if do_explore:
                # exploration: categorical sampling or uniform sampling
                if do_uniform:
                    indices = Categorical(torch.ones_like(prob)).sample((self.initial_list_size,)).transpose(0,1)
                else:
                    indices = Categorical(prob).sample((self.initial_list_size,)).transpose(0,1)
            else:
                # greedy: topk selection
                _, indices = torch.topk(prob, k = self.initial_list_size, dim = 1)
            # (B, C)
            indices = indices.view(-1,self.initial_list_size).detach()

            selected_P = torch.gather(prob,1,indices)
            # slate action (B, K) if training or (B, C) if inference
            initial_action = indices
            reg = 0

        out_dict = {'initial_prob': selected_P, # (B, C)
                    'initial_action': initial_action, # (B, C)
                    'candidate_item_enc': candidate_item_enc, # (1, L, enc_dim)
                    'reg': reg}
        return out_dict

    def generate_final_action(self, user_state, feed_dict, initial_out_dict):
        '''
        @input:
        - user_state: (B, state_dim)
        - feed_dict: same as BaseOnlinePolicy.get_forward@input-feed_dict
        - initial_out_dict: TwoStageOnlinePolicy.generate_initial_rank@output-out_dict
        @output:
        - out_dict: {
            prob: (B, K),
            action: (B, K),
            reg: scalar
        }
        '''

        B = user_state.shape[0]
        prob = initial_out_dict['initial_prob'][:,:self.slate_size].detach()
        slate_action = initial_out_dict['initial_action'][:,:self.slate_size].detach()

#         # (B, K)
#         selected_P = prob[:,:self.slate_size]
#         # (B, K)
#         initial_action = action_slate
        reg = 0
        return {'prob': prob,
                'action': slate_action,
                'reg': reg}

    def get_loss_observation(self):
        return ['loss', 'initial_loss', 'rerank_loss']

    def get_loss(self, feed_dict, out_dict):
        '''
        Reward-based pointwise ranking loss
        * - Ylog(P) - (1-Y)log(1-P)
        * Y = sum(w[i] * r[i]) # the weighted sum of user responses

        @input:
        - feed_dict: same as BaseOnlinePolicy.get_forward@input-feed_dict
        - out_dict: {
            'state': (B,state_dim),
            'initial_prob': (B,C),
            'initial_action': (B,C),
            'prob': (B,K),
            'action': (B,K),
            'reg': scalar,
            'immediate_response': (B,K*n_feedback),
            'immediate_response_weight: (n_feedback, ),
            'reward': (B,)}
        @output
        - loss
        '''
        B = out_dict['prob'].shape[0]
        # (B, K)
        initial_prob = out_dict['initial_prob'][:,:self.slate_size]

        if self.train_initial:
            # training of initial ranker
            # (B,K,n_feedback)
            weighted_response = out_dict['immediate_response'].view(B,self.slate_size,-1) \
                                    * out_dict['immediate_response_weight'].view(1,1,-1)
            # (B,K)
            Y = torch.mean(weighted_response, dim = 2)
            initial_loss = self.get_reward_bce(initial_prob, Y)
        else:
            initial_loss = torch.tensor(0)

        if self.train_rerank:
            # training of reranker
            rerank_loss = torch.zeros_like(initial_loss)
        else:
            rerank_loss = torch.tensor(0)

        # scalar
        loss = self.initial_loss_coef * initial_loss + rerank_loss + self.l2_coef * out_dict['reg']


#         print('log(P):', torch.mean(log_P), torch.var(log_P))
#         print('log(1-P):', torch.mean(log_neg_P), torch.var(log_neg_P))
#         print('Y:', torch.mean(Y), torch.var(Y))
#         print('loss:', torch.mean(R_loss), torch.var(R_loss))
#         input()
        return {'initial_loss': loss, 'rerank_loss': rerank_loss, 'loss': loss}

    def get_reward_bce(self, prob, y):
        # (B, K)
        log_P = torch.log(prob + self.stage1_pos_offset)
        # (B, K)
        log_neg_P = torch.log(1 - prob + self.stage1_neg_offset)
        # (B, K)
        L = - torch.mean(y * log_P + (1-y) * log_neg_P)
        return L

In [None]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np


class PRM(TwoStageOnlinePolicy):
    '''
    Pointwise model
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - from TwoStageOnlinePolicy:
            - initial_list_size
            - stage1_n_neg
            - stage1_state2z_hidden_dims
            - stage1_pos_offset
            - stage1_neg_offset
            - initial_loss_coef
            - from BackboneUserEncoder:
                - user_latent_dim
                - item_latent_dim
                - transformer_enc_dim
                - transformer_n_head
                - transformer_d_forward
                - transformer_n_layer
                - state_hidden_dims
                - dropout_rate
                - from BaseModel:
                    - model_path
                    - loss
                    - l2_coef
        '''
        parser = TwoStageOnlinePolicy.parse_model_args(parser)
        parser.add_argument('--prm_pv_input_dim', type=int, default=32,
                            help='input size of PV module of PRM')
        parser.add_argument('--prm_pv_hidden_dims', type=int, nargs="+", default=[128],
                            help='hidden dims of PV module of PRM')
        parser.add_argument('--prm_encoder_enc_dim', type=int, default=32,
                            help='item encoding size of PRM')
        parser.add_argument('--prm_encoder_n_head', type=int, default=4,
                            help='number of attention heads in transformer of PRM')
        parser.add_argument('--prm_encoder_d_forward', type=int, default=64,
                            help='forward layer dimension in transformer of PRM')
        parser.add_argument('--prm_encoder_n_layer', type=int, default=2,
                            help='number of encoder layers in transformer of PRM')
        parser.add_argument('--prm_pv_loss_coef', type=float, default=1.0,
                            help='relative coefficient of pv loss')
        return parser

    def __init__(self, args, env, device):
        # TwoStageOnlinePolicy initialization:
        # - initial_list_size, stage1_n_neg, stage1_state2z_hidden_dims, stage1_pos_offset, stage1_neg_offset, initial_loss_coef
        # - reader_stats, model_path, loss_type, l2_coef, no_reg, device, slate_size
        # - _define_params(args): userEncoder, enc_dim, state_dim, action_dim
        self.prm_pv_input_dim = args.prm_pv_input_dim
        self.prm_pv_hidden_dims = args.prm_pv_hidden_dims
        self.prm_encoder_enc_dim = args.prm_encoder_enc_dim
        self.prm_encoder_n_head = args.prm_encoder_n_head
        self.prm_encoder_d_forward = args.prm_encoder_d_forward
        self.prm_encoder_n_layer = args.prm_encoder_n_layer
        self.prm_pv_loss_coef = args.prm_pv_loss_coef
        super().__init__(args, env, device)
        self.display_name = "PRM"

    def to(self, device):
        new_self = super(PRM, self).to(device)
        new_self.PV_attn_mask = new_self.PV_attn_mask.to(device)
        new_self.PV_pos_emb_getter = new_self.PV_pos_emb_getter.to(device)
        return new_self

    def _define_params(self, args):
        '''
        Default two stage policy (pointwise initial ranker + no reranking)
        '''
        # stage1State2Z, stage1ZNorm, userEncoder, enc_dim, state_dim, action_dim
        super()._define_params(args)

        # input layer of PRM (personalized vector model + pos emb)
        # personalized vector model
        self.PVUserInputMap = nn.Linear(self.state_dim, args.prm_pv_input_dim)
        self.PVItemInputMap = nn.Linear(self.enc_dim, args.prm_pv_input_dim)
        self.PVInputNorm = nn.LayerNorm(args.prm_pv_input_dim)
        self.PVOutput = DNN(args.prm_pv_input_dim, args.prm_pv_hidden_dims, args.prm_encoder_enc_dim,
                            dropout_rate = args.dropout_rate, do_batch_norm = True)
        # label prediction model
        self.PVPred = nn.Linear(args.prm_encoder_enc_dim, 1)
        # positional embedding
        self.PVPosEmb = nn.Embedding(self.initial_list_size, args.prm_encoder_enc_dim)
        self.PV_pos_emb_getter = torch.arange(self.initial_list_size, dtype = torch.long)
        self.PV_attn_mask = ~torch.tril(torch.ones((self.initial_list_size,self.initial_list_size), dtype=torch.bool))

        # encoding layer of PRM (transformer)
        encoder_layer = nn.TransformerEncoderLayer(d_model=args.prm_encoder_enc_dim,
                                                   dim_feedforward = args.prm_encoder_d_forward,
                                                   nhead=args.prm_encoder_n_head, dropout = args.dropout_rate,
                                                   batch_first = True)
        self.PRMEncoder = nn.TransformerEncoder(encoder_layer, num_layers=args.prm_encoder_n_layer)

        # output layer of PRM
        self.PRMOutput = nn.Linear(args.prm_encoder_enc_dim, 1)

    def generate_final_action(self, user_state, feed_dict, initial_out_dict):
        '''
        @input:
        - user_state: (B, state_dim)
        - feed_dict: same as BaseOnlinePolicy.get_forward@input-feed_dict
        - initial_out_dict: TwoStageOnlinePolicy.generate_initial_rank@output-out_dict
        @output:
        - out_dict: {

        }
        '''

        do_explore = feed_dict['do_explore']
        do_uniform = feed_dict['do_uniform']
        epsilon = feed_dict['epsilon']
        is_train = feed_dict['is_train']
        action_slate = feed_dict['action']

        # batch size
        B = user_state.shape[0]
        # initial list (B, C), the first K correspond to the observed slate if training
        initial_prob = initial_out_dict['initial_prob'].detach()
        initial_action = initial_out_dict['initial_action'].detach()
        candidates = initial_action

        # (1, L, enc_dim)
        candidate_item_emb = initial_out_dict['candidate_item_enc']
        # (B, C, enc_dim)
        initial_item_emb = candidate_item_emb.view(-1, self.enc_dim)[initial_action].detach()

        # input layer
        # (B, 1, pv_input_dim)
        user_input = self.PVUserInputMap(user_state).view(B,1,self.prm_pv_input_dim)
        user_input = self.PVInputNorm(user_input)
        # (B, C, pv_input_dim)
        item_input = self.PVItemInputMap(initial_item_emb.view(B*self.initial_list_size,self.enc_dim))\
                            .view(B,self.initial_list_size,self.prm_pv_input_dim)
        item_input = self.PVInputNorm(item_input)
        # (B, C, pv_input_dim)
        pv_ui_input = user_input + item_input
        # (B, C, pv_enc_dim)
        pv_ui_enc = self.PVOutput(pv_ui_input).view(B,self.initial_list_size,self.prm_encoder_enc_dim)
        # positional encoding (1, C, pv_enc_dim)
        pos_emb = self.PVPosEmb(self.PV_pos_emb_getter).view(1,self.initial_list_size,self.prm_encoder_enc_dim)
        # (B, C, pv_enc_dim)
        pv_E = pv_ui_enc + pos_emb

        # PRM transformer encoder output (B, C, enc_dim)
        PRM_encoder_output = self.PRMEncoder(pv_E, mask = self.PV_attn_mask)

        # PRM reranked score (B, C)
        rerank_score = self.PRMOutput(PRM_encoder_output.view(B*self.initial_list_size,self.prm_encoder_enc_dim))\
                                    .view(B,self.initial_list_size)
        rerank_prob = torch.softmax(rerank_score, dim = 1)

        if is_train or torch.is_tensor(action_slate):
            # (B, K)
            final_action = action_slate
            # (B, K)
            selected_P = rerank_prob[:,:self.slate_size]
            # label prediction (B, C)
            Y = self.PVPred(pv_E.view(B*self.initial_list_size,self.prm_encoder_enc_dim))\
                        .view(B,self.initial_list_size)
            # (B, K)
            selected_Y = Y[:,:self.slate_size]
            reg = self.get_regularization(self.PVUserInputMap, self.PVItemInputMap, self.PVOutput,
                                          self.PVPred, self.PRMEncoder, self.PRMOutput)
            reg = reg + torch.mean(pos_emb * pos_emb)
        else:
            if do_explore:
                # exploration: categorical sampling or uniform sampling
                if do_uniform:
                    indices = Categorical(torch.ones_like(rerank_prob)).sample((self.slate_size,)).transpose(0,1)
                else:
                    indices = Categorical(rerank_prob).sample((self.slate_size,)).transpose(0,1)
            else:
                # greedy: topk selection
                _, indices = torch.topk(rerank_prob, k = self.slate_size, dim = 1)
            indices = indices.view(-1,self.slate_size).detach()
            selected_P = torch.gather(rerank_prob,1,indices)
            final_action = torch.gather(initial_action,1,indices)
            selected_Y = None
            reg = 0


        return {'prob': selected_P,
                'action': final_action,
                'reward_pred': selected_Y,
                'reg': reg}

    def get_loss_observation(self):
        return ['loss', 'initial_loss', 'rerank_loss', 'pv_loss']

    def get_loss(self, feed_dict, out_dict):
        '''
        Reward-based pointwise ranking loss
        * - Ylog(P) - (1-Y)log(1-P)
        * Y = sum(w[i] * r[i]) # the weighted sum of user responses

        @input:
        - feed_dict: same as BaseOnlinePolicy.get_forward@input-feed_dict
        - out_dict: {
            'state': (B,state_dim),
            'initial_prob': (B,K),
            'initial_action': (B,K),
            'prob': (B,K),
            'action': (B,K),
            'reward_pred': (B,K),
            'reg': scalar,
            'immediate_response': (B,K*n_feedback),
            'immediate_response_weight: (n_feedback, ),
            'reward': (B,)}
        @output
        - loss
        '''
        B = out_dict['prob'].shape[0]

        # training of initial ranker
        # (B,K,n_feedback)
        weighted_response = out_dict['immediate_response'].view(B,self.slate_size,-1) \
                                * out_dict['immediate_response_weight'].view(1,1,-1)
        # (B,K)
        Y = torch.mean(weighted_response, dim = 2)

        if self.train_initial:
            # initial ranker loss
            initial_loss = self.get_reward_bce(out_dict['initial_prob'][:,:self.slate_size], Y)
        else:
            initial_loss = torch.tensor(0)

        if self.train_rerank:
            # reranker loss
            rerank_loss = self.get_reward_bce(out_dict['prob'], Y)
            # pv loss
            pv_loss = torch.mean((out_dict['reward_pred'] - Y).pow(2))
        else:
            rerank_loss, pv_loss = torch.tensor(0), torch.tensor(0)

        # scalar
        loss = self.initial_loss_coef * initial_loss + rerank_loss \
                    + self.prm_pv_loss_coef * pv_loss + self.l2_coef * out_dict['reg']


#         print('log(P):', torch.mean(log_P), torch.var(log_P))
#         print('log(1-P):', torch.mean(log_neg_P), torch.var(log_neg_P))
#         print('Y:', torch.mean(Y), torch.var(Y))
#         print('loss:', torch.mean(R_loss), torch.var(R_loss))
#         input()
        return {'initial_loss': loss,
                'rerank_loss': rerank_loss,
                'pv_loss': pv_loss,
                'loss': loss}





In [None]:
import argparse
import torch
import torch.nn as nn
import numpy as np
from torch.distributions import Categorical


# 1) make an adapter that fixes the `_define_params` signature
class PRMAdapter(PRM):
    def _define_params(self, args, reader_stats):
        """
        Reproduce the logic of:
        - BaseOnlinePolicy._define_params(args, reader_stats)
        - TwoStageOnlinePolicy._define_params(args)
        - PRM._define_params(args)
        but in one method that has the correct signature.
        """
        # 1. base online policy: builds userEncoder, enc_dim, state_dim, action_dim
        BaseOnlinePolicy._define_params(self, args, reader_stats)

        # figure out dropout name (your earlier args used state_dropout_rate)
        dr = getattr(args, "dropout_rate", getattr(args, "state_dropout_rate", 0.1))

        # 2. two-stage part (this is what TwoStageOnlinePolicy._define_params did)
        self.stage1State2Z = DNN(
            self.state_dim,
            args.stage1_state2z_hidden_dims,
            self.enc_dim,
            dropout_rate=dr,
            do_batch_norm=True
        )
        self.stage1ZNorm = nn.LayerNorm(self.enc_dim)

        # 3. PRM-specific part (copy of your PRM._define_params)
        # personalized vector model
        self.PVUserInputMap = nn.Linear(self.state_dim, args.prm_pv_input_dim)
        self.PVItemInputMap = nn.Linear(self.enc_dim, args.prm_pv_input_dim)
        self.PVInputNorm = nn.LayerNorm(args.prm_pv_input_dim)
        self.PVOutput = DNN(
            args.prm_pv_input_dim,
            args.prm_pv_hidden_dims,
            args.prm_encoder_enc_dim,
            dropout_rate=dr,
            do_batch_norm=True
        )
        # label prediction model
        self.PVPred = nn.Linear(args.prm_encoder_enc_dim, 1)
        # positional embedding
        self.PVPosEmb = nn.Embedding(self.initial_list_size, args.prm_encoder_enc_dim)
        self.PV_pos_emb_getter = torch.arange(self.initial_list_size, dtype=torch.long)
        self.PV_attn_mask = ~torch.tril(
            torch.ones((self.initial_list_size, self.initial_list_size), dtype=torch.bool)
        )

        # encoding layer of PRM (transformer)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=args.prm_encoder_enc_dim,
            dim_feedforward=args.prm_encoder_d_forward,
            nhead=args.prm_encoder_n_head,
            dropout=dr,
            batch_first=True
        )
        self.PRMEncoder = nn.TransformerEncoder(encoder_layer, num_layers=args.prm_encoder_n_layer)

        # output layer of PRM
        self.PRMOutput = nn.Linear(args.prm_encoder_enc_dim, 1)



In [None]:
import torch.nn.functional as F
import torch.nn as nn
import torch


class QCritic(nn.Module):
    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - critic_hidden_dims
        - critic_dropout_rate
        '''
        parser.add_argument('--critic_hidden_dims', type=int, nargs='+', default=[128],
                            help='specificy a list of k for top-k performance')
        parser.add_argument('--critic_dropout_rate', type=float, default=0.1,
                            help='dropout rate in deep layers')
        return parser

    def __init__(self, args, environment, policy):
        super().__init__()
        self.state_dim = policy.state_dim
        self.action_dim = policy.action_dim
        self.net = DNN(self.state_dim + self.action_dim, args.critic_hidden_dims, 1,
                       dropout_rate = args.critic_dropout_rate, do_batch_norm = True)

    def forward(self, feed_dict):
        '''
        @input:
        - feed_dict: {'state': (B, state_dim), 'action': (B, action_dim)}
        '''
        state_emb = feed_dict['state'].view(-1, self.state_dim)
        action_emb = feed_dict['action'].view(-1, self.action_dim)
        Q = self.net(torch.cat((state_emb, action_emb), dim = -1)).view(-1)
        reg = get_regularization(self.net)
        return {'q': Q, 'reg': reg}

In [None]:
import torch
import torch.nn.functional as F
import random
import numpy as np


class BaseBuffer():
    '''
    The general buffer
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - buffer_size
        '''
        parser.add_argument('--buffer_size', type=int, default=10000,
                            help='replay buffer size')
        return parser

    def __init__(self, *input_args):
        args, env, policy, critic = input_args
        self.buffer_size = args.buffer_size
        super().__init__()
        self.device = args.device
        self.buffer_head = 0
        self.current_buffer_size = 0
        self.n_stream_record = 0

    def reset(self, *reset_args):
        '''
        @output:
        - buffer: {'observation': {'user_profile': {'user_id': (L,),
                                                    'uf_{feature_name}': (L, feature_dim)},
                                   'user_history': {'history': (L, max_H),
                                                    'history_if_{feature_name}': (L, max_H * feature_dim),
                                                    'history_{response}': (L, max_H),
                                                    'history_length': (L,)}}
                   'policy_output': {'state': (L, state_dim),
                                     'action': (L, action_dim),
                                     'prob': (L, slate_size)},
                   'next_observation': same format as @output-buffer['observation'],
                   'done_mask': (L,),
                   'response': {'reward': (L,),
                                'immediate_response':, (L, slate_size * response_dim)}}
        '''
        env = reset_args[0]
        actor = reset_args[1]
        observation = env.create_observation_buffer(self.buffer_size)
        next_observation = env.create_observation_buffer(self.buffer_size)
        policy_output = {'state': torch.zeros(self.buffer_size, actor.state_dim)\
                                         .to(torch.float).to(self.device),
                         'action': torch.zeros(self.buffer_size, actor.action_dim)\
                                         .to(torch.long).to(self.device),
                         'prob': torch.zeros(self.buffer_size, env.slate_size)\
                                         .to(torch.float).to(self.device)}
        reward = torch.zeros(self.buffer_size).to(torch.float).to(self.device)
        done = torch.zeros(self.buffer_size).to(torch.bool).to(self.device)
        im_response = torch.zeros(self.buffer_size, env.response_dim * env.slate_size)\
                                         .to(torch.float).to(self.device)
        self.buffer = {'observation': observation,
                       'policy_output': policy_output,
                       'user_response': {'reward': reward, 'immediate_response': im_response},
                       'done_mask': done,
                       'next_observation': next_observation}
        return self.buffer


    def sample(self, batch_size):
        '''
        Batch sample is organized as a tuple of (observation, policy_output, user_response, done_mask, next_observation)

        Buffer: see reset@output
        @output:
        - observation: {'user_profile': {'user_id': (B,),
                                         'uf_{feature_name}': (B, feature_dim)},
                        'user_history': {'history': (B, max_H),
                                         'history_if_{feature_name}': (B, max_H * feature_dim),
                                         'history_{response}': (B, max_H),
                                         'history_length': (B,)}}
        - policy_output: {'state': (B, state_dim),
                          'action': (B, slate_size),
                          'prob': (B, slate_size)},
        - user_feedback: {'reward': (B,),
                          'immediate_response':, (B, slate_size * response_dim)}}
        - done_mask: (B,),
        - next_observation: same format as @output - observation,
        '''
        # get indices
        indices = np.random.randint(0, self.current_buffer_size, size = batch_size)
        # observation
        profile = {k:v[indices] for k,v in self.buffer["observation"]["user_profile"].items()}
        history = {k:v[indices] for k,v in self.buffer["observation"]["user_history"].items()}
        observation = {"user_profile": profile, "user_history": history}
        # next observation
        profile = {k:v[indices] for k,v in self.buffer["next_observation"]["user_profile"].items()}
        history = {k:v[indices] for k,v in self.buffer["next_observation"]["user_history"].items()}
        next_observation = {"user_profile": profile, "user_history": history}
        # policy output
        policy_output = {"state": self.buffer["policy_output"]["state"][indices],
                         "action": self.buffer["policy_output"]["action"][indices],
                         "prob": self.buffer["policy_output"]["prob"][indices]}
        # user response
        user_response = {"reward": self.buffer["user_response"]["reward"][indices],
                         "immediate_response": self.buffer["user_response"]["immediate_response"][indices]}
        # done mask
        done_mask = self.buffer["done_mask"][indices]
        return observation, policy_output, user_response, done_mask, next_observation

    def update(self, observation, policy_output, user_feedback, next_observation):
        '''
        @input:
        - observation: {'user_profile': {'user_id': (B,),
                                         'uf_{feature_name}': (B, feature_dim)},
                        'user_history': {'history': (B, max_H),
                                         'history_if_{feature_name}': (B, max_H * feature_dim),
                                         'history_{response}': (B, max_H),
                                         'history_length': (B,)}}
        - policy_output: {'user_state': (B, state_dim),
                          'prob': (B, action_dim),
                          'action': (B, action_dim)}
        - user_feedback: {'done': (B,),
                          'immdiate_response':, (B, action_dim * feedback_dim),
                          'reward': (B,)}
        - next_observation: same format as update_buffer@input-observation
        '''
        # get buffer indices to update
        B = len(user_feedback['reward'])
        if self.buffer_head + B >= self.buffer_size:
            tail = self.buffer_size - self.buffer_head
            indices = [self.buffer_head + i for i in range(tail)] + \
                        [i for i in range(B - tail)]
        else:
            indices = [self.buffer_head + i for i in range(B)]
        indices = torch.tensor(indices).to(torch.long).to(self.device)

        # update buffer - observation
        for k,v in observation['user_profile'].items():
            self.buffer['observation']['user_profile'][k][indices] = v
        for k,v in observation['user_history'].items():
            self.buffer['observation']['user_history'][k][indices] = v
        # update buffer - next observation
        for k,v in next_observation['user_profile'].items():
            self.buffer['next_observation']['user_profile'][k][indices] = v
        for k,v in next_observation['user_history'].items():
            self.buffer['next_observation']['user_history'][k][indices] = v
        # update buffer - policy output
        self.buffer['policy_output']['state'][indices] = policy_output['state']
        self.buffer['policy_output']['action'][indices] = policy_output['action']
        self.buffer['policy_output']['prob'][indices] = policy_output['prob']
        # update buffer - user response
        self.buffer['user_response']['immediate_response'][indices] = user_feedback['immediate_response'].view(B,-1)
        self.buffer['user_response']['reward'][indices] = user_feedback['reward']
        # update buffer - done
        self.buffer['done_mask'][indices] = user_feedback['done']

        # buffer pointer
        self.buffer_head = (self.buffer_head + B) % self.buffer_size
        self.n_stream_record += B
        self.current_buffer_size = min(self.n_stream_record, self.buffer_size)


In [None]:
import torch
import torch.nn.functional as F
import random
import numpy as np


class HyperActorBuffer(BaseBuffer):

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - from BaseBuffer:
            - buffer_size
        '''
        parser = BaseBuffer.parse_model_args(parser)
        return parser

    def reset(self, *reset_args):

        '''
        @output:
        - buffer: {'observation': {'user_profile': {'user_id': (L,),
                                                    'uf_{feature_name}': (L, feature_dim)},
                                   'user_history': {'history': (L, max_H),
                                                    'history_if_{feature_name}': (L, max_H * feature_dim),
                                                    'history_{response}': (L, max_H),
                                                    'history_length': (L,)}}
                   'policy_output': {'state': (L, state_dim),
                                     'action': (L, action_dim) = (L, slate_size),
                                     'hyper_action': (L, hyper_action_size)},
                   'next_observation': same format as @output-buffer['observation'],
                   'done_mask': (L,),
                   'response': {'reward': (L,),
                                'immediate_response':, (L, action_dim * response_dim)}}
        '''
        env = reset_args[0]
        actor = reset_args[1]

        super().reset(env, actor)

        self.buffer['user_response']['immediate_response'] = torch.zeros(self.buffer_size,
                                                                         env.response_dim * actor.effect_action_dim)\
                                                         .to(torch.float).to(self.device)
        self.buffer['policy_output']['action'] = torch.zeros(self.buffer_size, actor.hyper_action_dim)\
                                                         .to(torch.float).to(self.device)
        self.buffer['policy_output']['effect_action'] = torch.zeros(self.buffer_size, actor.effect_action_dim)\
                                                         .to(torch.long).to(self.device)

        return self.buffer



    def sample(self, batch_size):
        '''
        Batch sample is organized as a tuple of (observation, policy_output, user_response, done_mask, next_observation)

        Buffer: see reset@output
        '''
        # get indices
        indices = np.random.randint(0, self.current_buffer_size, size = batch_size)
        # observation
        profile = {k:v[indices] for k,v in self.buffer["observation"]["user_profile"].items()}
        history = {k:v[indices] for k,v in self.buffer["observation"]["user_history"].items()}
        observation = {"user_profile": profile, "user_history": history}
        # next observation
        profile = {k:v[indices] for k,v in self.buffer["next_observation"]["user_profile"].items()}
        history = {k:v[indices] for k,v in self.buffer["next_observation"]["user_history"].items()}
        next_observation = {"user_profile": profile, "user_history": history}
        # policy output
        policy_output = {"state": self.buffer["policy_output"]["state"][indices],
                         "action": self.buffer["policy_output"]["action"][indices],
                         "hyper_action": self.buffer["policy_output"]["action"][indices],
                         "effect_action": self.buffer["policy_output"]["effect_action"][indices]}  # main change to BaseBuffer
        # user response
        user_response = {"reward": self.buffer["user_response"]["reward"][indices],
                         "immediate_response": self.buffer["user_response"]["immediate_response"][indices]}
        # done mask
        done_mask = self.buffer["done_mask"][indices]
        return observation, policy_output, user_response, done_mask, next_observation

    def update(self, observation, policy_output, user_feedback, next_observation):
        '''
        @input:
        - observation: {'user_profile': {'user_id': (B,),
                                         'uf_{feature_name}': (B, feature_dim)},
                        'user_history': {'history': (B, max_H),
                                         'history_if_{feature_name}': (B, max_H * feature_dim),
                                         'history_{response}': (B, max_H),
                                         'history_length': (B,)}}
        - policy_output: {'user_state': (B, state_dim),
                          'prob': (B, action_dim),
                          'action': (B, action_dim),
                          'hyper_action': (B, hyper_action_dim)}
        - user_feedback: {'done': (B,),
                          'immdiate_response':, (B, action_dim * feedback_dim),
                          'reward': (B,)}
        - next_observation: same format as update_buffer@input-observation
        '''
        # get buffer indices to update
        B = len(user_feedback['reward'])
        if self.buffer_head + B >= self.buffer_size:
            tail = self.buffer_size - self.buffer_head
            indices = [self.buffer_head + i for i in range(tail)] + \
                        [i for i in range(B - tail)]
        else:
            indices = [self.buffer_head + i for i in range(B)]
        indices = torch.tensor(indices).to(torch.long).to(self.device)

        # update buffer - observation
        for k,v in observation['user_profile'].items():
            self.buffer['observation']['user_profile'][k][indices] = v
        for k,v in observation['user_history'].items():
            self.buffer['observation']['user_history'][k][indices] = v
        # update buffer - next observation
        for k,v in next_observation['user_profile'].items():
            self.buffer['next_observation']['user_profile'][k][indices] = v
        for k,v in next_observation['user_history'].items():
            self.buffer['next_observation']['user_history'][k][indices] = v
        # update buffer - policy output
        self.buffer['policy_output']['state'][indices] = policy_output['state']
        self.buffer['policy_output']['action'][indices] = policy_output['hyper_action']
        self.buffer['policy_output']['effect_action'][indices] = policy_output['effect_action'] # main change to BaseBuffer
        # update buffer - user response
        self.buffer['user_response']['immediate_response'][indices] = user_feedback['immediate_response'].view(B,-1)
        self.buffer['user_response']['reward'][indices] = user_feedback['reward']
        # update buffer - done
        self.buffer['done_mask'][indices] = user_feedback['done']

        # buffer pointer
        self.buffer_head = (self.buffer_head + B) % self.buffer_size
        self.n_stream_record += B
        self.current_buffer_size = min(self.n_stream_record, self.buffer_size)



In [None]:
import torch

def get_retention_reward(user_feedback, reward_base = 0.7):
    '''
    @input:
    - user_feedback: {'retention': (B,), ...}
    @output:
    - reward: (B,)
    '''
    reward = - user_feedback['retention']/10.0
    return reward

def get_immediate_reward(user_feedback):
    '''
    @input:
    - user_feedback: {'immediate_response': (B, slate_size, n_feedback),
                      'immediate_response_weight': (n_feedback),
                      ... other feedbacks}
    @output:
    - reward: (B,)
    '''
    # (B, slate_size, n_feedback)
    if 'immediate_response_weight' in user_feedback:
        point_reward = user_feedback['immediate_response'] * user_feedback['immediate_response_weight'].view(1,1,-1)
    else:
        point_reward = user_feedback['immediate_response']
    # (B, slate_size)
    combined_reward = torch.sum(point_reward, dim = 2)
    # (B,)
    #leave_reward = user_feedback['leave'] * user_feedback['leave_weight']
    # (B,)
    #reward = point_reward.sum(dim = -1) + leave_reward
    reward = torch.mean(combined_reward, dim = 1)
    return reward

def get_immediate_reward_sum(user_feedback):
    '''
    @input:
    - user_feedback: {'immediate_response': (B, slate_size, n_feedback),
                      'immediate_response_weight': (n_feedback),
                      ... other feedbacks}
    @output:
    - reward: (B,)
    '''
    # (B, slate_size, n_feedback)
    if 'immediate_response_weight' in user_feedback:
        point_reward = user_feedback['immediate_response'] * user_feedback['immediate_response_weight'].view(1,1,-1)
    else:
        point_reward = user_feedback['immediate_response']
    # (B, slate_size)
    combined_reward = torch.sum(point_reward, dim = 2)
    # (B,)
    #leave_reward = user_feedback['leave'] * user_feedback['leave_weight']
    # (B,)
    #reward = point_reward.sum(dim = -1) + leave_reward
    reward = torch.sum(combined_reward, dim = 1)
    return reward


def sum_with_cost(feedback, zero_reward_cost = 0.1):
    '''
    @input:
    - feedback: (B, K)
    @output:
    - reward: (B,)
    '''
    B,L = feedback.shape
    cost = torch.zeros_like(feedback)
    cost[feedback == 0] = -zero_reward_cost
    reward = torch.sum(feedback + cost, dim = -1)
    return reward


def sigmoid_sum_with_cost(feedback, zero_reward_cost = 0.1):
    '''
    @input:
    - feedback: (B, K)
    @output:
    - reward: (B,)
    '''
    reward = sum_with_cost(feedback, zero_reward_cost)
    return torch.sigmoid(reward)


def log_sum_with_cost(feedback, zero_reward_cost = 0.1):
    '''
    @input:
    - feedback: (B, K)
    @output:
    - reward: (B,)
    '''
    reward = sum_with_cost(feedback, zero_reward_cost)
    reward[reward>0] = (reward[reward>0]+1).log()
    return torch.sigmoid(reward)

def mean_with_cost(feedback_dict, zero_reward_cost = 0.1):
    '''
    @input:
    - feedback: (B, K)
    @output:
    - reward: (B,)
    '''

    B,L = feedback.shape
    cost = torch.zeros_like(feedback)
    cost[feedback == 0] = -zero_reward_cost
    reward = torch.mean(feedback + cost, dim = -1)
    return reward

def mean_advance_with_cost(feedback, zero_reward_cost = 0.1, offset = 0.5):
    '''
    @input:
    - feedback: (B, K)
    @output:
    - reward: (B,)
    '''
    B,L = feedback.shape
    cost = torch.zeros_like(feedback)
    cost[feedback == 0] = -zero_reward_cost
    reward = torch.mean(feedback + cost, dim = -1) - offset
    return reward

In [None]:
import time
import copy
import numpy as np
import torch
import torch.nn.functional as F
from copy import deepcopy
from tqdm import tqdm



class BaseRLAgent():
    '''
    RL Agent controls the overall learning algorithm:
    - objective functions for the policies and critics
    - design of reward function
    - how many steps to train
    - how to do exploration
    - loading and saving of models

    Main interfaces:
    - train
    '''

    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - gamma
        - reward_func
        - n_iter
        - train_every_n_step
        - start_policy_train_at_step
        - initial_epsilon
        - final_epsilon
        - elbow_epsilon
        - explore_rate
        - do_explore_in_train
        - check_episode
        - save_episode
        - save_path
        - actor_lr
        - actor_decay
        - batch_size
        '''
        # basic settings
        parser.add_argument('--gamma', type=float, default=0.95,
                            help='reward discount')
        parser.add_argument('--reward_func', type=str, default='get_retention_reward',
                            help='reward function name')
        parser.add_argument('--n_iter', type=int, nargs='+', default=[2000],
                            help='number of training iterations')
        parser.add_argument('--train_every_n_step', type=int, default=1,
                            help='number of training iterations')
        parser.add_argument('--start_policy_train_at_step', type=int, default=1000,
                            help='start timestamp for buffer sampling')

        # exploration control
        parser.add_argument('--initial_epsilon', type=float, default=0.5,
                            help='probability for using uniform exploration')
        parser.add_argument('--final_epsilon', type=float, default=0.01,
                            help='probability for using uniform exploration')
        parser.add_argument('--elbow_epsilon', type=float, default=1.0,
                            help='probability for using uniform exploration')
        parser.add_argument('--explore_rate', type=float, default=1.0,
                            help='probability of engaging exploration')
        parser.add_argument('--do_explore_in_train', action='store_true',
                            help='probability of engaging exploration')

        # monitoring
        parser.add_argument('--check_episode', type=int, default=100,
                            help='number of iterations to check output and evaluate')
        parser.add_argument('--save_episode', type=int, default=1000,
                            help='number of iterations to save models')
        parser.add_argument('--save_path', type=str, required=True,
                            help='save path for networks')

        # learning
        parser.add_argument('--actor_lr', type=float, default=1e-4,
                            help='learning rate for actor')
        parser.add_argument('--actor_decay', type=float, default=1e-4,
                            help='regularization factor for actor learning')
        parser.add_argument('--batch_size', type=int, default=64,
                            help='training batch size')

        return parser

    def __init__(self, *input_args):
        args, env, actor, buffer = input_args

        self.device = args.device

        # hyperparameters
        self.gamma = args.gamma
        self.reward_func = eval(args.reward_func)
        self.n_iter = args.n_iter
        self.train_every_n_step = args.train_every_n_step
        self.start_policy_train_at_step = args.start_policy_train_at_step

        self.initial_epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.elbow_epsilon = args.elbow_epsilon
        self.explore_rate = args.explore_rate
        self.do_explore_in_train = args.do_explore_in_train

        self.check_episode = args.check_episode
        self.save_episode = args.save_episode
        self.save_path = args.save_path

        self.actor_lr = args.actor_lr
        self.actor_decay = args.actor_decay
        self.batch_size = args.batch_size

        # components
        self.env = env
        self.actor = actor
        self.buffer = buffer

        # controller
        self.exploration_scheduler = LinearScheduler(int(sum(args.n_iter) * args.elbow_epsilon),
                                                           args.final_epsilon,
                                                           initial_p=args.initial_epsilon)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr,
                                                weight_decay=args.actor_decay)

        # register modules that will be saved
        self.registered_models = [(self.actor, self.actor_optimizer, '_actor')]

        if len(self.n_iter) == 1:
            with open(self.save_path + ".report", 'w') as outfile:
                outfile.write(f" ")

    def train(self):
        if len(self.n_iter) > 2:
            self.load()

        t = time.time()
        print("Run procedures before training")
        self.action_before_train()
        t = time.time()
        start_time = t

        # training
        print("Training:")
        step_offset = sum(self.n_iter[:-1])
        do_buffer_update = True
        observation = deepcopy(self.env.current_observation)
        for i in tqdm(range(step_offset, step_offset + self.n_iter[-1]//10)):
            do_explore = np.random.random() < self.explore_rate if self.explore_rate < 1 else True
            # online inference
            observation = self.run_episode_step(i, self.exploration_scheduler.value(i), observation,
                                                do_buffer_update, do_explore)
            # online training
            if i % self.train_every_n_step == 0:
                self.step_train()
            # log monitor records
            if i > 0 and i % self.check_episode == 0:
                t_prime = time.time()
                print(f"Episode step {i}, time diff {t_prime - t}, total time diff {t - start_time})")
                episode_report, train_report = self.get_report(smoothness = self.check_episode)
                log_str = f"step: {i} @ online episode: {episode_report} @ training: {train_report}\n"
                with open(self.save_path + ".report", 'a') as outfile:
                    outfile.write(log_str)
                print(log_str)
                t = t_prime

            # save model and training info
            if i % self.save_episode == 0:
                self.save()

        self.action_after_train()


    def action_before_train(self):
        '''
        Action before training:
        - env.reset()
        - buffer.reset()
        - set up training monitors
            - training_history
            - eval_history
        - run several episodes of random actions to build-up the initial buffer
        '''

        observation = self.env.reset()
        self.buffer.reset(self.env, self.actor)

        # training monitors
        self.setup_monitors()

        episode_iter = 0 # zero training iteration
        pre_epsilon = 1.0 # uniform random explore before training
        do_buffer_update = True
        prepare_step = 0

        for i in tqdm(range(self.start_policy_train_at_step)):
            do_explore = np.random.random() < self.explore_rate
            observation = self.run_episode_step(episode_iter, pre_epsilon, observation,
                                                do_buffer_update, do_explore)
            prepare_step += 1
        print(f"Total {prepare_step} prepare steps")

    def setup_monitors(self):
        self.training_history = {'actor_loss': []}
        self.eval_history = {'avg_reward': [],
                             'reward_variance': [],
                             'avg_total_reward': [0.],
                             'max_total_reward': [0.],
                             'min_total_reward': [0.]}
        self.eval_history.update({f'{resp}_rate': [] for resp in self.env.response_types})
        self.current_sum_reward = torch.zeros(self.env.episode_batch_size).to(torch.float).to(self.device)


    def action_after_train(self):
        self.env.stop()

    def get_report(self, smoothness = 10):
        episode_report = self.env.get_report(smoothness)
        train_report = {k: np.mean(v[-smoothness:]) for k,v in self.training_history.items()}
        train_report.update({k: np.mean(v[-smoothness:]) for k,v in self.eval_history.items()})
        return episode_report, train_report

    def run_episode_step(self, *episode_args):
        '''
        Run one step of user-env interaction
        @input:
        - episode_args: (episode_iter, epsilon, observation, do_buffer_update, do_explore)
        @process:
        - apply_policy: observation, candidate items --> policy_output
        - env.step(): policy_output['action'] --> user_feedback, updated_observation
        - reward_func(): user_feedback --> reward
        - buffer.update(observation, policy_output, user_feedback, updated_observation)
        @output:
        - next_observation
        '''
        episode_iter, epsilon, observation, do_buffer_update, do_explore = episode_args
        self.epsilon = epsilon
        is_train = False
        with torch.no_grad():
            # generate action from policy
            policy_output = self.apply_policy(observation, self.actor, epsilon, do_explore, is_train)

            # apply action on environment
            # Note: action must be indices on env.candidate_iids
            action_dict = {'action': policy_output['indices']}
            new_observation, user_feedback, update_info = self.env.step(action_dict)

            # calculate reward
            R = self.get_reward(user_feedback)
            user_feedback['reward'] = R
            self.current_sum_reward = self.current_sum_reward + R
            done_mask = user_feedback['done']
            if torch.sum(done_mask) > 0:
                self.eval_history['avg_total_reward'].append(self.current_sum_reward[done_mask].mean().item())
                self.eval_history['max_total_reward'].append(self.current_sum_reward[done_mask].max().item())
                self.eval_history['min_total_reward'].append(self.current_sum_reward[done_mask].min().item())
                self.current_sum_reward[done_mask] = 0

            # monitor update
            self.eval_history['avg_reward'].append(R.mean().item())
            self.eval_history['reward_variance'].append(torch.var(R).item())

            for i,resp in enumerate(self.env.response_types):
                self.eval_history[f'{resp}_rate'].append(user_feedback['immediate_response'][:,:,i].mean().item())
            # update replay buffer
            if do_buffer_update:
                self.buffer.update(observation, policy_output, user_feedback, update_info['updated_observation'])
        return new_observation

    def apply_policy(self, observation, actor, *input_args):
        '''
        @input:
        - observation:{'user_profile':{
                           'user_id': (B,)
                           'uf_{feature_name}': (B,feature_dim), the user features}
                       'user_history':{
                           'history': (B,max_H)
                           'history_if_{feature_name}': (B,max_H,feature_dim), the history item features}
        - actor: the actor model
        - epsilon: scalar
        - do_explore: boolean
        - is_train: boolean
        @output:
        - policy_output
        '''
        epsilon = policy_args[0]
        do_explore = policy_args[1]
        is_train = policy_args[2]
        input_dict = {'observation': observation,
                      'candidates': self.env.get_candidate_info(observation),
                      'epsilon': epsilon,
                      'do_explore': do_explore,
                      'is_train': is_train,
                      'batch_wise': False}
        out_dict = self.actor(input_dict)
        return out_dict

    def get_reward(self, user_feedback):
        user_feedback['immediate_response_weight'] = self.env.response_weights
        R = self.reward_func(user_feedback).detach()
        return R

    def step_train(self):
        '''
        @process:
        '''
        observation, policy_output, user_feedback, done_mask, next_observation = self.buffer.sample(self.batch_size)

        loss_dict = self.get_loss(observation, policy_output, user_feedback, done_mask, next_observation)

        for k in loss_dict:
            if k in self.training_history:
                try:
                    self.training_history[k].append(loss_dict[k].item())
                except:
                    self.training_history[k].append(loss_dict[k])

    def get_loss(self, observation, policy_output, user_feedback, done_mask, next_observation):
        pass

    def test(self):
        pass

    def save(self):
        for model, opt, prefix in self.registered_models:
            torch.save(model.state_dict(), self.save_path + prefix)
            torch.save(opt.state_dict(), self.save_path + prefix + "_optimizer")

    def load(self):
        for model, opt, prefix in self.registered_models:
            model.load_state_dict(torch.load(self.save_path + prefix, map_location = self.device))
            opt.load_state_dict(torch.load(self.save_path + prefix + "_optimizer", map_location = self.device))


In [None]:
import time
import copy
import torch
import torch.nn.functional as F
import numpy as np


class TD3(BaseRLAgent):
    @staticmethod
    def parse_model_args(parser):
        '''
        args:
        - args from DDPG:
            - episode_batch_size
            - batch_size
            - actor_lr
            - critic_lr
            - actor_decay
            - critic_decay
            - target_mitigate_coef
            - args from BaseRLAgent:
                - gamma
                - n_iter
                - train_every_n_step
                - initial_greedy_epsilon
                - final_greedy_epsilon
                - elbow_greedy
                - check_episode
                - with_eval
                - save_path
        '''
        parser = BaseRLAgent.parse_model_args(parser)
        # parser.add_argument('--episode_batch_size', type=int, default=8,
        #                     help='episode sample batch size')
        # parser.add_argument('--batch_size', type=int, default=32,
        #                     help='training batch size')
        # parser.add_argument('--actor_lr', type=float, default=1e-4,
        #                     help='learning rate for actor')
        parser.add_argument('--critic_lr', type=float, default=1e-4,
                            help='decay rate for critic')
        # parser.add_argument('--actor_decay', type=float, default=1e-4,
        #                     help='learning rate for actor')
        parser.add_argument('--critic_decay', type=float, default=1e-4,
                            help='decay rate for critic')
        parser.add_argument('--target_mitigate_coef', type=float, default=0.01,
                            help='mitigation factor')

        return parser


    def __init__(self, *input_args):
        '''
        self.gamma
        self.n_iter
        self.check_episode
        self.with_eval
        self.save_path
        self.facade
        self.exploration_scheduler
        '''
        args, env, actor, critic, buffer = input_args
        super().__init__(args, env, actor, buffer)
        self.episode_batch_size = args.episode_batch_size
        self.batch_size = args.batch_size

        self.actor_lr = args.actor_lr
        self.critic_lr = args.critic_lr
        self.actor_decay = args.actor_decay
        self.critic_decay = args.critic_decay

        # self.actor = facade.actor
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.actor_lr,
                                                weight_decay=args.actor_decay)
        self.critic = critic
        self.critic1 = self.critic[0]
        self.critic1_target = copy.deepcopy(self.critic1)
        self.critic1_optimizer = torch.optim.Adam(self.critic1.parameters(), lr=args.critic_lr,
                                                 weight_decay=args.critic_decay)

        self.critic2 = self.critic[1]
        self.critic2_target = copy.deepcopy(self.critic2)
        self.critic2_optimizer = torch.optim.Adam(self.critic2.parameters(), lr=args.critic_lr,
                                                 weight_decay=args.critic_decay)

        self.tau = args.target_mitigate_coef
        if len(self.n_iter) == 1:
            with open(self.save_path + ".report", 'w') as outfile:
                outfile.write(f"{args}\n")

    def action_before_train(self):
        '''
        Action before training:
        - facade setup:
            - buffer setup
        - run random episodes to build-up the initial buffer
        '''
        super().action_before_train()

        # training records
        self.training_history = {'actor_loss': [], 'critic1_loss': [], 'critic2_loss': [],
                                 'Q': [], 'next_Q': []}
        # print(f"Total {prepare_step} prepare steps")



    def step_train(self):
        observation, policy_output, user_feedback, done_mask, next_observation = self.buffer.sample(self.batch_size)
        reward = user_feedback['reward'].view(-1)
        # reward = reward.clone().detach().to(torch.float)
        # done_mask = done_mask.clone().detach().to(torch.float)

        critic_loss, actor_loss = self.get_td3_loss(observation, policy_output, reward, done_mask, next_observation)
        self.training_history['actor_loss'].append(actor_loss.item())
        self.training_history['critic1_loss'].append(critic_loss[0])
        self.training_history['critic2_loss'].append(critic_loss[1])

        # Update the frozen target models
        for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        return {"step_loss": (self.training_history['actor_loss'][-1],
                              self.training_history['critic1_loss'][-1],
                              self.training_history['critic2_loss'][-1])}


    def get_td3_loss(self, observation, policy_output, reward, done_mask, next_observation,
                     do_actor_update = True, do_critic_update = True):
        '''
        @input:
        - observation: {'user_profile': {'user_id': (B,),
                                         'uf_{feature_name}': (B, feature_dim)},
                        'user_history': {'history': (B, max_H),
                                         'history_if_{feature_name}': (B, max_H, feature_dim),
                                         'history_{response}': (B, max_H),
                                         'history_length': (B, )}}
        - policy_output: {'state': (B, state_dim),
                          'action: (B, action_dim)}
        - reward: (B,)
        - done_mask: (B,)
        - next_observation: the same format as @input-observation
        '''

        # Compute the target Q value
        next_policy_output = self.apply_policy(next_observation, self.actor_target, self.epsilon, do_explore = True)
        target_critic1_output = self.apply_critic(next_observation, next_policy_output, self.critic1_target)
        target_critic2_output = self.apply_critic(next_observation, next_policy_output, self.critic2_target)
        target_Q = torch.min(target_critic1_output['q'], target_critic2_output['q'])
        # r+gamma*Q' when done; r+Q when not done
        # target_Q = reward + ((self.gamma * done_mask) + (1 - done_mask)) * target_Q.detach()
        target_Q = reward + ((self.gamma * done_mask) + torch.logical_not(done_mask)) * target_Q.detach()

        critic_loss_list = []
        if do_critic_update and self.critic_lr > 0:
            for critic, optimizer in [(self.critic1, self.critic1_optimizer),
                                           (self.critic2, self.critic2_optimizer)]:
                # Get current Q estimate
                current_critic_output = self.apply_critic(observation,
                                                                 wrap_batch(policy_output, device = self.device),
                                                                 critic)
                current_Q = current_critic_output['q']
                # Compute critic loss
                critic_loss = F.mse_loss(current_Q, target_Q).mean()
                critic_loss_list.append(critic_loss.item())

                # Optimize the critic
                optimizer.zero_grad()
                critic_loss.backward()
                optimizer.step()

        # Compute actor loss
        policy_output = self.apply_policy(observation, self.actor)
        critic_output = self.apply_critic(observation, policy_output, self.critic1)
        actor_loss = -critic_output['q'].mean()

        if do_actor_update and self.actor_lr > 0:
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

        return critic_loss_list, actor_loss


    def apply_policy(self, observation, policy_model, epsilon = 0,
                     do_explore = False, do_softmax = True):
        '''
        @input:
        - observation: input of policy model
        - policy_model
        - epsilon: greedy epsilon, effective only when do_explore == True
        - do_explore: exploration flag, True if adding noise to action
        - do_softmax: output softmax score
        '''
#         feed_dict = utils.wrap_batch(observation, device = self.device)
        feed_dict = observation
        # out_dict = policy_model(feed_dict)
        is_train = True
        input_dict = {'observation': observation,
                'candidates': self.env.get_candidate_info(observation),
                'epsilon': epsilon,
                'do_explore': do_explore,
                'is_train': is_train,
                'batch_wise': False}
        out_dict = policy_model(input_dict)

        return out_dict

    def apply_critic(self, observation, policy_output, critic_model):
        # feed_dict = {"state_emb": policy_output["state_emb"],
        #              "action_emb": policy_output["action_emb"]}
        feed_dict = {'state': policy_output['state'],
                'action': policy_output['hyper_action']}
        critic_output = critic_model(feed_dict)
        return critic_output

    def save(self):
        torch.save(self.critic1.state_dict(), self.save_path + "_critic1")
        torch.save(self.critic1_optimizer.state_dict(), self.save_path + "_critic1_optimizer")

        torch.save(self.critic2.state_dict(), self.save_path + "_critic2")
        torch.save(self.critic2_optimizer.state_dict(), self.save_path + "_critic2_optimizer")

        torch.save(self.actor.state_dict(), self.save_path + "_actor")
        torch.save(self.actor_optimizer.state_dict(), self.save_path + "_actor_optimizer")


    def load(self):
        self.critic1.load_state_dict(torch.load(self.save_path + "_critic1", map_location=self.device))
        self.critic1_optimizer.load_state_dict(torch.load(self.save_path + "_critic1_optimizer", map_location=self.device))
        self.critic1_target = copy.deepcopy(self.critic1)

        self.critic2.load_state_dict(torch.load(self.save_path + "_critic2", map_location=self.device))
        self.critic2_optimizer.load_state_dict(torch.load(self.save_path + "_critic2_optimizer", map_location=self.device))
        self.critic2_target = copy.deepcopy(self.critic2)

        self.actor.load_state_dict(torch.load(self.save_path + "_actor", map_location=self.device))
        self.actor_optimizer.load_state_dict(torch.load(self.save_path + "_actor_optimizer", map_location=self.device))
        self.actor_target = copy.deepcopy(self.actor)