In [10]:
import collections

from kaggle_environments import make
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
class Net(nn.Module):
    def __init__(self, dim_in, hidden_dim, dim_out):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(dim_in, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim_out)
    
    def forward(self, input):

        x = self.fc1(input)
        x = F.relu(x)
        return F.relu(self.fc2(x))

In [76]:
Experience = collections.namedtuple('Experience',
           field_names=['state', 'action', 'reward', 'next_state'])
class ExperienceReplay:
  def __init__(self, capacity):
      self.buffer = []
      self.capacity = capacity
  def __len__(self):
      return len(self.buffer)
  def append(self, experience):
      if len(self.buffer) >= self.capacity:
          self.buffer.pop()
      self.buffer.append(experience)

  def sample(self, batch_size, device):
      indices = np.random.choice(len(self.buffer), batch_size,
                replace=False)
      zipped = list(zip(*[self.buffer[i] for i in indices]))
      return torch.tensor(zipped[0], dtype = torch.float).to(device), torch.tensor(zipped[1], dtype = torch.long).to(device), torch.tensor(zipped[2], dtype = torch.float).to(device), torch.tensor(zipped[3], dtype = torch.float).to(device)


In [77]:
def takeAction(actionList, device,epsilon):
        if np.random.random() < epsilon:
            return torch.tensor(np.random.choice(len(actionList))).to(device)
        else:
            return torch.argmax(actionList).to(device)

In [78]:
def changeReward(reward):
    if reward == None:
        return -10
    else:
        return reward

In [79]:
def generateEpisodes(amount, model, replayBuffer, env,device,epsilon):
    batchReward = 0
    with torch.no_grad(): 
        for _ in range(amount):
            done = False
            trainer = env.train([None, "random"])
            obs = trainer.reset()
            while not done:
                env.render(mode="ansi")
                tensor = torch.tensor(obs.board, dtype = torch.float).to(device)
                res = model(tensor)
                action = takeAction(res, device,epsilon)
                old_obs = obs
                obs, reward, done, info = trainer.step(action.item())
                reward = changeReward(reward)
                exp = Experience(old_obs.board, action, reward, obs.board)
                replayBuffer.append(exp)
                batchReward+=reward
    return batchReward

In [101]:
def train(model, qModel, replayBuffer, optimizer, loss_function, device,batchSize, alpha, gamma):
    batch = zip(*replayBuffer.sample(batchSize, device))
    optimizer.zero_grad()
    #das sollte auch besser gehen
    for state, action, reward, next_state in batch:
        value = model(state)[action]
        qValue = max(qModel(next_state))
        target = alpha * (reward+ gamma * qValue)
        loss = loss_function(value, target)
        loss.backward()
        optimizer.step()

In [102]:
def train2(hiddenDim, episodes, batchSize, device):
    env = make("connectx", {"rows": 3, "columns": 4, "inarow": 4})
    env.render()
    model = Net(env.configuration.columns*env.configuration.rows, hiddenDim, env.configuration.rows).to(device)
    qModel = Net(env.configuration.columns*env.configuration.rows, hiddenDim, env.configuration.rows).to(device)
    loss_function = nn.MSELoss()
    optimizer = optim.SGD(params = model.parameters(), lr=0.01)
    buffer = ExperienceReplay(episodes)
    idx = 0
    while True:
        batchReward = generateEpisodes(episodes, model, buffer, env,device,0.1)
        train(model, qModel, buffer, optimizer,loss_function,device,batchSize, 0.01, 0.9)
        if idx%100 ==0:
            qModel = model
        if idx % 10 == 0:
            print("idx: " + str(idx) + " meanReward: " +  str(batchReward))
            torch.save(model.state_dict(), "model_state")
        idx+=1
    

In [103]:
train2(100, 100, 25, device)

idx: 0 meanReward: -900
idx: 10 meanReward: -1000


KeyboardInterrupt: 