In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

class Environment:
    def __init__(self):
        self.actionSpace = 1        # 1-D continuous action
        self.observationSpace = 2   # 2-D state space of (x,y)
        self.state = self.reset()

    def reset(self):
        self.state = np.random.rand(self.observationSpace)
        return self.state
    
    def step(self, action):
        done = False
        reward = -1

        self.state += action
        if np.linalg.norm(self.state) > 10:
            done = True
            reward = 10

        return self.state, reward, done, {}
    
    
    
class PolicyNetwork(nn.Module):
    def __init__(self, inputDim, actionDim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(inputDim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, actionDim)
        )

    def forward(self, x):
        return self.fc(x)
    
class Agent:
    def __init__(self, stateDim, actionDim, lr = 1e-3, gamma = 0.99):
        self.stateDim = stateDim
        self.actionDim = actionDim
        self.gamma = gamma

        self.policy = PolicyNetwork(stateDim, actionDim)
        self.optimizer = optim.Adam(self.policy.parameters(), lr = lr)
        self.memory = []
        self.memoryCapcity = 10000
        self.batchSize = 64

    def selectAction(self, state):
        stateTensor = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            action = self.policy(stateTensor).squeeze(0).numpy()
        print(action)
        return action
    
    def storeTransition(self, state, action, reward, nextState, done):
        if len(self.memory) >= self.memoryCapcity:
            self.memory.pop(0)
        self.memory.append((state, action, reward, nextState, done))

    def train(self):
        if len(self.memory) < self.batchSize:
            return
        
        batch = random.sample(self.memory, self.batchSize)
        states, actions, rewards, nextStates, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        nextStates = torch.FloatTensor(nextStates)
        dones = torch.FloatTensor(dones).unsqueeze(1)

        predictedActions = self.policy(states)
        loss = nn.MSELoss()(predictedActions, actions)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

if __name__ == "__main__":
    env = Environment()
    agent = Agent(env.observationSpace, env.actionSpace)

    numEpisodes = 1000
    for episode in range(numEpisodes):
        state = env.reset()
        totalReward = 0

        while True:
            action = agent.selectAction(state)
            nextState, reward, done, _ = env.step(action)

            agent.storeTransition(state, action, reward, nextState, done)
            agent.train()

            state = nextState
            totalReward += reward

            if done:
                print(f"Episode {episode + 1}, Total Reward: {totalReward}")
                break




[0.1398287]
[0.14014347]
[0.13800913]
[0.13441713]
[0.13150525]
[0.13010076]
[0.1305226]
[0.13193597]
[0.13298088]
[0.13315536]
[0.13233028]
[0.1312689]
[0.12959461]
[0.12833416]
[0.12789857]
[0.12761392]
[0.12756684]
[0.12655677]
[0.12516138]
[0.12378313]
[0.12269446]
[0.12228385]
[0.12188169]
[0.12148073]
[0.12150824]
[0.12110818]
[0.1204211]
[0.11973781]
[0.11905845]
[0.11838289]
[0.11771133]
[0.1170434]
[0.11637932]
[0.11571898]
[0.11506259]
[0.11440966]
[0.11376272]
[0.113148]
[0.11257879]
[0.1119965]
[0.11139316]
[0.11083025]
[0.11030442]
[0.10978115]
[0.10926035]
[0.10874205]
[0.10822613]
[0.10771291]
[0.10720181]
[0.10669325]
[0.10618721]
[0.10568348]
Episode 1, Total Reward: -41
[0.12888233]
[0.13301584]
[0.13384448]
[0.13232268]
[0.12838392]
[0.12604731]
[0.12494207]
[0.12402511]
[0.12379366]
[0.12345691]
[0.12258793]
[0.12160802]
[0.34016773]
[0.1864148]
[0.01335073]
[-0.05915293]
[-0.03655826]
[0.03046271]
[0.1136506]
[0.19412474]
[0.25277466]
[0.27467874]
[0.25475857]
[0.1

KeyboardInterrupt: 