# FROM CHATGPT

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

In [7]:
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)
        
    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
class QLearning:
    def __init__(self, input_dim, output_dim, lr, gamma):
        self.q_net = QNetwork(input_dim, output_dim)
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()
        self.gamma = gamma
        self.states = []
        self.actions = []
        self.rewards = []
        
    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        q_values = self.q_net(state)
        action = torch.argmax(q_values).item()
        return action
        
    def update(self):
        self.optimizer.zero_grad()
        states = torch.tensor(self.states, dtype=torch.float32)
        actions = torch.tensor(self.actions, dtype=torch.int64)
        rewards = torch.tensor(self.rewards, dtype=torch.float32)
        
        q_values = self.q_net(states)
        action_values = q_values.gather(1, actions.view(-1, 1)).squeeze()
        next_q_values = torch.zeros_like(rewards)
        next_q_values[-1] = self.q_net(torch.tensor(self.states[-1], dtype=torch.float32)).max().item()
        
        for i in reversed(range(len(self.rewards) - 1)):
            next_q_values[i] = self.gamma * next_q_values[i+1] + self.rewards[i]
        
        expected_q_values = rewards + self.gamma * next_q_values
        loss = self.loss_fn(action_values, expected_q_values.detach())
        loss.backward()
        self.optimizer.step()
        
        self.states = []
        self.actions = []
        self.rewards = []

In [11]:
hej=QLearning(1,1, 0.01, 0.9)