In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from assets.connectfour import ConnectFourEnv

# Detect if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
learning_rate = 1e-5
gamma = 0.95
batch_size = 32
buffer_capacity = 100000
epsilon_start = 0.9
target_update_frequency = 100
num_episodes = 500000

net_win = 0
epsilon = epsilon_start

# Initialize environment, Q-network, target network, optimizer, and replay buffer
env = ConnectFourEnv(rows=6, cols=7, win_condition="four_in_a_row")
input_dim = (env.rows, env.cols)  # Flattened board size
output_dim = env.cols  # Number of possible actions (columns)

# Initialize Q-network and target network
q_network = DQN(input_dim, output_dim).to(device)
target_network = DQN(input_dim, output_dim).to(device)
target_network.load_state_dict(q_network.state_dict())
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)

# Load the saved model if it exists
model_path = "XXX.pth"
try:
    q_network.load_state_dict(torch.load(model_path))
    target_network.load_state_dict(q_network.state_dict())
    print("Loaded saved model for continued training.")
except FileNotFoundError:
    print("No saved model found. Starting training from scratch.")

# Load the opponent model (pre-trained model)
opponent_model = DQN(input_dim, output_dim).to(device)
opponent_model.load_state_dict(torch.load("XXX.pth"))
opponent_model.eval()  # Ensure opponent model is fixed (not updated)

# Initialize replay buffer
replay_buffer = ReplayBuffer(buffer_capacity)

####################  Training loop  ####################
for episode in range(1, num_episodes + 1):
    done = False
    episode_transitions = []  # Store agent transitions
    winner = 0  # Winner flag (1 = Agent wins, 2 = Opponent wins, 0 = No winner yet, 3 = Tie)


    ############## Agent vs Opponent game loop ##############
    while not done:
        # Agent's move (exploration vs exploitation)
        action = epsilon_greedy_action(state, epsilon, q_network)
        next_state, reward, done, _ = env.step(action)
        
        if done:
            episode_transitions.append((state, action, reward, next_state, done))
            break

        # Now opponent plays with max Q (no epsilon-greedy)
        opponent_action = get_opponent_action(next_state)
        state_after_opponent, opponent_reward, done, _ = env.step(opponent_action)
        
        # Store agent's transition after opponent's move
        episode_transitions.append((state, action, reward, state_after_opponent, done))

        # Set the state after opponent's move as the current state for the next iteration
        state = state_after_opponent.copy()

        # Check for the game outcome (done)
        if done:
            break


    # Push only the agent's transitions into the replay buffer
    for transition in episode_transitions:
        replay_buffer.push(*transition)

    
    ############## Perform experience replay and Q-network update ##############
    if len(replay_buffer) >= batch_size:
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
        states = torch.FloatTensor(states).unsqueeze(1).to(device)
        next_states = torch.FloatTensor(next_states).unsqueeze(1).to(device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        dones = torch.FloatTensor(dones).to(device)

        q_values = q_network(states).gather(1, actions).squeeze(1)
        next_q_values = target_network(next_states).max(1)[0]
        target_q_values = rewards + (1 - dones) * gamma * next_q_values

        loss = nn.MSELoss()(q_values, target_q_values)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if episode % target_update_frequency == 0:
        target_network.load_state_dict(q_network.state_dict())
        torch.save(q_network.state_dict(), "XXX.pth")

    if episode % 100 == 0:
        print(f"Episode {episode}/{num_episodes}, Win: {net_win}, Epsilon: {epsilon:.2f}, Loss: {loss.item():.4f}")

# Save the final trained model
torch.save(q_network.state_dict(), "XXX.pth")
print("Training complete and model saved.")

In [None]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)