In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from pymongo import MongoClient
import copy

from state import getStaticInfo, addIndexInfo, convertToStateVector, getQueryMetrics, saveAsJSON
from index import create_single_field_index, delete_single_field_index, reset_index_config

In [2]:
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, action_size)
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.fc3.bias)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        action_probs = torch.tanh(self.fc3(x))
        return action_probs

In [3]:
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.fc3.bias)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        value = self.fc3(x)
        return value

In [12]:
state_size = 192
action_size = 30
learning_rate = 0.0001
actor = Actor(state_size, action_size).cuda()
critic = Critic(state_size).cuda()

In [13]:
actor_optimizer = optim.Adam(actor.parameters(), lr=learning_rate)
critic_optimizer = optim.Adam(critic.parameters(), lr=learning_rate)

In [None]:
#to continue training using an existing model
actor.load_state_dict(torch.load('actor_model.pth'))
critic.load_state_dict(torch.load('critic_model.pth'))

In [14]:
client = MongoClient('mongodb://localhost:27017/')
db_conn = client['benchmark_db1']

In [None]:
print("collecting static info...")
partial_state = copy.deepcopy(getStaticInfo(db_conn))
og_state_dict = copy.deepcopy(addIndexInfo(db_conn, partial_state))
saveAsJSON(og_state_dict)
og_state_vector, collection_list, field_list = convertToStateVector(og_state_dict) # returns a 1X192 list

In [15]:
# read saved static info instead of generating it on the fly
import json
with open('state.json', 'r') as file:
    og_state_dict = json.load(file)
og_state_vector, collection_list, field_list = convertToStateVector(og_state_dict)

In [16]:
def env_step(action, state, state_dict):
    field_counter = 0  # Counter to keep track of the field index in the action list
    for collectionIdx, collection_name in enumerate(collection_list):
        for fieldIdx, field_name in enumerate(field_list[collectionIdx]):
            action_value = action[field_counter]
            field_counter += 1

            # print(f"Action: Collection {collection_name}, Field {field_name}, Action Value {action_value}")

            if action_value > 0:
                # Try to create an index
                res = create_single_field_index(db_conn, collection_name, field_name)
                if res == 0:  # Index already exists
                    pass
                    # reward = -99999999
                    # print(reward)
                    # return reward, state, state_dict
                else:
                    state_dict = addIndexInfo(db_conn, state_dict)
                    state, _, _ = convertToStateVector(state_dict)
            elif action_value < 0:
                # Try to delete an index
                res = delete_single_field_index(db_conn, collection_name, field_name)
                if res == 0:  # Index doesn't exist
                    pass
                    # reward = -99999999
                    # print(reward)
                    # return reward, state, state_dict
                else:
                    state_dict = addIndexInfo(db_conn, state_dict)
                    state, _, _ = convertToStateVector(state_dict)

    reward = -getQueryMetrics(db_conn)['executionTimeMillis']
    print(reward)
    return reward, state, state_dict

In [17]:
def train_step(state, reward, next_state, done):
    state = torch.tensor(state, dtype=torch.float32).cuda()
    next_state = torch.tensor(next_state, dtype=torch.float32).cuda()
    reward = torch.tensor(reward, dtype=torch.float32).cuda()
    done = torch.tensor(done, dtype=torch.float32).cuda()

    # Critic update
    value = critic(state)
    next_value = critic(next_state)
    target = reward + (1 - done) * 0.99 * next_value
    critic_loss = nn.functional.mse_loss(value, target.detach())
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    # Actor update
    action_probs = actor(state)
    action_log_probs = torch.log(action_probs)
    advantage = (target - value).detach()
    actor_loss = -action_log_probs * advantage
    actor_optimizer.zero_grad()
    actor_loss.mean().backward()
    actor_optimizer.step()

In [18]:
num_episodes = 30
max_steps_per_episode = 100
total_steps = float(num_episodes * max_steps_per_episode)
steps = 0
for episode in range(num_episodes):
    state_dict = copy.deepcopy(og_state_dict)
    state = og_state_vector[:]

    # print(f"Episode {episode}: Initial state size: {len(state)}")
    done = False
    step_count = 0

    while not done:
        print(f"Progress : {round(float((steps*100)/total_steps), 2)}%")
        steps = steps + 1
        action_probs = actor(torch.tensor(state, dtype=torch.float32).cuda())
        action = action_probs.tolist()
        reward, next_state, next_state_dict = env_step(action, state, state_dict)

        train_step(state, reward, next_state, done)
        state = next_state

        step_count += 1
        if step_count >= max_steps_per_episode:
            done = True

        # if episode%100 == 0 and step_count == max_steps_per_episode:
        #     print(f"Episode {episode} action_probs = {action_probs}")
    torch.save(actor.state_dict(), 'actor_model.pth')
    torch.save(critic.state_dict(), 'critic_model.pth')
    reset_index_config(db_conn, og_state_dict)

Progress : 0.0%


In [None]:
torch.save(actor.state_dict(), 'actor_model.pth')
torch.save(critic.state_dict(), 'critic_model.pth')

In [13]:
def predict_action(state):
    actor.load_state_dict(torch.load('actor_model.pth'))
    state = torch.tensor(state, dtype=torch.float32).cuda()
    action_probs = actor(state)
    collection = torch.argmax(action_probs[:5]).item()
    field = torch.argmax(action_probs[5:11]).item()
    index_status = action_probs[11].item()
    return collection, field, index_status

In [None]:
print(og_state_dict['collection_3']['currency'])
reset_index_config(db_conn, dict(og_state_dict))