In [10]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

plt.ion()

In [11]:
# goal: use a RL agent to approximate y = sin x

In [12]:
# ground truth values

x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
x = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3]))

In [13]:
n_states = 3
n_actions = 2 

In [14]:
class Model(nn.Module):
    def __init__(self, n_states, n_actions):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(n_states, 10)
        self.layer2 = nn.Linear(10, n_actions)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        return self.layer2(x)

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state'))

class Memory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    def push(self, *args):
        self.memory.append(Transition(*args))
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)
    
class Agent:
    def __init__(self, n_states, n_actions):
        self.batch_size = 128
        self.gamma = 0.7
        self.epsilon = 0.99
        self.decay = 0.01
        self.lr = 1e-4
        
        self.model = Model(n_states, n_actions)
        self.memory = Memory(1000)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=self.lr, amsgrad=True)
    
    def get_action(self, state):
        sample = random.random()
        epsilon = self.epsilon*self.decay
        if sample > epsilon:
            with torch.no_grad():
                return self.model(state).max(1).indices.view(1, 1)
        else:
            random_action = random.choice([0, 1])
            return torch.tensor([[random_action]], dtype=torch.long)
        
    def update(self):
        if len(self.memory) < self.batch_size:
            return
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        next_state_batch = torch.cat(batch.next_state)
        
        q_values = self.model(state_batch).gather(1, action_batch)
        with torch.no_grad():
            next_state_values = self.model(next_state).max(1).values
        q_pred = (next_state_values * self.gamma) + reward_batch
        
        criterion = nn.MSELoss() # use Huber loss here to smooth results
        loss = criterion(q_values, q_pred.unsqueeze(1)) 
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [15]:
# get step functon that returns next_state and reward
# increase or decrease y-value of y = sin x plot by 0.01
# goal is to have this be user-defined

def step(state, action):
    if action == 1:
        next_state = state + 0.01
    else:
        next_state = state - 0.01
    reward = torch.tensor([1], dtype=torch.long)
    return next_state, reward

In [16]:
agent = Agent(n_states, n_actions)

episodes = []
state = torch.tensor([[-3.1, 9.8, -30.0]])

for i in range(10):
    action = agent.get_action(state)
    print(state, action)
    next_state, reward = step(state, action)
    agent.memory.push(state, action, reward, next_state)
    state = next_state
    agent.update()

tensor([[ -3.1000,   9.8000, -30.0000]]) tensor([[1]])
tensor([[ -3.0900,   9.8100, -29.9900]]) tensor([[1]])
tensor([[ -3.0800,   9.8200, -29.9800]]) tensor([[1]])
tensor([[ -3.0700,   9.8300, -29.9700]]) tensor([[1]])
tensor([[ -3.0600,   9.8400, -29.9600]]) tensor([[1]])
tensor([[ -3.0500,   9.8500, -29.9500]]) tensor([[1]])
tensor([[ -3.0400,   9.8600, -29.9400]]) tensor([[1]])
tensor([[ -3.0300,   9.8700, -29.9300]]) tensor([[1]])
tensor([[ -3.0200,   9.8800, -29.9200]]) tensor([[1]])
tensor([[ -3.0100,   9.8900, -29.9100]]) tensor([[1]])


In [17]:
agent.memory.sample(10)

[Transition(state=tensor([[ -3.0200,   9.8800, -29.9200]]), action=tensor([[1]]), reward=tensor([1]), next_state=tensor([[ -3.0100,   9.8900, -29.9100]])),
 Transition(state=tensor([[ -3.0400,   9.8600, -29.9400]]), action=tensor([[1]]), reward=tensor([1]), next_state=tensor([[ -3.0300,   9.8700, -29.9300]])),
 Transition(state=tensor([[ -3.1000,   9.8000, -30.0000]]), action=tensor([[1]]), reward=tensor([1]), next_state=tensor([[ -3.0900,   9.8100, -29.9900]])),
 Transition(state=tensor([[ -3.0500,   9.8500, -29.9500]]), action=tensor([[1]]), reward=tensor([1]), next_state=tensor([[ -3.0400,   9.8600, -29.9400]])),
 Transition(state=tensor([[ -3.0900,   9.8100, -29.9900]]), action=tensor([[1]]), reward=tensor([1]), next_state=tensor([[ -3.0800,   9.8200, -29.9800]])),
 Transition(state=tensor([[ -3.0700,   9.8300, -29.9700]]), action=tensor([[1]]), reward=tensor([1]), next_state=tensor([[ -3.0600,   9.8400, -29.9600]])),
 Transition(state=tensor([[ -3.0800,   9.8200, -29.9800]]), acti

In [18]:
agent.model.state_dict()

OrderedDict([('layer1.weight',
              tensor([[-0.4282,  0.4636,  0.0773],
                      [-0.3766,  0.1119, -0.0088],
                      [ 0.2945,  0.4114,  0.2721],
                      [ 0.5460,  0.3432, -0.3135],
                      [ 0.3919,  0.2815,  0.1311],
                      [ 0.4624,  0.1530, -0.4005],
                      [-0.3883,  0.2627, -0.5057],
                      [ 0.2026,  0.0338, -0.4828],
                      [-0.4734,  0.4617,  0.3300],
                      [ 0.2952,  0.0548,  0.4848]])),
             ('layer1.bias',
              tensor([-0.2972, -0.0352,  0.4661, -0.1220,  0.3902,  0.2373, -0.3435,  0.1794,
                      -0.1015,  0.1087])),
             ('layer2.weight',
              tensor([[-0.0950, -0.2298, -0.2503, -0.2926,  0.0731, -0.1781,  0.0795,  0.0688,
                        0.0209,  0.0566],
                      [-0.0774,  0.1688, -0.0010,  0.0757,  0.2357, -0.2544,  0.1859,  0.2204,
                       -0.2

In [19]:
# plotting state_dict function should get an approximation of sin x

In [20]:
linear_layer = agent.model.layer2

In [21]:
linear_layer.weight[:, 0]

tensor([-0.0950, -0.0774], grad_fn=<SelectBackward>)

In [22]:
y_result = linear_layer.bias[0].item() + \
           linear_layer.weight[:, 0][0].item()*x + \
           linear_layer.weight[:, 1][0].item()*x**2 + \
           linear_layer.weight[:, 2][0].item()*x**3

In [23]:
y_result

tensor([[ 5.5238e+00, -2.6423e+02,  7.2433e+03],
        [ 5.5047e+00, -2.6270e+02,  7.1777e+03],
        [ 5.4858e+00, -2.6118e+02,  7.1126e+03],
        ...,
        [-1.0539e+01, -2.6118e+02, -7.5497e+03],
        [-1.0567e+01, -2.6270e+02, -7.6174e+03],
        [-1.0596e+01, -2.6423e+02, -7.6856e+03]])