In [1]:
import sys
sys.path.insert(0, '..')
sys.path.insert(0, '../../7wd-engine/')

In [2]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from swd_bot.state_features import StateFeatures
from swd.entity_manager import EntityManager

In [3]:
CARDS_COUNT = EntityManager.cards_count()

In [4]:
class StatesDataset(Dataset):
    def __init__(self, states, actions):
        self.states = states
        self.actions = actions
    
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, index):
        state = self.states[index]
        action = self.actions[index]
        features = torch.tensor(flatten_features(StateFeatures.extract_state_features_dict(state)), dtype=torch.float)
        action_id = action.card_id + (0 if str(action)[0] == "B" else CARDS_COUNT)
        return features, torch.tensor(action_id, dtype=torch.long)

In [5]:
def flatten_features(x):
    output = []
    output.append(x["age"])
    output.append(x["current_player"])
    output.extend(x["tokens"])
    output.append(x["military_pawn"])
    output.extend(x["military_tokens"])
    output.append(x["game_status"])
    for i in range(2):
        output.append(x["players"][i]["coins"])
        output.extend(x["players"][i]["unbuilt_wonders"])
        output.extend(x["players"][i]["bonuses"])
    for card_id in x["cards_board"]:
        ohe = [0] * CARDS_COUNT
        if card_id >= 0:
            ohe[card_id] = 1
            # output.extend(EntityManager.card(card_id).bonuses)
        # else:
            # output.extend([0] * len(EntityManager.card(0).bonuses))
        output.extend(ohe)
    return output

In [6]:
with open("states.pkl", "rb") as f:
    states = pickle.load(f)
with open("actions.pkl", "rb") as f:
    actions = pickle.load(f)
dataset = StatesDataset(states, actions)
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [16]:
class Net(nn.Module):
    def __init__(self, in_features: int):
        super().__init__()
        
        self.linear1 = nn.Linear(in_features, 300)
        self.linear2 = nn.Linear(300, CARDS_COUNT * 2)
    
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.softmax(self.linear2(x), dim=1)
        return x

In [17]:
features_number = len(dataset[0][0])
features_number

1622

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [19]:
model = Net(features_number).to(device)

In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch_size = 256

In [21]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True,
                                           num_workers=0)
val_loader = torch.utils.data.DataLoader(test_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=0)

In [22]:
for epoch in range(10):
    running_loss = 0.0
    count = 0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        
        optimizer.zero_grad()

        outputs = model(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        count += 1

    with torch.no_grad():
        correct_pred = 0
        total_pred = 0
        for i, data in enumerate(val_loader):
            inputs, labels = data
            labels = labels.to(device)

            outputs = model(inputs.to(device))
            _, predictions = torch.max(outputs, 1)
            for label, prediction in zip(labels, predictions.to(device)):
                if label == prediction:
                    correct_pred += 1
                total_pred += 1

    print(f'[{epoch + 1}] loss: {running_loss / count:.3f}, Accuracy: {round(100 * correct_pred / total_pred)}%')
    running_loss = 0.0
    count = 0

print('Finished Training')

[1] loss: 4.897, Accuracy: 15%
[2] loss: 4.837, Accuracy: 18%
[3] loss: 4.806, Accuracy: 20%
[4] loss: 4.780, Accuracy: 23%
[5] loss: 4.759, Accuracy: 25%
[6] loss: 4.739, Accuracy: 26%
[7] loss: 4.727, Accuracy: 27%
[8] loss: 4.712, Accuracy: 28%
[9] loss: 4.701, Accuracy: 28%
[10] loss: 4.694, Accuracy: 29%
Finished Training
