In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install gymnasium
!pip install ale-py



In [3]:
import os
import gymnasium as gym
import ale_py
gym.register_envs(ale_py)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from gymnasium.utils.save_video import save_video

import copy
import random
import time

In [4]:
os.chdir('/content/drive/MyDrive/Colab Notebooks/gymnasium/atari')

In [5]:
class Breakout(gym.Wrapper):
    def __init__(self, render_mode='rgb_array_list', repeat=4, device='cpu'):
        env = gym.make('ALE/Breakout-v5', render_mode=render_mode, frameskip=1, repeat_action_probability=0.0)
        super(Breakout, self).__init__(env)

        self.image_shape = (84,84)
        self.repeat = repeat
        self.lives = 5
        self.frame_buffer = []
        self.device = device

    def step(self, action):
        total_reward = 0
        done = False

        for i in range(self.repeat):
            observation, reward, done, truncacted, info = self.env.step(action)

            total_reward += reward

            current_lives = info['lives']

            if current_lives < self.lives:
                total_reward = total_reward - 1
                self.lives = current_lives

            self.frame_buffer.append(observation)

            if done:
                break

        max_frame = np.max(self.frame_buffer[-2:], axis=0)
        max_frame = self.process_observation(max_frame)
        max_frame = max_frame.to(self.device)

        total_reward = torch.tensor(total_reward).view(1,-1).float()
        total_reward = total_reward.to(self.device)

        done = torch.tensor(done).view(1,-1)
        done = done.to(self.device)

        return max_frame, total_reward, done, info, observation

    def reset(self):
        self.frame_buffer = []

        observation, _ = self.env.reset()
        image = observation.copy()

        self.lives = 5

        observation = self.process_observation(observation)

        return observation, image

    def process_observation(self, observation):
        img = Image.fromarray(observation).resize(self.image_shape).convert("L")
        img = torch.from_numpy(np.array(img))
        img = img.unsqueeze(0).unsqueeze(0)
        img = img.to(self.device)

        return img/255.0

class ReplayMemory:
    def __init__(self, capacity, device='cpu'):
        self.capacity = capacity
        self.memory = []
        self.device = device

    def insert(self, transition):
        transition = [item.to('cpu') for item in transition]

        if len(self.memory) < self.capacity:
            self.memory.append(transition)
        else:
            self.memory.remove(self.memory[0])
            self.memory.append(transition)

    def sample(self, batch_size=64):
        batch = random.sample(self.memory, batch_size)
        batch = zip(*batch)
        return [torch.cat(items).to(self.device) for items in batch]

    def can_sample(self, batch_size):
        return len(self.memory) >= batch_size * 10

class Model(nn.Module):
    def __init__(self, nb_action=4):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1,32,kernel_size=(8,8), stride=(4,4))
        self.conv2 = nn.Conv2d(32,64,kernel_size=(4,4), stride=(2,2))
        self.conv3 = nn.Conv2d(64,64,kernel_size=(3,3), stride=(1,1))

        self.action_value1 = nn.Linear(3136, 1024)
        self.action_value2 = nn.Linear(1024, 1024)
        self.action_value3 = nn.Linear(1024, nb_action)

        self.state_value1 = nn.Linear(3136, 1024)
        self.state_value2 = nn.Linear(1024, 1024)
        self.state_value3 = nn.Linear(1024, 1)

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        x = torch.Tensor(x)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.flatten(x)

        s = self.dropout(self.relu(self.state_value1(x)))
        s = self.dropout(self.relu(self.state_value2(s)))
        s = self.relu(self.state_value3(s))

        a = self.dropout(self.relu(self.action_value1(x)))
        a = self.dropout(self.relu(self.action_value2(a)))
        a = self.relu(self.action_value3(a))

        output = s + (a - a.mean())

        return output

    def save_the_model(self, weights_filename='models/latest.pt'):
        if not os.path.exists("models"):
            os.makedirs("models")
        torch.save(self.state_dict(), weights_filename)

    def load_the_model(self, weights_filename='models/latest.pt'):
        try:
            self.load_state_dict(torch.load(weights_filename, map_location=device))
            print("Loaded weights file")
        except:
            print("No weights file")

def f(episode_id: int) -> bool:
    return True

class Agent:
    def __init__(self, model, device='cpu', epsilon=1.0, min_epsilon=0.1, action_size=None, learning_rate=1e-5):
        self.memory = ReplayMemory(device=device, capacity=600000)
        self.model = model
        self.target_model = copy.deepcopy(model).eval()
        self.epsilon = epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_decay = 1 - (((epsilon - min_epsilon) / 5000) * 2)
        self.batch_size = 64
        self.model.to(device)
        self.target_model.to(device)
        self.gamma = 0.99
        self.action_size = action_size

        self.optimizer = optim.AdamW(model.parameters(),lr=learning_rate)

    def get_action(self, state):
        if torch.rand(1) < self.epsilon:
            return torch.randint(self.action_size, (1,1)), None
        else:
            av = self.model(state).detach()
            return torch.argmax(av, dim=1, keepdim=True), av

    def train(self, env, epochs):
        reward_list = {}
        for epoch in range(1,epochs + 1):
            print(epoch)
            reward_list[epoch] = 0
            state,_ = env.reset()
            done = False

            while not done:
                action, _ = self.get_action(state)

                next_state, reward, done, info, _ = env.step(action)
                reward_list[epoch] += reward
                self.memory.insert([state, action, reward, done, next_state])

                if self.memory.can_sample(self.batch_size):
                    state_b, action_b, reward_b, done_b, next_state_b = self.memory.sample(self.batch_size)
                    qsa_b = self.model(state_b).gather(1,action_b)
                    next_qsa_b = self.target_model(next_state_b)
                    next_qsa_b = torch.max(next_qsa_b, dim=-1, keepdim=True)[0]
                    target_b = reward_b + ~done_b * self.gamma * next_qsa_b
                    loss = F.mse_loss(qsa_b, target_b)
                    self.model.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                state = next_state

            if self.epsilon > self.min_epsilon:
                self.epsilon = self.epsilon * self.epsilon_decay

            if epoch % 10 == 0:
                self.model.save_the_model()
                reward_sum = 0
                for i in range(10):
                    reward_sum += reward_list[epoch - i]
                print(reward_sum/10)

            if epoch % 100 == 0:
                self.target_model.load_state_dict(self.model.state_dict())

                save_video(
                env.render(),
                "videos",
                episode_trigger=f,
                fps=24,
                step_starting_index=0,
                episode_index=epoch)

            if epoch % 1000 == 0:
                self.model.save_the_model(f"models/model_iter_{epoch}.pt")

In [6]:
os.environ['KMP_DUPLICATE_OK'] = 'TRUE'
device = torch.device('cuda:0')

environment = Breakout(device=device)

model = Model(nb_action=4).to(device)

agent = Agent(model=model,
              device=device,
              epsilon=1,
              action_size=4,
              learning_rate=1e-5)

agent.train(env=environment, epochs=10000000)

1
2
3
4
5
6
7
8
9
10
tensor([[-3.9000]], device='cuda:0')
11
12
13
14
15
16
17
18
19
20
tensor([[-4.]], device='cuda:0')
21
22
23
24
25
26
27
28
29
30
tensor([[-3.3000]], device='cuda:0')
31
32
33
34
35
36
37
38
39
40
tensor([[-3.5000]], device='cuda:0')
41
42
43
44
45
46
47
48
49
50
tensor([[-4.]], device='cuda:0')
51
52
53
54
55
56
57
58
59
60
tensor([[-4.3000]], device='cuda:0')
61
62
63
64
65
66
67
68
69
70
tensor([[-3.2000]], device='cuda:0')
71
72
73
74
75
76
77
78
79
80
tensor([[-3.9000]], device='cuda:0')
81
82
83
84
85
86
87
88
89
90
tensor([[-4.1000]], device='cuda:0')
91
92
93
94
95
96
97
98
99
100
tensor([[-3.8000]], device='cuda:0')
101
102
103
104
105
106
107
108
109
110
tensor([[-3.1000]], device='cuda:0')
111
112
113
114
115
116
117
118
119
120
tensor([[-4.2000]], device='cuda:0')
121
122
123
124
125
126
127
128
129
130
tensor([[-3.7000]], device='cuda:0')
131
132
133
134
135
136
137
138
139
140
tensor([[-3.4000]], device='cuda:0')
141
142
143
144
145
146
147
148
149
15

KeyboardInterrupt: 