In [1]:
!pip install highway-env gymnasium torch numpy tqdm
!sudo apt-get update
!sudo apt-get install build-essential swig
!pip install gymnasium[box2d]

Collecting highway-env
  Downloading highway_env-1.10.1-py3-none-any.whl.metadata (16 kB)
Collecting gymnasium
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Downloading highway_env-1.10.1-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gymnasium, highway-env
  Attempting uninstall: gymnasium
    Found existing installation: gymnasium 0.29.0
    Uninstalling gymnasium-0.29.0:
      Successfully uninstalled gymnasium-0.29.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
kaggle-environments 1.16.10 requires gymnasiu

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AlphaZeroNetwork(nn.Module):
    def __init__(self, input_shape, n_residual_layers=10, n_actions=5):
        super(AlphaZeroNetwork, self).__init__()
        self.input_shape = input_shape
        self.n_actions = n_actions

        # Convolution đầu tiên
        self.conv_layer = nn.Conv2d(input_shape[2], 128, kernel_size=11, padding=5)
        self.batch_norm = nn.BatchNorm2d(128)

        # Residual layers
        self.residual_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(128, 128, kernel_size=7, padding=3),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128)
            )
            for _ in range(n_residual_layers)
        ])

        # Value head
        self.value_conv = nn.Conv2d(128, 1, kernel_size=1)
        self.value_bn = nn.BatchNorm2d(1)
        self.fc_input_size = input_shape[0] * input_shape[1]

        self.value_fc1 = nn.Linear(self.fc_input_size, 128)
        self.value_dropout = nn.Dropout(0.1)  # Thêm Dropout
        self.value_fc2 = nn.Linear(128, 1)

        # Policy head
        self.policy_conv = nn.Conv2d(128, 2, kernel_size=1)
        self.policy_bn = nn.BatchNorm2d(2)
        self.policy_fc = nn.Linear(self.fc_input_size * 2, n_actions)
        self.policy_dropout = nn.Dropout(0.1)  # Thêm Dropout

    def forward(self, x):
        # Convolution đầu tiên và batch normalization
        x = F.relu(self.batch_norm(self.conv_layer(x)))

        # Residual layers
        for residual in self.residual_layers:
            residual_x = x
            x = residual(x) + residual_x
            x = F.relu(x)

        # Value head
        value = F.relu(self.value_bn(self.value_conv(x)))
        value = value.reshape(value.size(0), -1)
        value = F.relu(self.value_fc1(value))
        value = self.value_dropout(value)  # Áp dụng Dropout
        value = torch.tanh(self.value_fc2(value))

        # Policy head
        policy = F.relu(self.policy_bn(self.policy_conv(x)))
        policy = policy.reshape(policy.size(0), -1)
        policy = self.policy_fc(self.policy_dropout(policy))  # Áp dụng Dropout
        policy = F.softmax(policy, dim=1)

        return policy, value


In [3]:
class KinematicToGridWrapper:
    def __init__(self):
        # Define grid parameters
        self.x_range = (-30, 90)  # meters relative to ego
        self.y_range = (-10, 10)   # meters relative to ego

        # Calculate grid size based on 1m per cell
        self.grid_size = (
            self.x_range[1] - self.x_range[0],  # 120 cells for x
            self.y_range[1] - self.y_range[0]   # 20 cells for y
        )

        # Car dimensions
        self.car_length = 5  # meters
        self.car_width = 2   # meters

    def get_car_footprint(self, x, y, heading):
        """Calculate which cells a car occupies given its center and heading"""
        occupied_cells = []

        cos_h = np.cos(heading)
        sin_h = np.sin(heading)

        corners_car = [
            (-self.car_length/2, -self.car_width/2),
            (self.car_length/2, -self.car_width/2),
            (self.car_length/2, self.car_width/2),
            (-self.car_length/2, self.car_width/2)
        ]

        corners_world = [
            (x + dx*cos_h - dy*sin_h, y + dx*sin_h + dy*cos_h)
            for dx, dy in corners_car
        ]

        min_x = min(x[0] for x in corners_world)
        max_x = max(x[0] for x in corners_world)
        min_y = min(x[1] for x in corners_world)
        max_y = max(x[1] for x in corners_world)

        for cell_x in range(int(min_x), int(max_x) + 1):
            for cell_y in range(int(min_y), int(max_y) + 1):
                if self.point_in_rotated_rect(
                    cell_x + 0.5, cell_y + 0.5,
                    x, y, heading,
                    self.car_length, self.car_width
                ):
                    occupied_cells.append((cell_x, cell_y))

        return occupied_cells

    def world_to_grid(self, x, y):
        """Convert world coordinates to grid coordinates"""
        grid_x = int(x - self.x_range[0])
        # Flip y-axis to maintain correct orientation
        grid_y = int(self.grid_size[1] - (y - self.y_range[0]) - 1)
        return grid_x, grid_y

    def point_in_rotated_rect(self, px, py, rect_x, rect_y, rect_angle, length, width):
        dx = px - rect_x
        dy = py - rect_y

        cos_h = np.cos(-rect_angle)
        sin_h = np.sin(-rect_angle)

        rotated_x = dx * cos_h - dy * sin_h
        rotated_y = dx * sin_h + dy * cos_h

        return (abs(rotated_x) <= length/2) and (abs(rotated_y) <= width/2)

    def process_observation(self, obs, left_bound, right_bound):
        """
        Process vehicle observations and return separate ego info and occupancy grid
        obs: list of [x, y, vx, vy, heading] for each vehicle (ego first)
        """
        # Extract ego vehicle state
        ego_x, ego_y, ego_vx, ego_vy, ego_heading = obs[0]

        # Initialize grid
        grid = np.zeros((self.grid_size[0], self.grid_size[1], 3), dtype=np.float32)

        to_left =  ego_y - left_bound
        to_right = right_bound - ego_y

        left = int(self.grid_size[1]/2 -1 - to_left)
        right = int(self.grid_size[1]/2 +1 + to_right)

        if left >= 0:
            grid[:, :left + 1, 0] = 2
            grid[:, :left + 1, 2] = ego_vy
        if right < self.grid_size[1]:
            grid[:, right:, 0] = 2
            grid[:, right:, 2] = ego_vy

        # Place ego vehicle
        ego_cells = self.get_car_footprint(0, 0, ego_heading)
        for cell_x, cell_y in ego_cells:
            grid_x, grid_y = self.world_to_grid(cell_x, cell_y)

            if (0 <= grid_x < self.grid_size[0] and
                0 <= grid_y < self.grid_size[1]):
                grid[grid_x, grid_y, 0] = 1
                grid[grid_x, grid_y, 1] = 0
                grid[grid_x, grid_y, 2] = 0

        # Process other vehicles
        for vehicle in obs[1:]:
            x, y, vx, vy, heading = vehicle

            # Get relative position
            rel_x = x - ego_x
            rel_y = y - ego_y

            # Get relative velocities
            rel_vx = vx - ego_vx
            rel_vy = vy - ego_vy

            # Get relative heading

            # Skip if vehicle center is out of range
            if (rel_x < self.x_range[0] or rel_x > self.x_range[1] or
                rel_y < self.y_range[0] or rel_y > self.y_range[1]):
                continue

            # Get all cells occupied by this vehicle
            occupied_cells = self.get_car_footprint(rel_x, rel_y, heading)

            # Convert to grid coordinates and update grid
            for cell_x, cell_y in occupied_cells:
                grid_x, grid_y = self.world_to_grid(cell_x, cell_y)

                if (0 <= grid_x < self.grid_size[0] and
                    0 <= grid_y < self.grid_size[1]):
                    grid[grid_x, grid_y, 0] = 2
                    grid[grid_x, grid_y, 1] = rel_vx
                    grid[grid_x, grid_y, 2] = rel_vy

        return np.array(grid).transpose(2,0,1)
converter = KinematicToGridWrapper()

In [4]:
import numpy as np

import numpy as np

def init_stack_of_planes(env, history_length=5):
    """
    Initialize a stack of planes representing the grid and velocity information.

    Args:
        env: The environment instance, expected to have the method `unwrapped.observation_type.observe()`
             to return the observation.
        history_length (int): Number of past frames to stack.

    Returns:
        A numpy array of shape (history_length + 2, grid_height, grid_width) representing the stack of planes.
    """
    # Initialize the KinematicToGridWrapper
    converter = KinematicToGridWrapper()

    # Collect the initial observations and process each to form planes
    grid_planes_list = []
    for _ in range(history_length):
        obs = env.unwrapped.observation_type.observe()
        grid_planes = converter.process_observation(obs, -2, 4 * (4 - 1) + 2)[0]
        grid_planes_list.append(grid_planes)

    # Convert list to numpy array with shape (history_length, grid_height, grid_width)
    stack_of_planes = np.stack(grid_planes_list, axis=0)

    # Add two additional planes: obs[1] and obs[2] from the last observation
    last_obs = converter.process_observation(env.unwrapped.observation_type.observe(), -2, 4 * (4 - 1) + 2)
    plane_6 = last_obs[1]  # Second plane of the latest observation
    plane_7 = last_obs[2]  # Third plane of the latest observation

    # Concatenate the new planes to the stack
    stack_of_planes = np.concatenate(
        [stack_of_planes, np.expand_dims(plane_6, axis=0), np.expand_dims(plane_7, axis=0)], axis=0
    )

    return stack_of_planes


import numpy as np

def get_stack_of_planes(env, old_state, history_length=5):
    """
    Update the stack of planes based on the latest observation.

    Args:
        env: The environment instance, expected to have the method `unwrapped.observation_type.observe()`
             to return the observation.
        old_state: A numpy array of shape (history_length + 2, grid_height, grid_width) representing the previous state.
        history_length (int): Number of past frames to stack.

    Returns:
        A numpy array of shape (history_length + 2, grid_height, grid_width) representing the updated stack of planes.
    """
    # Initialize the KinematicToGridWrapper
    converter = KinematicToGridWrapper()

    # Get new observation and process it to form multiple planes
    new_obs = converter.process_observation(env.unwrapped.observation_type.observe(), -2, 4 * (4 - 1) + 2)
    new_obs_plane_0 = new_obs[0]  # First plane
    new_obs_plane_1 = new_obs[1]  # Second plane
    new_obs_plane_2 = new_obs[2]  # Third plane

    # Shift old_state to remove the oldest frame and append the new one
    stack_of_planes = np.roll(old_state, shift=-1, axis=0)
    stack_of_planes[-3] = new_obs_plane_0  # Update the third-last plane
    stack_of_planes[-2] = new_obs_plane_1  # Update the second-last plane
    stack_of_planes[-1] = new_obs_plane_2  # Update the last plane

    return stack_of_planes


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
duration = 15

In [6]:
import numpy as np

def softmax_policy(policy, available_actions):
    """
    Áp dụng softmax cho các xác suất trong policy dựa trên available_actions.

    :param policy: Dictionary chứa 5 hành động [0, 1, 2, 3, 4] với các xác suất tương ứng.
    :param available_actions: Danh sách các hành động có thể thực hiện (subset của [0, 1, 2, 3, 4]).
    :return: Dictionary chứa xác suất mới cho từng hành động (softmax áp dụng với các hành động khả dụng).
    """
    # Lấy các giá trị xác suất tương ứng với available_actions
    available_probs = np.array([policy[action] for action in available_actions])

    # Áp dụng softmax chỉ trên available_probs
    softmax_probs = available_probs / np.sum(available_probs)

    # Cập nhật xác suất mới
    updated_policy = {action: 0.0 for action in policy}  # Khởi tạo tất cả xác suất bằng 0
    for action, prob in zip(available_actions, softmax_probs):
        updated_policy[action] = prob

    return updated_policy

In [7]:
import numpy as np
import gymnasium as gym
import highway_env
config = {
    "observation": {
        "type": "Kinematics",
        "vehicles_count": 8,
        "features": ["x", "y", "vx", "vy", "heading"],
        "absolute": True,
        "normalize": False,
        "order": "sorted",
    },
    "action": {
        "type": "DiscreteMetaAction",
        "target_speeds": np.linspace(10, 30, 5)
    },
    "lanes_count": 4,
    "vehicles_density": 1.6+np.random.rand(),
}
env = gym.make("highway-fast-v0", config=config, render_mode='rgb_array')
env.reset()

(array([[133.61937 ,   4.      ,  25.      ,   0.      ,   0.      ],
        [142.21725 ,   8.      ,  22.122345,   0.      ,   0.      ],
        [151.6266  ,   0.      ,  22.895647,   0.      ,   0.      ],
        [160.57358 ,   0.      ,  22.1782  ,   0.      ,   0.      ],
        [168.86444 ,   0.      ,  21.171537,   0.      ,   0.      ],
        [177.8137  ,   0.      ,  22.887634,   0.      ,   0.      ],
        [187.06297 ,   8.      ,  23.109152,   0.      ,   0.      ],
        [196.75107 ,   4.      ,  23.425312,   0.      ,   0.      ]],
       dtype=float32),
 {'speed': 25,
  'crashed': False,
  'action': 3,
  'rewards': {'collision_reward': 0.0,
   'right_lane_reward': 0.3333333333333333,
   'high_speed_reward': 0.5,
   'on_road_reward': 1.0}})

In [8]:
import numpy as np
import copy

class MCTSNode:
    def __init__(self, env, parent, parent_action, prior_prob, c_puct=2.5):
        self.env = copy.deepcopy(env)
        self.parent = parent #parent node
        self.parent_action = parent_action
        self.children = {} #parent.children[action] = child
        self._n = 0
        self._W = 0
        self._P = prior_prob #Xác suất thực hiện hành động parent_action tại parent_node
        self.c_puct = c_puct
        min_speed = self.env.unwrapped.road.vehicles[0].target_speeds[0]
        max_speed = self.env.unwrapped.road.vehicles[0].target_speeds[-1]
        self.speed_bonus = (self.env.unwrapped.road.vehicles[0].speed - min_speed)/(max_speed - min_speed)
        self.collision = 0
        self.brake_penalty = 0
        if self.parent_action==4:
            self.brake_penalty = 1
        if self.env.unwrapped.road.vehicles[0].crashed:
            self.collision = 1 + 2*self.speed_bonus
        if self.parent is None:
            self.stack_of_planes = init_stack_of_planes(self.env)
        else:
            self.stack_of_planes = get_stack_of_planes(self.env, self.parent.stack_of_planes)
    def pucb_score(self):
        """
        Tính PUCB của node
        """
        if self._n == 0:
            Q = 0
        else:
            Q = self._W / self._n

        return Q + self.c_puct * self._P * np.sqrt(self.parent._n) / (1 + self._n) + 0.5*self.speed_bonus - 0.4*self.collision - 0.2*self.brake_penalty
    def select(self):
        """
        Chọn node có UCB lớn nhất
        """
        if not self.children:  # Nếu không có node con
            return None  # Hoặc raise Exception("No children nodes to select from")
        return max(self.children.values(), key=lambda child: child.pucb_score())

    def expand(self, action_priors):
        """
        Mở rộng cây bằng cách tạo node con
        action_priors là một dictionary chứa các xác suất prior của các action
        """
        for action, prob in action_priors.items():
            if action not in self.children and prob>0:
                #print(f"expanded {action} with {prob}")
                copy_env = copy.deepcopy(self.env)
                copy_env.step(action)
                self.children[action] = MCTSNode(copy_env, self, action, prob)

    def is_leaf(self):
        """
        Kiểm tra node có phải là leaf không
        """
        return self.children == {}
    def backpropagate(self, result):
        """
        Cập nhật visit count n và tổng điểm W
        new Q = new W/ new n
        """
        self._n += 1
        self._W += result
    def backpropagate_recursive(self, result):
        """
        Cập nhật toàn bộ đường đi từ node hiện tại đến root
        """
        if self.parent:
            self.parent.backpropagate_recursive(result)
        self.backpropagate(result)

In [9]:
import copy

class MCTS:
    def __init__(self, root, network, c_puct=3.5, n_simulations=10, min_average_speed=23, duration=12):
        self.c_puct = c_puct
        self.root = root
        self._network = network.to(device)
        self._n_simulations = n_simulations
        self.ego_init_position = root.env.unwrapped.road.vehicles[0].position[0]
        self.min_average_speed = min_average_speed
        self.duration = duration
    def traverse_to_leaf(self):
        node = self.root
        while not node.is_leaf():
            node = node.select()
        return node

    def rollout(self):
        leaf_node = self.traverse_to_leaf()
        truncated = leaf_node.env.unwrapped._is_truncated() # True nếu hoàn thành episode
        crashed = leaf_node.env.unwrapped.road.vehicles[0].crashed # True nếu ego-vehicle xảy ra va chạm
        leaf_state = leaf_node.stack_of_planes
        state_tensor = torch.tensor(leaf_state, dtype=torch.float32).unsqueeze(0).to(device)
        predicted_policy, predicted_value = self._network(state_tensor)
        predicted_policy = {action: prob for action, prob in enumerate(predicted_policy.squeeze().tolist())}
        #print(predicted_policy)
        available_actions = leaf_node.env.unwrapped.get_available_actions()
        #print(available_actions)
        updated_policy = softmax_policy(predicted_policy, available_actions)
        #print(updated_policy)
        predicted_value = predicted_value.item()
        if not truncated and not crashed:
            leaf_node.expand(updated_policy)
        elif truncated:
            ego_last_position = leaf_node.env.unwrapped.road.vehicles[0].position[0]
            ego_average_speed = (ego_last_position - self.ego_init_position)/(self.duration-1)
            confidence_score = ego_average_speed / self.min_average_speed
            if confidence_score >=1.0:
                predicted_value = 1.0
            else:
                predicted_value = 0.0
        elif crashed:
            predicted_value = -1.0
        leaf_node.backpropagate_recursive(predicted_value)

    def move_to_new_root(self, action):
        """
        Chuyển gốc của cây MCTS tới node con tương ứng với hành động được chọn.
        """
        if action in self.root.children:
            self.root = self.root.children[action]  # Di chuyển gốc tới node con
            self.root.parent = None  # Ngắt liên kết với node cha để giảm bộ nhớ
        else:
            # Nếu node con không tồn tại, khởi tạo lại cây tại node gốc mới
            raise ValueError("Hành động không có trong cây hiện tại.")

In [10]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch

class AlphaZeroTrainer:
    def __init__(self, network, env, c_puct=2, n_simulations=10, learning_rate=0.001, batch_size=32, epochs=10):
        self.network = network  # AlphaZeroNetwork
        self.env = env
        self.c_puct = c_puct
        self.n_simulations = n_simulations
        self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
        self.batch_size = batch_size
        self.epochs = epochs
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=3, verbose=True)
        self.training_data = []  # Lưu trữ dữ liệu huấn luyện dạng (state, policy, value, action)
        self.action_list = []

    def self_play(self, seed=21):
        """
        Tạo dữ liệu huấn luyện thông qua self-play với MCTS.
        """
        # Khởi tạo lại môi trường và trạng thái ban đầu
        self.env.reset(seed=seed)
        state = init_stack_of_planes(env)
        done = self.env.unwrapped._is_truncated() or self.env.unwrapped._is_terminated()

        # Tạo gốc của cây MCTS
        root_node = MCTSNode(self.env, parent=None, parent_action=None, prior_prob=1.0, c_puct=self.c_puct)
        mcts = MCTS(root=root_node, network=self.network, c_puct=self.c_puct, n_simulations=self.n_simulations, duration=duration)

        while not done:
            # Thực hiện MCTS rollout để tính xác suất hành động
            state = get_stack_of_planes(env, state)
            for _ in range(self.n_simulations):
                mcts.rollout()
            # Thu thập xác suất hành động và giá trị của trạng thái hiện tại
            action_probs = {action: 0.0 for action in range(5)}  # Khởi tạo xác suất của tất cả hành động là 0
            for action, child in root_node.children.items():
                action_probs[action] = child._n / (root_node._n - 1)
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            #print(state_tensor.shape)
            predicted_value = root_node._W / root_node._n if root_node._n > 0 else 0

            # Lưu trữ dữ liệu huấn luyện


            # Chọn hành động dựa trên xác suất từ MCTS
            action = max(action_probs, key=action_probs.get)
            self.action_list.append(action)
            self.env.step(action)
            #print(f"action chosen: {action}")
            self.training_data.append((state_tensor, action_probs, predicted_value, action))
            #(env.unwrapped.road.vehicles[0].target_lane_index[2])

            # Di chuyển gốc của MCTS đến node con tương ứng với hành động được chọn
            if action in root_node.children:
                mcts.move_to_new_root(action)
                root_node = mcts.root  # Cập nhật root_node cho vòng lặp kế tiếp
            else:
                raise ValueError("Action không tồn tại trong cây MCTS.")

            # Cập nhật trạng thái và kiểm tra điều kiện kết thúc
            done = self.env.unwrapped._is_truncated() or self.env.unwrapped._is_terminated()
        print("end self-play")

    def train(self):
        self.network.to(device)
        states, policies, values, actions = zip(*self.training_data)
        states = torch.cat(states).to(device)
        policies = torch.tensor([list(policy.values()) for policy in policies], dtype=torch.float32).to(device)
        values = torch.tensor(values, dtype=torch.float32).unsqueeze(1).to(device)
        actions = torch.tensor(actions, dtype=torch.long).to(device)

        # Weighted sampling
        class_counts = torch.bincount(actions)
        class_weights = 1.0 / (class_counts.float() + 1e-6)
        sample_weights = class_weights[actions]

        sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
        dataset = TensorDataset(states, policies, values, actions)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler)

        # Khởi tạo các danh sách lưu trữ loss
        self.policy_losses = []
        self.value_losses = []
        self.total_losses = []

        for epoch in range(self.epochs):
            epoch_policy_loss = 0
            epoch_value_loss = 0
            epoch_total_loss = 0
            batch_count = 0

            for state_batch, policy_batch, value_batch, action_batch in dataloader:
                # Move batches to device
                state_batch = state_batch.to(device)
                policy_batch = policy_batch.to(device)
                value_batch = value_batch.to(device)

                # Forward pass
                predicted_policy, predicted_value = self.network(state_batch)

                # Losses
                policy_loss = F.kl_div(
                    F.log_softmax(predicted_policy, dim=-1),
                    policy_batch,
                    reduction='batchmean'
                )
                #policy_loss = F.cross_entropy(predicted_policy, policy_batch)
                value_loss = F.mse_loss(predicted_value, value_batch)
                loss = 0.9*policy_loss + 0.1*value_loss

                # Backward pass
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # Cộng dồn loss
                epoch_policy_loss += policy_loss.item()
                epoch_value_loss += value_loss.item()
                epoch_total_loss += loss.item()
                batch_count += 1

            # Tính loss trung bình cho mỗi epoch
            avg_policy_loss = epoch_policy_loss / batch_count
            avg_value_loss = epoch_value_loss / batch_count
            avg_total_loss = epoch_total_loss / batch_count

            # Lưu loss vào danh sách
            self.policy_losses.append(avg_policy_loss)
            self.value_losses.append(avg_value_loss)
            self.total_losses.append(avg_total_loss)

            print(f"Epoch {epoch + 1}/{self.epochs}, value loss: {avg_value_loss}, policy loss: {avg_policy_loss}, Loss: {avg_total_loss}")
            self.scheduler.step(avg_total_loss)


    def save_model(self, path="alphazero_model.pth"):
        torch.save(self.network.state_dict(), path)

    def load_model(self, path="alphazero_model.pth"):
        self.network.load_state_dict(torch.load(path))


In [11]:
def env_init(duration):
    config = {
        "observation": {
            "type": "Kinematics",
            "vehicles_count": 8,
            "features": ["x", "y", "vx", "vy", "heading"],
            "absolute": True,
            "normalize": False,
            "order": "sorted",
        },
        "action": {
            "type": "DiscreteMetaAction",
            "target_speeds": np.linspace(10, 30, 5)
        },
        "lanes_count": 4,
        "vehicles_density": 1.6+np.random.rand(),
        "duration": duration,
    }
    env = gym.make("highway-fast-v0", config=config, render_mode='rgb_array')
    env.reset()
    return env

In [12]:
env = env_init(15)

In [13]:
network = AlphaZeroNetwork(input_shape=(120,20,7), n_residual_layers=10)

In [14]:
trainer = AlphaZeroTrainer(network, env, c_puct=2.0, n_simulations=150, learning_rate=0.001, batch_size=64, epochs=50)



In [15]:
trainer.load_model('/kaggle/input/alphazero/pytorch/ver0/1/alphazero_model (25).pth')

  self.network.load_state_dict(torch.load(path))


In [16]:
# def evaluate(network, seed):
#     action_list = []
#     speed_list = []
#     env = env_init(40)
#     state = init_stack_of_planes(env)
#     trainer = AlphaZeroTrainer(network, env, c_puct=3.5, n_simulations=15, learning_rate=0.001, batch_size=64, epochs=30)
#     trainer.network.eval()
#     ego_position_list = []
#     ego_position_list.append(env.unwrapped.road.vehicles[0].position[0])
#     while not env.unwrapped._is_terminated() and not env.unwrapped._is_truncated():
#         obs = env.unwrapped.observation_type.observe()
#         state = get_stack_of_planes(env, state)
#         state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
#         predicted_policy, predicted_value = trainer.network(state_tensor.to(device))
#         available_actions = env.unwrapped.get_available_actions()
#         predicted_policy = {action: prob for action, prob in enumerate(predicted_policy.squeeze().tolist())}
#         updated_policy = softmax_policy(predicted_policy, available_actions)
#         action = max(updated_policy, key=updated_policy.get)
#         env.render()
#         print(f"action chosen: {action}")
#         env.step(action)

# for seed in range(0, 20):
#     evaluate(network=network ,seed=seed)

In [17]:
import pickle
def training_pipeline(trainer, iterations=3, init_duration=20, duration_increase=5):
    trainer.network.train()
    for i in range(iterations):
        trainer.training_data = []
        env = env_init(duration=init_duration + duration_increase*i)
        trainer.env = env
        print(f"iter: {i}")
        j = 0
        while len(trainer.training_data)<2500:
            print(f"self-play: {j}")
            trainer.self_play(seed=j)
            j = j+1

In [18]:
# training_pipeline(trainer, iterations=1, init_duration=20, duration_increase=0)

In [19]:
# training_data = trainer.training_data

In [20]:
# import pickle
# with open("training_data.pkl", "wb") as f:
#     pickle.dump(training_data, f)

In [21]:
import pickle
with open("/kaggle/input/training-data-10000-samples/training_data (1).pkl", "rb") as f:
    training_data1 = pickle.load(f)
with open("/kaggle/input/training-data-10000-samples/training_data.pkl", "rb") as f:
    training_data2 = pickle.load(f)
with open("/kaggle/input/training-data-10000-samples/training_data (7).pkl", "rb") as f:
    training_data3 = pickle.load(f)
with open("/kaggle/input/training-data-10000-samples/training_data (8).pkl", "rb") as f:
    training_data4 = pickle.load(f)

In [22]:
training_data = training_data1+training_data2+training_data3+training_data4
len(training_data)

10008

In [23]:
trainer.training_data = training_data

In [24]:
trainer.train()

Epoch 1/50, value loss: 0.2422394486749248, policy loss: 0.729133094571958, Loss: 0.6804437147583932
Epoch 2/50, value loss: 0.2472449842911617, policy loss: 0.7347677399398415, Loss: 0.6860154479931874
Epoch 3/50, value loss: 0.2442529568816446, policy loss: 0.7269720951463007, Loss: 0.6787001657637821
Epoch 4/50, value loss: 0.24280762568021277, policy loss: 0.7224093144107017, Loss: 0.6744491271911912
Epoch 5/50, value loss: 0.24379478167196747, policy loss: 0.7254879394913935, Loss: 0.677318605647725
Epoch 6/50, value loss: 0.248140097138988, policy loss: 0.7348231043025946, Loss: 0.686154788087128
Epoch 7/50, value loss: 0.2495961639152211, policy loss: 0.7308000899424218, Loss: 0.6826796797430439
Epoch 8/50, value loss: 0.24199747934842566, policy loss: 0.7204470854655952, Loss: 0.6726021068111346
Epoch 9/50, value loss: 0.24131480666102878, policy loss: 0.7214051997585661, Loss: 0.6733961390082244
Epoch 10/50, value loss: 0.24281564734543964, policy loss: 0.7254589442994185, Los

In [25]:
trainer.save_model()