## import

In [None]:
!pip install "ray[rllib]" --quiet
!pip install gym --quiet
import gym
import numpy as np
import statistics
import tensorflow as tf
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from uuid import uuid4 as uuid
import sys
import copy
import os
import pandas as pd
import argparse
import random
import time
import ray
from ray.rllib.agents.dqn import DQNTrainer, ApexTrainer, SimpleQTrainer, R2D2Trainer 
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.a3c import A3CTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.impala import ImpalaTrainer
import matplotlib.patheffects as pe
from scipy.special import softmax
import pickle
from scipy.stats import norm, truncnorm
import logging
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
import scipy.stats as stats
import matplotlib as mpl

from google.colab import drive
drive.mount('/content/drive')

def sigmoid(i):  
    return np.exp(-np.logaddexp(0, -i))

politics_color_mapping = ['darkblue', 'royalblue', 'lightsteelblue', 'thistle', 'lightcoral', 'crimson', 'darkred']
politics_labels = ["far left", "left", "lean left", "center", "lean right", "right", "far right"]
politics_labels_capital = ["Far Left", "Left", "Lean Left", "Center", "Lean Right", "Right", "Far Right"]

## simulation

In [None]:
n_categories = 7

def encode_politics(politics):
    i = np.clip(int(np.floor((politics+1)/(2/7))), 0, 6)
    oh_politics = np.zeros(7, dtype=int)
    oh_politics[i] = 1
    return oh_politics

def decode_politics(politics):
    assert len(politics)==7
    i = np.argmax(politics)
    start = i * (2/7) - 1
    end = start + (2/7)
    return (i, start, end)

class User:
    def __init__(self, belief=None, malleability=None, polarization_factor=None, random_seed=None, mu_l=-0.5, sigma_l=0.25, mu_r=0.3, sigma_r=0.3, p_l=0.55):
        self.random_seed = random_seed if random_seed is not None else np.random.randint(0, 10000)
        self.user_id = str(uuid())
        if belief is not None:
            self.belief = belief
        else:
            self.identity = np.random.choice(["L", "R"], p=[p_l, 1-p_l])
            self.belief = np.clip(np.random.normal(mu_l, sigma_l), -1, 1) if self.identity=="L" else  np.clip(np.random.normal(mu_r, sigma_r), -1, 1)
            if abs(self.belief) >= 0.9:
                self.belief = np.clip(np.random.normal(mu_l, sigma_l), -1, 1) if self.identity=="L" else  np.clip(np.random.normal(mu_r, sigma_r), -1, 1)
        
        self.malleability = np.random.uniform(0.01, 0.1)
        self.polarization_factor = 1 #np.random.uniform(0.25, 1.25) #np.random.uniform(0.7, 0.9) #polarization_factor if polarization_factor else 0.8 # np.random.choice([0.25, 0.75, 1.25])polarization_factor if polarization_factor else np.random.uniform(0.25, 1.25)
        self.engagement = 0
        self.satisfaction = 1
        self.open_mindedness = 0.5
        self.partisan_click_factor = 0.5

        self.is_active = True

        self.initial_belief = self.belief
        self.content_history, self.click_history, self.belief_history, self.action_history = [],[],[],[]
        self.n_clicks = 0


class Content:
    def __init__(self, bias):
        politics_mapping = {
          "extreme left": {'int': 0, 'center': -0.75},
          "left": {'int': 1, 'center': -0.50},
          "slight left": {'int': 2, 'center': -0.25}, 
          "neutral": {'int': 3, 'center': 0},
          "slight right": {'int': 4, 'center': 0.25},
          "right": {'int': 5, 'center': 0.50},
          "extreme right": {'int': 6, 'center': 0.75}
        }
        if type(bias)==str:
            self.politics_name = bias
            politics_encoding = np.zeros(7)
            politics_encoding[politics_mapping[bias]['int']] = 1
        else:
            politics_encoding = bias
        assert len(politics_encoding)==7
        self.politics_int, self.politics_start, self.politics_end = decode_politics(politics_encoding) 
        self.politics_name = [k for k,v in politics_mapping.items() if v['int']==self.politics_int][0]
        center = -0.75 + self.politics_int*0.25
        center = (self.politics_start + self.politics_end)/2
        self.politics = np.random.uniform(self.politics_start, self.politics_end)

def recommend_content(user, content, x=0.5, max_click_probability=0.8, probability_spread=10):

    ## opinion shift
    dissonance = content.politics - user['belief']    
    extremes_decay = (1 - abs(user['belief'])**2) if user['belief']*(dissonance - dissonance**3) > 0 else 1
    shift = dissonance*(1 - (dissonance**2)/(user['polarization_factor']**2))*user['malleability']*user['engagement']*extremes_decay
    user['belief'] += shift

    ## click probability
    opposition = user['belief']*content.politics
    click_probs = sigmoid((user['open_mindedness'])/(abs(dissonance) + 1e-8))**probability_spread * max_click_probability
    click_probs = np.clip(click_probs, 0, 1)
    click = np.random.choice([0,1], p=[1-click_probs, click_probs]) 
    user['click_history'].append(click)

    user['engagement'] += np.random.uniform(0.01, 0.1) if click else 0
    user['satisfaction'] *= 1 + np.random.uniform(0.01, 0.1) if click else 1 - np.random.uniform(0.01, 0.1)

    user['belief'] = np.clip(user['belief'], -1, 1)
    user['satisfaction'] = np.clip(user['satisfaction'], 0, 1)
    user['engagement'] = np.clip(user['engagement'], 0, 1)

    user['content_history'].append(content.politics)

    ## user attrition
    attrition_threshold = 0.25
    if user['satisfaction'] < attrition_threshold:
        user_attrition_probs = 1 - user["satisfaction"]/attrition_threshold
        user_leaves_platform = np.random.choice([False,True], p=[1-user_attrition_probs, user_attrition_probs])
        if user_leaves_platform:
            user['is_active'] = False
    
    user['n_clicks'] += click
    return user, click, click_probs

## visualizations

### user belief distribution

In [None]:
left_users = [user for user in users if user['identity']=="L"]
right_users = [user for user in users if user['identity']=="R"]

mpl.rcParams.update({'font.size': 22})

df = pd.DataFrame()
df['party'] = ['left' for user in left_users] + ['right' for user in right_users]
df['politics'] = [user['politics'] for user in left_users] + [user['politics'] for user in right_users]
plt.figure(figsize=(20, 10), dpi=80)

plt.hist([user['politics'] for user in left_users], color='darkblue', alpha=0.75, bins=28)
plt.hist([user['politics'] for user in right_users], color='darkred', alpha=0.75, bins=28)

plt.xlim(-1,1)
plt.xlabel('User Belief')
plt.ylabel('Number of Users')
plt.title('User Distribution')
plt.legend(labels=['Democrat', "Republican"])
plt.tight_layout()
plt.show(block=False)

### average shifts

In [None]:
ps,starts,ends = [],[],[]
d = 0.8
for i in range(7):
  p = np.zeros(7)
  p[i] = 1
  p_int, start, end = decode_politics(p)
  ps.append(p)
  starts.append(start)
  ends.append(end)
fig,axes = plt.subplots(7,figsize=(20,25), sharey=True)

cmap = mpl.colors.ListedColormap(politics_color_mapping)
polarization_factor = 0.8
for p,s,e,ax in zip(ps,starts,ends,axes):
  user_beliefs = [User(np.random.uniform(s,e)).belief for _ in range(100)]

  df = pd.DataFrame()
  shifts,cs,categories = [],[],[]
  for c in range(7):
    for u in user_beliefs:
        user = User(u).__dict__
        user['engagement'] = 1
        user['polarization_factor'] = polarization_factor
        content_politics = np.zeros(7)
        content_politics[c] = 1
        content = Content(content_politics)
        updated_user, _, _ = recommend_content(user, content)
        shift = user['belief'] - updated_user['initial_belief']
        shifts.append(shift)
        cs.append(content.politics)
        categories.append(c)
  xmin,xmax = -np.max([abs(np.min(shifts)), abs(np.max(shifts))]), np.max([abs(np.min(shifts)), abs(np.max(shifts))])
  ax.imshow([[-1,1.], [-1,1.]], aspect='auto', cmap = "coolwarm", interpolation = 'bicubic', extent=[xmin, xmax, -1.1, 1.1], alpha=xmax)
  df['shift'] = shifts
  df['content_politics'] = cs
  df['category'] = categories
  scatter = ax.scatter(data=df, x='shift', y='content_politics', c='category', cmap=cmap)

  ax.set_title(f"{politics_labels[np.argmax(p)]} user", weight='bold')
  
  ax.set_ylabel('content politics')
  ax.set_xlabel('shift')
  ax.set_xlim(xmin,xmax)
  ax.text(xmax*0.5, 0, "shift right")
  ax.text(xmin*0.5, 0, "shift left")

plt.tight_layout(h_pad=2)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
cbar=fig.colorbar(scatter, cax=cbar_ax)
cbar.set_ticks([0.5, 1.3, 2.2, 3, 3.8, 4.7, 5.5])
cbar.set_ticklabels(politics_labels)
cbar.set_label('content politics', labelpad=-75, y=1.05, rotation=0, weight='bold')

plt.show()

In [None]:
ps,starts,ends = [],[],[]
d = 0.8
for i in range(7):
  p = np.zeros(7)
  p[i] = 1
  p_int, start, end = decode_politics(p)
  ps.append(p)
  starts.append(start)
  ends.append(end)
fig,axes = plt.subplots(7, figsize=(30,35), sharey=True)
plt.rcParams.update({'font.size': 16})
cmap = mpl.colors.ListedColormap(politics_color_mapping)
polarization_factor = 1.25
xlim = 0.1
for p,s,e,user_ax in zip(ps,starts,ends,axes):
  user_beliefs = [User(np.random.uniform(s,e)).belief for _ in range(1000)]
  df = pd.DataFrame()
  shifts,cs,categories = [],[],[]
  for c in range(7):
    for u in user_beliefs:
        user = User(u).__dict__
        user['engagement'] = 1
        user['polarization_factor'] = 1
        content_politics = np.zeros(7)
        content_politics[c] = 1
        content = Content(content_politics)
        user, _, _ = recommend_content(user, content)
        shift = user['belief'] - user['initial_belief']
        shifts.append(shift)
        cs.append(content.politics)
        categories.append(c)
  xmin,xmax = -np.max([abs(np.min(shifts)), abs(np.max(shifts))]), np.max([abs(np.min(shifts)), abs(np.max(shifts))])
  user_ax.imshow([[-1,1.], [-1,1.]], aspect='auto', cmap ="coolwarm", interpolation ='bicubic', extent=[-xlim, xlim, -1.1, 1.1], alpha=0.25)
  user_ax.set_title(f"{politics_labels[np.argmax(p)]} user", weight='bold')
  df['shift'] = shifts
  df['content_politics'] = cs
  df['category'] = categories
  scatter = user_ax.scatter(data=df, x='shift', y='content_politics', c='category', cmap=cmap)
  user_ax.vlines(x=0, ymin=-1, ymax=1, colors='black')
  user_ax.set_title(f"{politics_labels[np.argmax(p)]} user", weight='bold')

  user_ax.set_xlim(-xlim,xlim)
  user_ax.set_xticks([-xlim, 0, xlim])
  if np.argmax(p)==3:
      user_ax.set_ylabel('content bias', weight='bold')
  if np.argmax(p)==6:
      user_ax.set_xlabel('user belief shift', weight='bold')

plt.tight_layout(h_pad=1)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
cbar=fig.colorbar(scatter, cax=cbar_ax)
cbar.set_ticks([0.5, 1.3, 2.2, 3, 3.8, 4.7, 5.5])
cbar.set_ticklabels(politics_labels)
cbar.set_label('content', labelpad=-75, y=1.05, rotation=0, weight='bold')


plt.show()

In [None]:
ps,starts,ends = [],[],[]
d = 0.8
for i in range(7):
  p = np.zeros(7)
  p[i] = 1
  p_int, start, end = decode_politics(p)
  ps.append(p)
  starts.append(start)
  ends.append(end)
fig,axes = plt.subplots(7, 3, figsize=(20,25), sharey=True)
plt.rcParams.update({'font.size': 10})
cmap = mpl.colors.ListedColormap(politics_color_mapping)
polarization_factor = 1.25
for p,s,e,user_ax in zip(ps,starts,ends,axes):
  for ax,polarization_factor,idx,xlim in zip(user_ax, [0.25, 0.5, 0.75], [0,1,2], [0.02, 0.02, 0.02]):
      user_beliefs = [User(np.random.uniform(s,e)).belief for _ in range(100)]

      df = pd.DataFrame()
      shifts,cs,categories = [],[],[]
      for c in range(7):
        for u in user_beliefs:
            user = User(u).__dict__
            user['engagement'] = 1
            # user['polarization_factor'] = polarization_factor
            content_politics = np.zeros(7)
            content_politics[c] = 1
            content = Content(content_politics)
            user, _, _ = recommend_content(user, content)
            shift = user['belief'] - user['initial_belief']
            shifts.append(shift)
            cs.append(content.politics)
            categories.append(c)
      xmin,xmax = -np.max([abs(np.min(shifts)), abs(np.max(shifts))]), np.max([abs(np.min(shifts)), abs(np.max(shifts))])
      ax.imshow([[-1,1.], [-1,1.]], aspect='auto', cmap ="coolwarm", interpolation ='bicubic', extent=[-xlim, xlim, -1.1, 1.1], alpha=0.25)
      df['shift'] = shifts
      df['content_politics'] = cs
      df['category'] = categories
      scatter = ax.scatter(data=df, x='shift', y='content_politics', c='category', cmap=cmap)
      ax.vlines(x=0, ymin=-1, ymax=1, colors='black')
      if idx==1:
          ax.set_title(f"{politics_labels[np.argmax(p)]} user", weight='bold')

      ax.set_xlim(-xlim,xlim)
      ax.set_xticks([-xlim, 0, xlim])
      if np.argmax(p)==3 and idx==0:
          ax.set_ylabel('content', weight='bold')
      if np.argmax(p)==6 and idx==1:
          ax.set_xlabel('shift', weight='bold')

plt.tight_layout(h_pad=1)
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
cbar=fig.colorbar(scatter, cax=cbar_ax)
cbar.set_ticks([0.5, 1.3, 2.2, 3, 3.8, 4.7, 5.5])
cbar.set_ticklabels(politics_labels)
cbar.set_label('content', labelpad=-75, y=1.05, rotation=0, weight='bold')

plt.show()

In [None]:
.# from scipy.spatial import ConvexHull
# ps,starts,ends = [],[],[]
# d = 1.25
# for i in range(7):
#   p = np.zeros(7)
#   p[i] = 1
#   p_int, start, end = decode_politics(p)
#   ps.append(p)
#   starts.append(start)
#   ends.append(end)

# cmap = mpl.colors.ListedColormap(politics_color_mapping)

# fig,axes = plt.subplots(7,figsize=(20,25), sharey=True)
# for u,ax in zip(range(7),axes):
#   all_shifts = []
#   for c in range(7):
#     df = pd.DataFrame()
#     shifts,cs,categories = [],[],[]
#     for content_politics in [starts[c], ends[c]]:
#       for user_politics in [starts[u],ends[u]]: 
#           for malleability in [0.001, 0.01, 0.1]:
#                 user = User(user_politics).__dict__
#                 user['malleability'] = malleability
#                 user['engaged'] = 1
#                 content = Content("neutral")
#                 content.politics = content_politics
#                 updated_user, _, _ = recommend_content(user, content, clip=True)
#                 shift = user['politics'] - updated_user['initial_belief']
#                 shifts.append(shift)
#                 cs.append(content.politics)
#                 categories.append(c)
#     df['shift'] = shifts
#     df['content_politics'] = cs
#     df['category'] = categories
#     # ax.scatter(data=df, x='shift', y='content_politics', c='category', cmap=cmap)

#     points = np.array([[shift,p] for shift,p in zip(shifts, cs)])

#     hull = ConvexHull(points)

#     ax.fill(points[hull.vertices,0], points[hull.vertices,1],politics_color_mapping[c])
#     all_shifts = all_shifts + shifts
#   # ax.fill_between(df['shift'], df['content_politics'])
#   xmin,xmax = -np.max([abs(np.min(all_shifts)), abs(np.max(all_shifts))]), np.max([abs(np.min(all_shifts)), abs(np.max(all_shifts))])
#   ax.imshow([[-1,1.], [-1,1.]], aspect='auto', cmap = "coolwarm", interpolation = 'bicubic', extent=[xmin, xmax, -1.1, 1.1], alpha=xmax)
#   ax.set_xlim(xmin-0.001,xmax+0.001)
# plt.show()

### click probability distributions

In [None]:
starts = [i * (2/7) - 1 for i in range(7)]
ends = [start + (2/7) for start in starts]
mids = [np.mean([start,end]) for start,end in zip(starts,ends)]

fig,axes = plt.subplots(4,2, figsize=(30,20))
mpl.rcParams.update({'font.size': 12})

for user_belief,label,ax in zip([-0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75], politics_labels_capital, [axes[0][0], axes[1][0], axes[2][0], axes[3][0], axes[0][1], axes[1][1], axes[2][1]]):
    beliefs, probs = [], []
    for start,end,color in zip(starts, ends ,politics_color_mapping):
        ax.fill_between([start, end], [1, 1], color=color, alpha=0.75)

    for i in range(7):
        for _ in range(1000):
            user = User(belief=user_belief)

            content_politics = np.zeros(7)
            content_politics[i] = 1
            content = Content(content_politics)

            user, click, click_probs = recommend_content(user.__dict__, content)
            beliefs.append(content.politics)
            probs.append(click_probs)
    
    ax.scatter(beliefs, probs, color='black')

    content_boundaries = [i * (2/7) - 1 for i in range(7)]

    ax.set_ylim(0, 1)
    ax.set_xlim(-1, 1)

    ax.set_title(f'{label} User (b={np.round(user_belief,2)})')
    ax.set_xlabel('Content Bias')
    ax.set_ylabel('Click Probability')
fig.suptitle('Click Probability Distributions', y=1.005, x=0.51, horizontalalignment='center', weight='bold')
plt.tight_layout()
plt.show()

In [None]:
color_mapping = ['darkblue', 'royalblue', 'lightsteelblue', 'thistle', 'lightcoral', 'crimson', 'darkred']
starts = [i * (2/7) - 1 for i in range(7)]
ends = [start + (2/7) for start in starts]
mids = [np.mean([start,end]) for start,end in zip(starts,ends)]

plt.figure(figsize=(20,10))
plt.xticks(starts + [1])
plt.bar(mids, [1 for _ in range(7)], color=color_mapping,  width=2/7)
for i in range(7):
  plt.text(x=mids[i], y=0.5, s=politics_labels_capital[i], horizontalalignment='center', color='white', fontdict={'fontsize':30}) # path_effects=[pe.withStroke(linewidth=4, foreground="black")])
plt.xlabel('Bias')
plt.yticks([])

plt.xlim(-1,1)
plt.ylim(0,1)
plt.tight_layout()
plt.title('Content Politics')
plt.show()

## gym environment

In [None]:
def get_env():
    class NewsRecommendationEnv(gym.Env):
        def __init__(self, env_config):        
            # Defines the length of the episode in terms of no. recommendations (we use horizon=30 in our experiments)
            self.horizon = env_config['horizon'] if 'horizon' in env_config else 100
            self.target_politics = env_config['target_politics'] if 'target_politics' in env_config else None
            self.mode = env_config['mode'] if 'mode' in env_config else None
            self.eval = env_config['eval'] if 'eval' in env_config else None
            self.target_politics_reward = 1 if 'target_politics_reward' not in env_config else env_config['target_politics_reward']
            self.other_reward = 0 if 'other_reward' not in env_config else env_config['other_reward']
            self.session_length = 500 if 'session_length' not in env_config else env_config['session_length']

            self.action_space = gym.spaces.Discrete(7)
            self.observation_space = gym.spaces.Box(0, 1, shape=(self.horizon,8), dtype=np.int64) #gym.spaces.Box(0, self.session_length, shape=(14,), dtype=np.int64)

            self.current_user = None

            self.session_history_tracker = ray.get_actor(env_config['actor_name']) if 'actor_name' in env_config else None

        def update_user_state(self, state, action, click):
            record = np.zeros(8)
            record[action] = 1 
            record[7] = 1 if click else 0
            state = np.concatenate([[record],state[:-1]])
            return state
      
        def record_session(self):
            ray.get(self.session_history_tracker.record_session.remote(time.time(), self.current_session, copy.deepcopy(self.current_user)))

        def reset(self, user=None):
            if user:
                self.current_user = user
            else:
                self.current_user = User().__dict__
                
            if 'state' not in self.current_user:
                self.current_user['state'] = np.zeros(8*self.horizon).reshape(self.horizon, 8)
            self.current_user['engaged'] = False
            self.current_user['is_active'] = True
            self.current_session = {
              'initial_belief': self.current_user['belief'], 
              'end_belief': self.current_user['belief'],
              'clicks': 0,
              'n': 0,
              'ctr': 0,
              'rewards': 0,
              'user_id': self.current_user['user_id']
            }

            return self.current_user['state']

        def step(self, action):
          
            reward = 0
            done = False

            content_politics = np.zeros(7)
            content_politics[action] = 1
            
            content = Content(content_politics)
            self.current_user, click, click_probs = recommend_content(self.current_user, content)
            if click:
              if self.mode == "manipulate":
                  reward += self.target_politics_reward if content.politics_int in self.target_politics else self.other_reward
              else:
                  reward += 1
            
            self.current_user['state'] = self.update_user_state(self.current_user['state'], action, click)

            self.current_session['n'] += 1
            self.current_session['clicks'] += click
            self.current_session['rewards'] += reward
            self.current_session['clicked'] = click

            if self.current_session['n']>=self.session_length:
                done = True
            if not self.current_user['is_active']:
                done = True
                reward = -1

            if done:
                self.current_session['end_belief'] = self.current_user['belief']
                self.current_session['belief_shift'] = self.current_session['end_belief'] - self.current_session['initial_belief']
                self.current_session['ctr'] = self.current_session['clicks'] / self.current_session['n']
                if self.session_history_tracker:
                    self.record_session()

            return self.current_user['state'], reward, done, self.current_session

    return NewsRecommendationEnv

In [None]:
NewsRecommendationEnv = get_env()
env = NewsRecommendationEnv({'horizon':10, 'session_length':500})
env.reset()

### here

In [None]:
shifts = []
lengths = []
sat = []
for _ in range(100):
    user = User().__dict__
    while True:
        action = np.random.choice([0,1,2,3,4,5,6])
        content = np.zeros(7)
        content[action] = 1
        content = Content(content)
        user,click,probs = recommend_content(user, content)
        if not user['is_active'] or len(user['content_history'])>=500:
            sat.append(user['satisfaction'])
            break

    shifts.append(user['politics'] - user['initial_belief'])
    lengths.append(len(user['content_history']))
sns.displot(shifts)
sns.displot(lengths)
sns.displot(sat)

In [None]:
p = []
attrition = []
lengths = []
for _ in range(100):
  user = User(-0.5).__dict__
  sat = []
  n = 0
  while user['politics'] < 0.8:
      sat.append(user['satisfaction'])
      if user['politics'] < -0.25:
          content = Content("extreme left")
      elif user['politics'] < 0:
          content = Content("neutral")
      elif user['politics'] < 0.4:
          content = Content("slight right")
      elif user['politics'] < 0.6:
          content = Content("right")
      else:
          content = Content("extreme right")
      user,click,click_probs = recommend_content(user, content)
      n += 1
      if not user['is_active']:
          attrition.append(1)
          # print(user['satisfaction'])
          break
      if n >= 500:
          break
  p.append(user['politics'])
  lengths.append(len(user['content_history']))
  # plt.plot([i for i in range(n)], sat)
  # plt.show()
sns.displot(p)
sns.displot(lengths)
sum(attrition)

In [None]:
env.reset(User().__dict__)

In [None]:
env.current_user['politics']

In [None]:
env.step(2)[0].shape

In [None]:
env.step(1)

In [None]:
## avg shift acting randomly
shifts = []
start = time.time()
for _ in range(1000):
    state = env.reset(User().__dict__)
    done = False
    while not done:
        state, reward, done, info = env.step(np.random.choice([i for i in range(7)]))
    shifts.append(env.current_user['politics'] - env.current_user['initial_belief'])
sns.displot(shifts, kde=True)

## random agent

In [None]:
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict


class RandomAgent(Trainer):
    """Trainer that produces random actions and never learns."""

    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return with_common_config({
            "rollouts_per_iteration": 10,
            "framework": "tf",  # not used
        })

    @override(Trainer)
    def _init(self, config, env_creator):
        self.env = env_creator(config["env_config"])

    class RandomPolicy():
        def __init__(self, env):
            self.env = env
        def compute_actions(self, obs_batch):
            return [[self.env.action_space.sample() for _ in obs_batch]]
    
    @override(Trainer)
    def get_policy(self):
        return self.RandomPolicy(self.env)

    @override(Trainer)
    def step(self):
        rewards = []
        steps = 0
        for _ in range(self.config["rollouts_per_iteration"]):
            obs = self.env.reset()
            done = False
            reward = 0.0
            while not done:
                action = self.env.action_space.sample()
                obs, r, done, info = self.env.step(action)
                reward += r
                steps += 1
            rewards.append(reward)
        return {
            "episode_reward_mean": np.mean(rewards),
            "timesteps_this_iter": steps,
        }
    @override(Trainer)
    def compute_single_action(self, state, full_fetch=False):
        if full_fetch:
            action = self.env.action_space.sample()
            q_values = np.zeros(self.env.action_space.n)
            q_values[action] = 1
            d = [{'action': action}, {'None': 'None'}, {'q_values': q_values}]
            return d
        return self.env.action_space.sample()

In [None]:
def compute_single_action(state):
    clicks = [int(v) for i,v in enumerate(state) if i%2==1 and v!=-1]
    contents = [int(v) for i,v in enumerate(state) if i%2==0 and v!=-1]
    content_clicks = {i: 0 for i in range(7)}
    for content,click in zip(contents,clicks):
        print(content,click)
        content_clicks[content] += click

    action = next(iter([content for content in range(7) if contents.count(content)<5 or content_clicks[content]>2]))
    return action

## baseline agent

In [None]:
class BaselineAgent(Trainer):
    """Trainer that produces random actions and never learns."""

    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return with_common_config({
            "rollouts_per_iteration": 10,
            "framework": "tf",  # not used
        })

    @override(Trainer)
    def _init(self, config, env_creator):
        self.env = env_creator(config["env_config"])

    class BaselinePolicy():
        def __init__(self, env):
            self.env = env
        def compute_actions(self, obs_batch):
            return [[self.compute_single_action(_) for _ in obs_batch]]
            
        def compute_single_action(self, state):
            ctrs = [[] for _ in range(7)]
            for s in state:
                ctrs[np.argmax(s[:7])].append(s[7])
            ctrs = [np.mean(ctr)**2 if len(ctr) > 0 else 0 for ctr in ctrs]
            recs_count = len([s for s in state if sum(s[:7]) > 0])
            if recs_count < 25:
                probs = softmax(ctrs)
                action = np.random.choice([0,1,2,3,4,5,6], p=probs)
            else:
                action = np.argmax(ctrs)
            return action

    @override(Trainer)
    def get_policy(self):
        return self.BaselinePolicy(self.env)

    @override(Trainer)
    def step(self):
        rewards = []
        steps = 0
        for _ in range(self.config["rollouts_per_iteration"]):
            state = self.env.reset()
            done = False
            reward = 0.0
            while not done:
                action = self.compute_single_action(state)
                state, r, done, info = self.env.step(action)
                reward += r
                steps += 1
            rewards.append(reward)
        return {
            "episode_reward_mean": np.mean(rewards),
            "timesteps_this_iter": steps,
        }
    @override(Trainer)
    def compute_single_action(self, state, full_fetch=False):
        ctrs = [[] for _ in range(7)]
        for s in state:
            ctrs[np.argmax(s[:7])].append(s[7])
        ctrs = [np.mean(ctr)**2 if len(ctr) > 0 else 0 for ctr in ctrs]
        recs_count = len([s for s in state if sum(s[:7]) > 0])
        if recs_count < 25:
            probs = softmax(ctrs)
            action = np.random.choice([0,1,2,3,4,5,6], p=probs)
        else:
            action = np.argmax(ctrs)
        return action

## session history tracker

In [None]:
@ray.remote
class SessionHistoryTracker:
    def __init__(self):
        self.session_histories = []    

    def record_session(self, timestamp, session_history, user):
        session_history["time"] = timestamp
        self.session_histories.append(session_history)

    def get_session_histories(self):
        return self.session_histories

## evaluation

In [None]:
def plot_results(session_histories, window):

    initial_politics = [i['initial_belief'] for i in session_histories]
    end_politics = [i['end_belief'] for i in session_histories]
    belief_shift = [abs(end - start) for end,start in zip(end_politics, initial_politics)]
    rewards = [i['rewards'] for i in session_histories]
    ctr = [i['ctr'] for i in session_histories]
    session_lengths = [i['n'] for i in session_histories]
    mpl.rcParams.update({'font.size': 16})

    fig, axes = plt.subplots(2,2)
    fig.set_size_inches(35, 15)
    for data, ax, ylabel in zip({'Average Rewards': rewards, 'Average CTR': ctr, 'Average Absolute Belief Shift': belief_shift,'Average Episode Length': session_lengths}.items(), 
                      [axes[0][0], axes[0][1], axes[1][0], axes[1][1]],
                      ['Reward', 'CTR', 'Shift', 'Steps']):
        feature_name, feature_data = data
        rolling_avg = []
        x = []
        for i in range(window,len(feature_data)):
            start = i-window if i>window else 0
            rolling_avg.append(np.mean(feature_data[start:i]))
            x.append(i)
        ax.plot(x, rolling_avg)
        ax.set_title(feature_name)
        ax.set_ylabel(ylabel)
        ax.set_xlabel('Iteration')
    plt.tight_layout()
    plt.show()

In [None]:
def eval_trainer(trainer, users=False, n_users=100, mode=None, horizon=100, target_politics=None, ctr_window=1, session_length=500, random_cumulative_clicks=None, baseline_cumulative_clicks=None):

    NewsRecommendationEnv = get_env()
    users = copy.deepcopy(users) if users else [User().__dict__ for _ in range(n_users)]
    envs = [NewsRecommendationEnv({'horizon':horizon, 'target_politics':target_politics, 'eval':True, 'mode': mode, 'session_length': session_length}) for _ in range(len(users))]
    for user,env in zip(users,envs):
        user['env'] = env

    all_users = {user['user_id']: user for user in users}
    active_users = list(all_users.keys())

    action_histories = {user['user_id']: [] for user in users}
    belief_histories = {user['user_id']: [] for user in users}
    click_histories = {user['user_id']: [] for user in users}
    attrition_histories = {user['user_id']: [] for user in users}
    session_lengths = []

    states = np.array([user['env'].reset(user) for user_id,user in all_users.items()])

    while len(active_users) > 0:
        actions = trainer.get_policy().compute_actions(obs_batch=states)[0]
        states = []
        updated_active_users = []
        for user_id, action in zip(active_users, actions):
            user = all_users[user_id]
            env = user['env']
            state, reward, done, info = env.step(action)
            belief_histories[env.current_user['user_id']].append(env.current_user['belief'])
            action_histories[env.current_user['user_id']].append(action)
            click_histories[env.current_user['user_id']].append(env.current_session['clicked'])
            all_users[env.current_user['user_id']] = env.current_user
            if done:
                session_lengths.append(env.current_session['n'])
                if not user['is_active']:
                    attrition_histories[env.current_user['user_id']].append(0)
            else:
                updated_active_users.append(env.current_user['user_id'])
                attrition_histories[env.current_user['user_id']].append(1)
                states.append(state)
                user['env'] = env
            
        active_users = updated_active_users
        states = np.array(states)

    color_mapping = ['darkblue', 'royalblue', 'lightsteelblue', 'thistle', 'lightcoral', 'crimson', 'darkred', 'white']
    plt.xticks([i for i in range(7)], politics_labels)
    plt.xticks(rotation=45)
    plt.bar([i for i in range(7)], [1 for _ in range(7)], color=color_mapping, width=1.0)
    plt.show()
    overall_shifts = []
    fig,axes = plt.subplots(7,3,figsize=(40,40))
    for p,ax in zip(range(7),axes):
        most_common_actions = []
        average_beliefs = []
        actions,beliefs,clicks,attritions = [],[],[],[]
        for belief_history,action_history,click_history,attrition_history in zip(belief_histories.values(), action_histories.values(), click_histories.values(), attrition_histories.values()):
            if np.argmax(encode_politics(belief_history[0]))==p:
                beliefs.append(belief_history)
                actions.append(action_history)
                clicks.append(click_history)
                attritions.append(attrition_history)
        for step in range(500):
            actions_at_step = [action_history[step] for action_history in actions if len(action_history)>step]
            belief_at_step = [belief_history[step] for belief_history in beliefs if len(belief_history)>step]
            if len(actions_at_step)==0:
                break
            most_common_action = stats.mode(actions_at_step)
            avg_belief = np.mean(belief_at_step)
            if len(most_common_action.mode) == 0:
                most_common_actions.append(7)
            else:
                most_common_actions.append(most_common_action.mode[0])
            average_beliefs.append(avg_belief)

        rolling_avg_ctrs = []
        for step in range(len(most_common_actions)):
            ctrs = []
            for click_history in clicks:
                ctr = np.mean(click_history[:step])
                ctrs.append(ctr)
            rolling_avg_ctrs.append(np.mean(ctrs))

        while len(most_common_actions)<500:
            most_common_actions.append(-1)

        colors = ["lightgrey" if action==-1 else color_mapping[action] for action in most_common_actions]

        mpl.rcParams.update({'font.size': 30})

        ax[0].bar([step + 1 for step in range(50)], 
                  [1.05 for _ in most_common_actions[:50]], 
                  color=colors, width=1.0)            
        # ax[0].plot([step for step in range(len(rolling_avg_ctrs[:50]))], 
        #             rolling_avg_ctrs[:50], color='green', linewidth=3, 
        #             path_effects=[pe.Stroke(linewidth=5, foreground='white'), pe.Normal()],
        #           )
        # ax[0].set_ylabel("CTR")
        ax[0].set_xlabel("Step")
        ax[0].set_yticks([])
        ax[0].set_ylim(0, 1.05)
        ax[0].set_ylabel(f'{politics_labels_capital[p]} Users',  rotation=0, fontsize=40, labelpad=175, horizontalalignment='right')
        ax[0].set_xticks([1, 10, 20, 30, 40, 50])


        ax[1].bar([step + 1for step in range(len(most_common_actions))], [1.05 for _ in most_common_actions], color=colors, width=1.0)
            
        ax[1].plot([step + 1for step in range(len(rolling_avg_ctrs))], 
                      rolling_avg_ctrs, color='green', linewidth=3, path_effects=[pe.Stroke(linewidth=5, foreground='white'), pe.Normal()],
                      )
        ax[1].set_ylabel("CTR")
        ax[1].set_xlabel("Step")
        ax[1].set_ylim(0, 1.05)
        ax[1].set_xticks([1, 100, 200, 300, 400, 500])


        ## belief shift by politics
        for user_b in beliefs:
            ax[2].plot([step for step in range(len(user_b))], 
                       user_b, color=politics_color_mapping[p])

        ax[2].set_ylim(-1,1)
        ax[2].set_ylabel("Belief")
        # ax[2].plot([0, len(rolling_avg_ctrs)], 
        #            [np.mean([user_b[0] for user_b in beliefs]), np.mean([user_b[-1] for user_b in beliefs])], 
        #            color='black', linewidth=4, path_effects=[pe.Stroke(linewidth=5, foreground='white'), pe.Normal()])
        
        ax2 = ax[2].twinx()
        max_len = np.max([len(i) for i in attritions])
        user_counts = []
        for _ in range(max_len):
            user_counts.append(np.sum([attrition[_] for attrition in attritions if len(attrition) > _]))
        ax2.plot([_ for _ in range(max_len)], user_counts, color='black')
        ax2.set_ylabel('Active Users')
        ax2.set_ylim(0, len(attritions))
        ax[2].set_xlabel('Step')

        if p==0:
            ax[0].set_title("Average Recommendations (First 50)", pad=100)
            ax[1].set_title("Average Recommendations and CTR", pad=100)
            ax[2].set_title("Belief Shift and Attrition", pad=100)

        avg_shift = np.mean([b[-1]-b[0] for b in beliefs])
        overall_shifts.append(avg_shift)
    plt.tight_layout()
    plt.show()

    all_shifts = [p[-1]-p[0] for p in belief_histories.values()]
    
    starts = [i[0] for i in belief_histories.values()]
    ends = [i[-1] for i in belief_histories.values()]

    ## start vs end belief distribution
    df = pd.DataFrame()
    df['Belief'] = starts + ends
    df['Time'] = ['Start' for _ in range(len(starts))] + ['End' for _ in range(len(ends))]
    sns.displot(data=df, x='Belief',  hue='Time', kind="kde", rug=True, height=8, aspect=15/8)

    plt.title('Belief Distribution Change')
    plt.xlim(-1, 1)
    plt.show()

    ## start vs end belief scatter
    starts, ends = (list(t) for t in zip(*sorted(zip(starts, ends))))
    plt.figure(figsize=(25, 5), dpi=80)
    plt.scatter(starts, starts)
    plt.scatter(starts, ends)
    plt.show()

    ## belief changes
    color_mapping = ['darkblue', 'royalblue', 'lightsteelblue', 'thistle', 'lightcoral', 'crimson', 'darkred']
    plt.figure(figsize=(25, 10), dpi=80)
    for user in belief_histories.values():
        plt.plot([step for step in range(len(user))], user, color=color_mapping[np.argmax(encode_politics(user[0]))])
    plt.ylim(-1.05,1.05)
    plt.title('Overall Belief Shifts')
    plt.ylabel('Belief')
    plt.xlabel('Step')
    plt.show()


    ## average shifts
    plt.figure(figsize=(20, 10), dpi=80)
    shifts = [np.mean([end - start for start,end in zip(starts,ends) if np.argmax(encode_politics(start))==p]) for p in range(7)]
    max_shift = np.max([abs(np.min(shifts)), abs(np.max(shifts))])
    norm = plt.Normalize(-max_shift, max_shift)
    cmap = plt.get_cmap("coolwarm")
    colors = cmap(norm(shifts))
    plt.barh(politics_labels, shifts, color=colors)
    plt.xlabel('avg shift')
    plt.xlim(-max_shift - 0.05, max_shift + 0.05)
    plt.show()


    session_lengths = [len(h) for h in belief_histories.values()]
    print(f'avg session length: {np.mean(session_lengths)}')

In [None]:
def get_cumulative_clicks(trainer, users):
    NewsRecommendationEnv = get_env()
    users = copy.deepcopy(users) if users else [User().__dict__ for _ in range(n_users)]
    envs = [NewsRecommendationEnv({'eval':True}) for _ in range(len(users))]
    for user,env in zip(users,envs):
        user['env'] = env

    all_users = {user['user_id']: user for user in users}
    active_users = list(all_users.keys())

    action_histories = {user['user_id']: [] for user in users}
    belief_histories = {user['user_id']: [] for user in users}
    click_histories = {user['user_id']: [] for user in users}
    session_lengths = []

    states = np.array([user['env'].reset(user) for user_id,user in all_users.items()])

    while len(active_users) > 0:
        actions = trainer.get_policy().compute_actions(obs_batch=states)[0]
        states = []
        updated_active_users = []
        for user_id, action in zip(active_users, actions):
            user = all_users[user_id]
            env = user['env']
            state, reward, done, info = env.step(action)
            belief_histories[env.current_user['user_id']].append(env.current_user['belief'])
            action_histories[env.current_user['user_id']].append(action)
            click_histories[env.current_user['user_id']].append(env.current_session['clicked'])
            all_users[env.current_user['user_id']] = env.current_user
            if done:
                session_lengths.append(env.current_session['n'])
            else:
                updated_active_users.append(env.current_user['user_id'])
                states.append(state)
                user['env'] = env
            
        active_users = updated_active_users
        states = np.array(states)
    cumulative_clicks = []
    clicks = 0
    for step in range(500):
        step_clicks = np.mean([click_history[step] for click_history in click_histories.values() if len(click_history)>step])
        clicks += step_clicks if not np.isnan(step_clicks) else 0 
        cumulative_clicks.append(clicks)
    return cumulative_clicks

In [None]:
cumulative_clicks = get_cumulative_clicks(random_trainer, eval_users)
plt.plot([i for i in range(len(cumulative_clicks))], cumulative_clicks)

## saving

In [None]:
def save_results(session_histories, users_histories, file_name):
    df = pd.DataFrame(session_histories)
    df.to_csv(f'{file_name} - session_history.csv', index=False)

## training function

In [None]:
def train(n_episodes, session_history_tracker, horizon, agent_name, eval_interval, log_interval, trainer, eval_users, n=None, target_politics=None, mode=None, metrics_window=1000, save_policy=True, session_length=100, checkpoint_interval=100):
  
    train_start = time.time()

    episode_start = time.time()
    results = []

    for i in range(n_episodes):
        result = trainer.train()
        results.append(result)

        if i%checkpoint_interval==0 and i>0:
            t = int(time.time())
            trainer.save(f'drive/MyDrive/thesis-training/{agent_name}-{t}/{i}')
            session_histories = ray.get(session_history_tracker.get_session_histories.remote())
            with open(f'drive/MyDrive/thesis-training/{agent_name}-{t}/{i}.pickle', 'wb') as handle:
                 pickle.dump(session_histories, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        if i%log_interval==0 and i>0:
            try:
                print(f"episode: {i} - avg reward={result['episode_reward_mean']} - avg episode length={result['episode_len_mean']} - {(time.time() - episode_start)/60} min")
            except:
                print(f"episode: {i} - avg reward={result['episode_reward_mean']} - {(time.time() - episode_start)/60} min")
            episode_start = time.time()
        if i%eval_interval==0 and i>0:
            eval_trainer(trainer, eval_users, mode=mode, horizon=horizon, target_politics=target_politics, session_length=session_length)

    print(f'total time: {(time.time() - train_start)/60}')

    t = int(time.time())
    trainer.save(f'drive/MyDrive/thesis-training/{agent_name}-{t}/final')
    session_histories = ray.get(session_history_tracker.get_session_histories.remote())
    with open(f'drive/MyDrive/thesis-training/{agent_name}-{t}/final.pickle', 'wb') as handle:
          pickle.dump(session_histories, handle, protocol=pickle.HIGHEST_PROTOCOL)

    session_histories = ray.get(session_history_tracker.get_session_histories.remote())

    plot_results(session_histories=session_histories, window=metrics_window)

    eval_trainer(trainer, eval_users, mode=mode, horizon=horizon, target_politics=target_politics, session_length=session_length)
    return session_histories

## train

In [None]:
eval_users = [User().__dict__ for _ in range(1000)]

In [None]:
num_workers = 8
num_gpus = 1
horizon = 100
session_length = 500

### random

In [None]:
ray.shutdown()
ray.init()

actor_name = f"session_history_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)

random_trainer = RandomAgent(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": None,
          "mode": None,
          "eval": False,
          "session_length": session_length
      },
      "num_workers": 8,
      "num_gpus": 1, 
      "num_gpus_per_worker": 1/8, 
      "num_envs_per_worker": 50,
  })
eval_trainer(random_trainer, eval_users, mode=None, horizon=horizon, target_politics=None, session_length=session_length)

### baseline

In [None]:
def test(row):
    return (row)

In [None]:
ray.shutdown()
ray.init()

actor_name = f"session_history_tracker_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)

baseline_trainer = BaselineAgent(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": None,
          "mode": None,
          "eval": False
      },
      "num_workers": 8,
      "num_gpus": 1, 
      "num_gpus_per_worker": 1/8, 
      "num_envs_per_worker": 50,
  })

eval_trainer(baseline_trainer, eval_users, mode=None, horizon=horizon, target_politics=None, session_length=session_length)

In [None]:
x = np.array([[1,2,3], [4,5,6]])
sum(x)

In [None]:
state = np.array([0,1,2,3,4,5,6,10,10,10,10,10,10,10])
print(np.where(np.array(state[:7])<5))
# np.argmin(state[:7]) if len(np.where(state[:7])<5)>0 else np.argmax(state[7:])

### no manipulation

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = None
target_politics = None
actor_name = f"session_history_tracker_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote()
np.random.seed(47)

trainer = ApexTrainer(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False,
          "session_length": session_length
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 100,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": None,
      "batch_mode": "complete_episodes",
      "seed": 47,
      # "n_step": 1,
      "model": {
                'dim': 84,
                'fcnet_activation': 'tanh',
                'fcnet_hiddens': [256, 256],
                'framestack': False,
                'use_lstm': False
                }
  })

session_histories = train(n_episodes=300,
          session_history_tracker=session_history_tracker,
          horizon=horizon, 
          agent_name="no_manipulation", 
          eval_users=eval_users,
          eval_interval=251, 
          checkpoint_interval=100,
          log_interval=25, 
          metrics_window=1000, 
          session_length=session_length,
          trainer=trainer)

### extremes

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = "manipulate"
target_politics = [0, 6]
actor_name = f"session_history_tracker_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)
np.random.seed(47)

config = {
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False,
          "session_length": session_length
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 100,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": None,
      "batch_mode": "complete_episodes",
      "seed": 47,
      # "n_step": 1,
      "model": {
                'dim': 84,
                'fcnet_activation': 'tanh',
                'fcnet_hiddens': [256, 256],
                'framestack': False,
                }
  }

trainer = ApexTrainer(config=config)

session_histories = train(n_episodes=4,
      session_history_tracker=session_history_tracker,
      horizon=horizon, 
      agent_name="extremes", 
      eval_users=eval_users,
      eval_interval=1,
      checkpoint_interval=2,
      n=len(train_users), 
      log_interval=1, 
      metrics_window=1000, 
      session_length=session_length,
      trainer=trainer)

In [None]:
# session_histories = ray.get(session_history_tracker.get_session_histories.remote())

plot_results(session_histories=session_histories, window=1000)

eval_trainer(trainer, eval_users, mode=mode, horizon=horizon, target_politics=target_politics, session_length=session_length)

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = "manipulate"
target_politics = [0,6]
actor_name = f"session_history_tracker_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)

trainer = ApexTrainer(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False,
          "session_length": session_length
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 100,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": None,
      # "soft_horizon": True,
      # "horizon": 25
  })

try:
    session_histories = train(n_episodes=500,
          session_history_tracker=session_history_tracker,
          horizon=horizon, 
          agent_name="extremes", 
          eval_users=eval_users,
          eval_interval=100,
          n=len(train_users), 
          log_interval=1, 
          metrics_window=1000, 
          session_length=session_length,
          trainer=trainer)
except Exception as e:
    print(e)
    session_histories = ray.get(session_history_tracker.get_session_histories.remote())
    plot_results(session_histories=session_histories, window=1000)
    eval_trainer(trainer, eval_users, mode=mode, horizon=horizon, target_politics=target_politics, session_length=session_length)

### center

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = "manipulate"
target_politics = [2,3,4]
actor_name = f"session_history_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)
np.random.seed(47)

trainer = ApexTrainer(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False,
          "session_length": session_length
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 100,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": None,
      "batch_mode": "complete_episodes",
      "seed": 47,
      "model": {
                'dim': 84,
                'fcnet_activation': 'tanh',
                'fcnet_hiddens': [256, 256],
                'framestack': False,
                }
  })
session_histories = train(n_episodes=1000,
      session_history_tracker=session_history_tracker,
      horizon=horizon, 
      agent_name="center", 
      eval_users=eval_users,
      eval_interval=1001,
      checkpoint_interval=100,
      n=len(train_users), 
      log_interval=25, 
      metrics_window=1000, 
      session_length=session_length,
      trainer=trainer)

### right

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = "manipulate"
target_politics = [4,5,6]
actor_name = f"session_history_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)
np.random.seed(47)

trainer = ApexTrainer(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False,
          "session_length": session_length
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 100,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": None,
      "batch_mode": "complete_episodes",
      "seed": 47,
      "n_step": 1,
      "model": {
                'dim': 84,
                'fcnet_activation': 'tanh',
                'fcnet_hiddens': [256, 256],
                'framestack': False,
                }
  })
try:
    session_histories = train(n_episodes=500,
          session_history_tracker=session_history_tracker,
          horizon=horizon, 
          agent_name="right", 
          eval_users=eval_users,
          eval_interval=50,
          n=len(train_users), 
          log_interval=25, 
          metrics_window=1000, 
          session_length=session_length,
          trainer=trainer)
except Exception as e:
    print(e)
    session_histories = ray.get(session_history_tracker.get_session_histories.remote())
    plot_results(session_histories=session_histories, window=1000)
    eval_trainer(trainer, eval_users, mode=mode, horizon=horizon, target_politics=target_politics, session_length=session_length)

### neutral

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = "manipulate"
target_politics = [3]
actor_name = f"session_history_{time.time()}"
sessionHistory = SessionHistory.options(name=actor_name).remote()

trainer = ApexTrainer(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 50,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": 1000000000,
  })


session_histories = train(n_episodes=250, 
      horizon=horizon, 
      agent_name="right",
      eval_users=eval_users, 
      eval_interval=100, 
      log_interval=5, 
      metrics_window=1000, 
      trainer=trainer)

### small reward increase for extreme content

In [None]:
ray.shutdown()
ray.init()
horizon = 100
mode = "manipulate"
target_politics = [0, 6]
actor_name = f"session_history_tracker_{time.time()}"
session_history_tracker = SessionHistoryTracker.options(name=actor_name).remote(users=train_users)
np.random.seed(47)

trainer = ApexTrainer(
  config={
      "env": get_env(),
      "env_config": {
          "horizon": horizon,
          "actor_name": actor_name,
          "target_politics": target_politics,
          "mode": mode,
          "eval": False,
          "session_length": session_length,
          "target_politics_reward": 1,
          "other_reward": 0.9,
      },
      "num_workers": num_workers,
      "num_gpus": num_gpus, 
      "num_envs_per_worker": 50,
      "metrics_smoothing_episodes": 10000,
      "evaluation_interval": None,
      "batch_mode": "complete_episodes",
      "seed": 47,
      # "n_step": 1,
      "model": {
                'dim': 84,
                'fcnet_activation': 'tanh',
                'fcnet_hiddens': [256, 256],
                'framestack': False,
                }
  })

# try:
session_histories = train(n_episodes=600,
      session_history_tracker=session_history_tracker,
      horizon=horizon, 
      agent_name="extremes", 
      eval_users=eval_users,
      eval_interval=50,
      n=len(train_users), 
      log_interval=25, 
      metrics_window=1000, 
      session_length=session_length,
      trainer=trainer)
# except Exception as e:
#     print(e)
#     session_histories = ray.get(session_history_tracker.get_session_histories.remote())
#     plot_results(session_histories=session_histories, window=1000)
#     eval_trainer(trainer, eval_users, mode=mode, horizon=horizon, target_politics=target_politics, session_length=session_length)
