Skip to content

Commit

Permalink
Add ASMR environment
Browse files Browse the repository at this point in the history
  • Loading branch information
bubble-07 committed Nov 19, 2023
1 parent 43e8112 commit 16d2f3f
Show file tree
Hide file tree
Showing 17 changed files with 620 additions and 31 deletions.
61 changes: 61 additions & 0 deletions core/ASMRnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

from typing import List, Optional, Tuple
import torch
import torch.nn as nn

from dataclasses import dataclass
from typing import Callable

@dataclass
class ASMRNetConfig:
feature_maps: int
layers: int

def reset_model_weights(m):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()

class ResidualBlock(nn.Module):
def __init__(self, features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.Linear(features, features),
nn.LeakyReLU(),
nn.Linear(features, features))
self.leaky_relu = nn.LeakyReLU()
def forward(self, x):
residual = x
out = self.block(x)
out += residual
return out


class ASMRNet(nn.Module):
def __init__(self, config: ASMRNetConfig, input_shape: torch.Size, output_shape: torch.Size) -> None:
super().__init__()
print(input_shape)
self.input_features, = input_shape

self.input_layer = nn.Linear(self.input_features, config.feature_maps)

self.res_blocks = nn.Sequential(
*[ResidualBlock(config.feature_maps) for _ in range(config.layers)])

self.policy_head = nn.Sequential(
nn.Linear(config.feature_maps, output_shape[0])
# we use cross-entropy loss so no need for softmax
)

self.value_head = nn.Sequential(
nn.Linear(config.feature_maps, 1)
)

self.config = config

def forward(self, x):
x = self.input_layer(x)
x = self.res_blocks(x)
policy = self.policy_head(x)
value = self.value_head(x)
return policy, value
5 changes: 2 additions & 3 deletions core/algorithms/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def evaluate(self, evaluation_fn: Callable) -> Tuple[torch.Tensor, Optional[torc
# choose next action with PUCT scores
actions = self.choose_action()

legal_actions = self.env.get_legal_actions()

# make a step in the environment with the chosen actions
self.env.step(actions)

Expand Down Expand Up @@ -209,8 +211,6 @@ def evaluate(self, evaluation_fn: Callable) -> Tuple[torch.Tensor, Optional[torc
with torch.no_grad():
policy_logits, values = evaluation_fn(self.env)

legal_actions = self.env.get_legal_actions()

policy_logits = (policy_logits * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions))
# store the policy
self.p_vals[self.env_indices, self.cur_nodes] = torch.softmax(policy_logits, dim=1)
Expand Down Expand Up @@ -306,4 +306,3 @@ def load_subtree(self, actions: torch.Tensor):
self.max_depths -= 1
self.max_depths.clamp_(min=1)


8 changes: 7 additions & 1 deletion core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def training_step(self):
# consistent with open_spiel implementation
policy_logits = (policy_logits * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions))

policy_loss = self.config.policy_factor * torch.nn.functional.cross_entropy(policy_logits, target_policy)
print(policy_logits)
print(target_policy)


policy_loss = self.config.policy_factor * torch.nn.functional.cross_entropy(policy_logits, target_policy, label_smoothing=0.01)
print(policy_loss)
sys.exit()
# multiply by 2 since most other implementations have values rangeing from -1 to 1 whereas ours range from 0 to 1
# this makes values loss a bit more comparable
value_loss = torch.nn.functional.mse_loss(values.flatten() * 2, target_value * 2)
Expand Down
2 changes: 1 addition & 1 deletion core/utils/custom_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def load_activation(activation: str) -> Optional[torch.nn.Module]:
if activation != '':
logging.warn(f'Warning: activation {activation} not found')
logging.warn(f'No activation will be applied to value head')
return None
return None
Empty file added envs/ASMR/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions envs/ASMR/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional
import torch
from core.train.collector import Collector
from core.algorithms.evaluator import Evaluator

class ASMRCollector(Collector):
def __init__(self,
evaluator: Evaluator,
episode_memory_device: torch.device
) -> None:
super().__init__(evaluator, episode_memory_device)

def assign_rewards(self, terminated_episodes, terminated):
episodes = []
for episode in terminated_episodes:
episode_with_rewards = []
moves = len(episode)
for (inputs, visits, legal_actions) in episode:
episode_with_rewards.append((inputs, visits, torch.tensor(moves, dtype=torch.float32, requires_grad=False, device=inputs.device), legal_actions))
moves -= 1
episodes.append(episode_with_rewards)
return episodes

def postprocess(self, terminated_episodes):
inputs, probs, rewards, legal_actions = zip(*terminated_episodes)
return list(zip(inputs, probs, rewards, legal_actions))
174 changes: 174 additions & 0 deletions envs/ASMR/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@

from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from core.env import Env, EnvConfig
from .torchscripts import push_actions, get_legal_actions, get_rewards, generate_random_games

@dataclass
class ASMREnvConfig(EnvConfig):
# The dimension of one of the matrices in the set
# (matrices are matrix_dim x matrix_dim)
matrix_dim: int = 2
# The minimum number of matrices in the starting set
min_initial_set_size: int = 3
# The maximum number of matrices in the starting set
max_initial_set_size: int = 7
# The standard deviation of the normal distribution which
# is used to generate elements of the initial and
# target matrices
normal_std_dev: float = 0.1
# The maximum number of turns to take
max_num_turns: int = 6
# The discount factor
discount_factor: float = 0.01

class ASMREnv(Env):
def __init__(self,
parallel_envs: int,
config: ASMREnvConfig,
device: torch.device,
debug=False
) -> None:
self.matrix_dim = config.matrix_dim
self.min_initial_set_size = config.min_initial_set_size
self.max_initial_set_size = config.max_initial_set_size
self.normal_std_dev = config.normal_std_dev
self.max_num_turns = config.max_num_turns
self.discount_factor = config.discount_factor

# Derived information

self.max_ending_set_size = self.max_num_turns + self.max_initial_set_size
self.max_policy_dim_size = self.max_ending_set_size - 1

# The total number of matrices, including the active
# set and the target matrix
self.total_num_matrices = self.max_ending_set_size + 1

# The size of the state vector
self.state_vector_size = self.total_num_matrices * self.matrix_dim * self.matrix_dim
self.state_vector_size += 1 # for the current set size tracker
self.state_vector_size += 1 # for the current turn counter

# The policy encompasses every index-pair, together with a
# probability of stopping
self.policy_vector_size = (self.max_policy_dim_size ** 2) + 1

super().__init__(
parallel_envs=parallel_envs,
config=config,
device=device,
num_players=1,
state_shape=torch.Size((self.state_vector_size, )),
policy_shape=torch.Size((self.policy_vector_size,)),
value_shape=torch.Size((1,)),
debug=debug
)

if self.debug:
self.get_legal_actions_ts = get_legal_actions
self.push_actions_ts = push_actions
self.get_rewards_ts = get_rewards
self.generate_random_games_ts = generate_random_games
else:
self.get_legal_actions_ts = torch.jit.trace(get_legal_actions, ( # type: ignore
self.states,
self.max_policy_dim_size,
))

self.push_actions_ts = torch.jit.trace(push_actions, ( # type: ignore
self.states,
torch.zeros((self.parallel_envs, ), dtype=torch.int64, device=device)
))

self.get_rewards_ts = torch.jit.trace(get_rewards, ( # type: ignore
self.states,
self.discount_factor
))
self.generate_random_games_ts = torch.jit.trace(generate_random_games, ( # type: ignore
self.parallel_envs,
self.matrix_dim,
self.min_initial_set_size,
self.max_initial_set_size,
self.normal_std_dev,
self.total_num_matrices
))

self.saved_states = self.states.clone()

def reset(self, seed: Optional[int] = None) -> int:
self.states.zero_()
self.terminated.fill_(True)
return self.reset_terminated_states(seed)

def reset_terminated_states(self, seed: Optional[int] = None) -> int:
if seed is not None:
torch.manual_seed(seed)
else:
seed = 0
# Zeros the states which are terminated
self.states *= torch.logical_not(self.terminated).view(self.parallel_envs, 1)

# Find the total number of terminated states
num_terminated_states = torch.sum(self.terminated)

if num_terminated_states > 0:
# Re-initialize the terminated states
random_games = self.generate_random_games_ts(
num_terminated_states,
self.matrix_dim,
self.min_initial_set_size,
self.max_initial_set_size,
self.normal_std_dev,
self.total_num_matrices,
self.states.get_device()
)

self.states[self.terminated] = random_games

# Clears the terminated mask, since presumably, all states have
# been correctly reset
self.terminated.zero_()
return seed

def next_turn(self):
# Apply updates to the state for the next turn
# I think (?) nothing really needs to be done here
return

def get_rewards(self, player_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.get_rewards_ts(self.states, self.total_num_matrices, self.matrix_dim, self.discount_factor)

def update_terminated(self) -> None:
self.terminated = self.is_terminal()

def is_terminal(self):
# The leaf nodes [not counting the ordinary nodes with the
# "terminate" action taken.]
out_of_turns = self.states[:, -1] >= self.max_num_turns
return torch.logical_or(self.terminated, out_of_turns)

def get_legal_actions(self) -> torch.Tensor:
# Gets the legal actions for the current state
return self.get_legal_actions_ts(self.states, self.max_policy_dim_size) # type: ignore

def push_actions(self, actions) -> None:
# Updates the state in response to an action
self.states, terminate_actions = self.push_actions_ts(self.states, actions,
self.total_num_matrices, self.matrix_dim,
self.max_policy_dim_size) # type: ignore
self.terminated = torch.logical_or(self.terminated, terminate_actions)

def save_node(self) -> torch.Tensor:
return self.states.clone()

def load_node(self, load_envs: torch.Tensor, saved: torch.Tensor):
load_envs_expnd = load_envs.view(self.parallel_envs, 1)
self.states = saved.clone() * load_envs_expnd + self.states * (~load_envs_expnd)
self.update_terminated()

def print_state(self, action=None) -> None:
assert self.parallel_envs == 1
self.states, action,
print("testing")
7 changes: 7 additions & 0 deletions envs/ASMR/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from core.test.tester import Tester

class ASMRTester(Tester):
def add_evaluation_metrics(self, episodes):
if self.history is not None:
for _ in episodes:
self.history.add_evaluation_data({}, log=self.log_results)

0 comments on commit 16d2f3f

Please sign in to comment.