In [2]:
import random
import re
from collections import deque
from typing import Callable

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [47]:
def unary_lambda(symbol: str) -> Callable[[str], str]:
    return lambda a: f"({a}){symbol}"


letters = "abcdefghijklmnopqrstuvwxyz"
digits = "0123456789"

quantifiers: set[str] = {"*", "+", "?"}

unary_regex: dict[str, Callable[[str], str]] = {
    **{symbol: unary_lambda(symbol) for symbol in quantifiers},
    **{symbol: unary_lambda(f"{symbol}?") for symbol in quantifiers},  # greedy versions
}
binary_regex: dict[str, Callable[[str, str], str]] = {
    "concat": lambda a, b: a + b,  # concatenation
    "|": lambda a, b: f"{a}|{b}",
}
many_to_one_regex: dict[str, Callable[[list[str]], str]] = {
    "[]": lambda a: "[" + "".join([f"({x})" for x in a]) + "]"  # set
}

specials_regex: set[str] = {
    ".",
    "^",
    "$",
}

operands_regex: set[str] = {
    *specials_regex,
    *[f"\\{s}" for s in specials_regex],
    *letters,
    *letters.upper(),
    *digits,
    *[f"\\{s}" for s in quantifiers],
}


def rpn_to_infix_regex(expression: list):
    global unary_regex, binary_regex, operands_regex, many_to_one_regex
    stack = []

    for token in expression:
        if token in many_to_one_regex:
            stack = [many_to_one_regex[token](list(reversed(stack)))]
            continue

        if token in binary_regex:
            operand2 = stack.pop()
            operand1 = stack.pop()
            stack.append(binary_regex[token](operand1, operand2))
            continue
        if token in unary_regex:
            operand = stack.pop()
            stack.append(f"({token}{operand})")
            continue
        if token in operands_regex:
            stack.append(token)
            continue
        raise RuntimeError(f"Operand '{token}' is unknown")

    return stack[0]

In [48]:
IDX_TO_ACTIONS = {
    (i): action
    for i, action in enumerate(
        set().union(
            operands_regex,
            set(binary_regex.keys()),
            set(unary_regex.keys()),
            set(many_to_one_regex.keys()),
        )
    )
}
IDX_TO_ACTIONS[len(IDX_TO_ACTIONS)] = "FINISH"
ACTIONS_TO_IDX = {v: k for k, v in IDX_TO_ACTIONS.items()}

ACTIONS = set(ACTIONS_TO_IDX.keys())

In [None]:
class Environment:
    penalty = -100
    word_penalty = -10000
    len_penalty = -1

    max_steps = 5

    def __init__(
        self, texts: list[str], targets: list[list[int]], words_counts: list[int]
    ):
        self.texts = texts
        self.targets = targets
        self.words_counts = words_counts
        self.current_index = 0

        self.empty_state_idx = len(ACTIONS_TO_IDX)
        self.finish_action_idx = ACTIONS_TO_IDX["FINISH"]

        self.reset()

    def reset(self):
        self.step_idx = 0
        self.current_index += 1
        if self.current_index >= len(self.words_counts):
            self.current_index = 0
        self.state = [self.empty_state_idx] * self.max_steps
        return self.state

    def reward(self, regexp_actions: list[int]):
        try:
            regex_symbols = [IDX_TO_ACTIONS[x] for x in regexp_actions]
            regexp = rpn_to_infix_regex(regex_symbols)
            print(regexp)
            array = [
                x.span() for x in re.finditer(regexp, self.texts[self.current_index])
            ]
        except BaseException:
            return -1000000
        bit_mask = [0 for _ in range(len(self.texts[self.current_index]))]

        for it in array:
            for i in range(it[0], it[1]):
                bit_mask[i] = 1

        return (
            sum(np.bitwise_xor(bit_mask, self.targets[self.current_index])) * self.penalty
        )  # + abs(len(array)-self.words_counts[self.current_index]) * self.word_penalty

    def step(self, action: int) -> tuple[list, float, bool]:
        self.step_idx += 1
        if (
            action == self.finish_action_idx
        ) or self.step_idx > self.max_steps:  # finish action
            reward = self.reward(self.state[: self.step_idx])
            self.reset()
            return self.state, reward, True

        self.state[self.step_idx - 1] = action
        return self.state, -1, False

In [None]:
device = torch.device("cuda")  # "cuda" if torch.cuda.is_available() else
len(ACTIONS)
device

device(type='cuda')

In [51]:
class DQN(nn.Module):
    def __init__(self, state_dim=Environment.max_steps, action_dim=len(ACTIONS)):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(state_dim, 64), nn.ReLU(), nn.Linear(64, action_dim)
        )

    def forward(self, x):
        return self.layer(x)

In [52]:
class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)).to(device),
            torch.LongTensor(np.array(actions)).to(device),
            torch.FloatTensor(np.array(rewards)).to(device),
            torch.FloatTensor(np.array(next_states)).to(device),
            torch.FloatTensor(np.array(dones)).to(device),
        )

    def __len__(self):
        return len(self.memory)

In [53]:
policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

DQN(
  (layer): Sequential(
    (0): Linear(in_features=5, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=78, bias=True)
  )
)

In [None]:
optimizer = optim.Adam(policy_net.parameters(), lr=1e-2)
memory = ReplayMemory(1000)

batch_size = 32
GAMMA = 0.99

In [55]:
eps = 1
EPS_START = 1
EPS_END = 0.01
EPS_DECAY = 0.999


def select_action(state):
    global eps
    if random.random() > eps:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1) + 1
    else:
        return torch.tensor(
            [[random.randrange(1, len(ACTIONS))]], device=device, dtype=torch.long
        )

In [None]:
def optimize_model():
    if len(memory) < batch_size:
        return

    states, actions, rewards, next_states, dones = memory.sample(batch_size)

    # print(states.shape)
    # print(actions.shape)
    # print(policy_net(states).squeeze().shape)
    # print( actions.squeeze(1).shape)

    # Compute Q(s_t, a)
    current_q = torch.gather(policy_net(states).squeeze(), 1, actions.squeeze(1))

    # Compute V(s_{t+1}) using target network
    next_q = target_net(next_states).squeeze().max(1)[0].detach()
    print(rewards.shape, next_q.shape)
    expected_q = rewards + (GAMMA * next_q * (1 - dones))

    # Compute Huber loss
    loss = nn.functional.smooth_l1_loss(current_q.squeeze(), expected_q)

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [57]:
def train(env, num_episodes=10_000):
    global eps
    for episode in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        total_reward = 0
        done = False

        while not done:
            # Select and perform an action
            action = select_action(state)
            next_state, reward, done = env.step(action.item())
            total_reward += reward

            # Store transition in memory
            next_state = torch.FloatTensor(next_state).unsqueeze(0).to(device)
            memory.push(
                state.cpu().numpy(),
                action.cpu().numpy(),
                reward,
                next_state.cpu().numpy(),
                done,
            )

            # Move to the next state
            state = next_state

            # Perform one step of optimization
            optimize_model()

        # Decay epsilon
        eps = max(EPS_END, eps * EPS_DECAY)

        # Update the target network
        if episode % 10 == 0:
            target_net.load_state_dict(policy_net.state_dict())

        print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {eps:.2f}")

In [58]:
train_data = """a2b """
array = [x.span() for x in re.finditer("[1-9]", train_data)]

target = [0 for _ in range(len(train_data))]
for it in array:
    for i in range(it[0], it[1]):
        target[i] = 1

In [59]:
env = Environment([train_data], [target], [1])
train(env)

Q
Episode 0, Total Reward: -105, Epsilon: 1.00
\?
Episode 1, Total Reward: -105, Epsilon: 1.00
R
Episode 2, Total Reward: -105, Epsilon: 1.00
$
Episode 3, Total Reward: -105, Epsilon: 1.00
\+
Episode 4, Total Reward: -105, Epsilon: 1.00
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
\.
torch.Size([32]) torch.Size([32])
Episode 5, Total Reward: -105, Epsilon: 0.99
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
Episode 6, Total Reward: -1000005, Epsilon: 0.99
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
torch.Size([32]) torch.Size([32])
f
torch.Size([32]) torch.Size([32])
Episode 7, Total Reward: -105, Epsilon: 0.99
torch.Size([32]) torch.Size([32])
torch.S

KeyboardInterrupt: 