In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from pymongo import MongoClient

import state as state_gen

In [27]:
class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, action_size)
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.fc3.bias)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        action_probs = torch.sigmoid(self.fc3(x))
        return action_probs

In [28]:
class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        nn.init.zeros_(self.fc3.bias)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        value = self.fc3(x)
        return value

In [51]:
state_size = 192
action_size = 12
learning_rate = 0.0001
actor = Actor(state_size, action_size).cuda()
critic = Critic(state_size).cuda()

actor_optimizer = optim.Adam(actor.parameters(), lr=learning_rate)
critic_optimizer = optim.Adam(critic.parameters(), lr=learning_rate)

In [52]:
def simulate_reward(action):
    return random.uniform(-2, 1)

In [53]:
def train_step(state, reward, next_state, done):
    state = torch.tensor(state, dtype=torch.float32).cuda()
    next_state = torch.tensor(next_state, dtype=torch.float32).cuda()
    reward = torch.tensor(reward, dtype=torch.float32).cuda()
    done = torch.tensor(done, dtype=torch.float32).cuda()

    # Critic update
    value = critic(state)
    next_value = critic(next_state)
    target = reward + (1 - done) * 0.99 * next_value
    critic_loss = nn.functional.mse_loss(value, target.detach())
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()

    # Actor update
    action_probs = actor(state)
    action_log_probs = torch.log(action_probs)
    advantage = (target - value).detach()
    actor_loss = -action_log_probs * advantage
    actor_optimizer.zero_grad()
    actor_loss.mean().backward()
    actor_optimizer.step()

In [16]:
client = MongoClient('mongodb://localhost:27017/')
db_conn = client['benchmark_db1']

print("collecting static info...")
partial_state = dict(state_gen.getStaticInfo(db_conn))
og_state_dict = dict(state_gen.addIndexInfo(db_conn, partial_state))
og_state_vector, fields, collections = state_gen.convertToStateVector(og_state_dict) # returns a 1X192 list

collecting static info...
collection_1
['name', 'address', 'email', 'age']
collection_2
['company', 'price', 'quantity', 'in_stock', 'discount']
collection_3
['date', 'transaction_id', 'amount', 'currency']
collection_4
['username', 'password', 'last_login', 'is_active', 'role']
collection_5
['product_name', 'category', 'rating', 'review_count', 'release_date', 'discontinued']
collection_1 -> name : 456539 / 1900000
collection_1 -> address : 1899992 / 1900000
collection_1 -> email : 787976 / 1900000
collection_1 -> age : 63 / 1900000
collection_2 -> company : 939173 / 1900000
collection_2 -> price : 98673 / 1900000
collection_2 -> quantity : 50 / 1900000
collection_2 -> in_stock : 2 / 1900000
collection_2 -> discount : 1 / 1900000
collection_3 -> date : 1663 / 1823286
collection_3 -> transaction_id : 1823286 / 1823286
collection_3 -> amount : 810125 / 1823286
collection_3 -> currency : 4 / 1823286
collection_4 -> username : 484322 / 1900005
collection_4 -> password : 1900001 / 1900005


In [54]:
num_episodes = 1000
max_steps_per_episode = 10

for episode in range(num_episodes):
    state = og_state_vector

    # print(f"Episode {episode}: Initial state size: {len(state)}")

    done = False
    step_count = 0

    while not done:
        action_probs = actor(torch.tensor(state, dtype=torch.float32).cuda())
        try:
            collection = torch.argmax(action_probs[:5]).item()
            field = torch.argmax(action_probs[5:11]).item()
            index_status = action_probs[11].item()
        except Exception as e:
            print(f"Error at Episode {episode}, Step {step_count}: {e}")
            print(f"action_probs: {action_probs}")
            break

        action = (collection, field, index_status)
        reward = simulate_reward(action)
        next_state = og_state_vector

        train_step(state, reward, next_state, done)
        state = next_state

        step_count += 1
        if step_count >= max_steps_per_episode:
            done = True

        if episode%100 == 0 and step_count == max_steps_per_episode:
            print(f"Episode {episode} action_probs = {action_probs}")

Episode 0 action_probs = tensor([0.5508, 0.5426, 0.3176, 0.4840, 0.6441, 0.5467, 0.5537, 0.6944, 0.4942,
        0.4981, 0.4760, 0.3408], device='cuda:0', grad_fn=<SigmoidBackward0>)
Episode 100 action_probs = tensor([7.9492e-07, 7.8884e-03, 2.8430e-16, 5.7404e-06, 4.4713e-08, 4.0628e-10,
        1.8499e-05, 5.0890e-07, 4.0815e-09, 5.7801e-13, 3.0133e-14, 1.0943e-08],
       device='cuda:0', grad_fn=<SigmoidBackward0>)
Episode 200 action_probs = tensor([2.5217e-07, 4.3660e-03, 2.8992e-17, 2.1213e-06, 1.2536e-08, 8.5113e-11,
        7.5278e-06, 1.5869e-07, 1.0403e-09, 9.0564e-14, 3.8344e-15, 2.9236e-09],
       device='cuda:0', grad_fn=<SigmoidBackward0>)


KeyboardInterrupt: 

In [None]:
torch.save(actor.state_dict(), 'actor_model.pth')
torch.save(critic.state_dict(), 'critic_model.pth')

In [13]:
def predict_action(state):
    actor.load_state_dict(torch.load('actor_model.pth'))
    state = torch.tensor(state, dtype=torch.float32).cuda()
    action_probs = actor(state)
    collection = torch.argmax(action_probs[:5]).item()
    field = torch.argmax(action_probs[5:11]).item()
    index_status = action_probs[11].item()
    return collection, field, index_status