In [None]:
import re
from typing import Callable

import numpy as np
import torch
import torch.nn as nn

In [None]:
def unary_lambda(symbol: str) -> Callable[[str], str]:
    return lambda a: f"({a}){symbol}"


letters = "abcdefghijklmnopqrstuvwxyz"
digits = "0123456789"

quantifiers: set[str] = {"*", "+", "?"}

unary_regex: dict[str, Callable[[str], str]] = {
    **{symbol: unary_lambda(symbol) for symbol in quantifiers},
    **{symbol: unary_lambda(f"{symbol}?") for symbol in quantifiers},  # greedy versions
}
binary_regex: dict[str, Callable[[str, str], str]] = {
    "concat": lambda a, b: a + b,  # concatenation
    "|": lambda a, b: f"{a}|{b}",
}
many_to_one_regex: dict[str, Callable[[list[str]], str]] = {
    "[]": lambda a: "[" + "".join([f"({x})" for x in a]) + "]",  # set
    "^[]": lambda a: "[^" + "".join([f"({x})" for x in a]) + "]",  # not in set
    "concat_all": "".join,
}

specials_regex: set[str] = {
    ".",
    "^",
    "$",
}

operands_regex: set[str] = {
    *specials_regex,
    *[f"\\{s}" for s in specials_regex],
    *letters,
    *letters.upper(),
    *digits,
    *[f"\\{s}" for s in quantifiers],
}


def rpn_to_infix_regex(expression: list):
    global unary_regex, binary_regex, operands_regex, many_to_one_regex
    if len(expression) == 0:
        return ""

    if expression[-1] != "concat_all":
        expression.append("concat_all")
    stack = []

    for token in expression:
        if token in many_to_one_regex:
            stack = [many_to_one_regex[token](list(reversed(stack)))]
            continue

        if token in binary_regex:
            operand2 = stack.pop()
            operand1 = stack.pop()
            stack.append(binary_regex[token](operand1, operand2))
            continue
        if token in unary_regex:
            operand = stack.pop()
            stack.append(f"({token}{operand})")
            continue
        if token in operands_regex:
            stack.append(token)
            continue
        raise RuntimeError(f"Operand '{token}' is unknown")

    return stack[0]

In [None]:
numbers: set[str] = {
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
    "0",
}

In [None]:
IDX_TO_ACTIONS = {
    (i): action
    for i, action in enumerate(
        set().union(
            numbers,
            # operands_regex,
            # set(binary_regex.keys()),
            # set(unary_regex.keys()),
            # set(many_to_one_regex.keys()),
        )
    )
}
IDX_TO_ACTIONS[len(IDX_TO_ACTIONS)] = "FINISH"
ACTIONS_TO_IDX = {v: k for k, v in IDX_TO_ACTIONS.items()}

ACTIONS = set(ACTIONS_TO_IDX.keys())

In [None]:
IDX_TO_ACTIONS

In [None]:
HIGHEST = 100

In [None]:
class Environment:
    penalty = -100
    # penalty = 1
    word_penalty = -10000
    len_penalty = -1

    max_steps = 5

    def __init__(
        self, text: str, target: list[int], words_count: int, penalty_weights: dict = None
    ):
        self.text = text
        self.target = target
        self.words_count = words_count
        self.current_index = 0
        self.regexp = None

        self.empty_state_idx = len(ACTIONS_TO_IDX)
        self.finish_action_idx = ACTIONS_TO_IDX["FINISH"]

        default_weights = {
            "f1": 10.0,
            "precision": 2.0,
            "recall": 2.0,
            "complexity": -0.5,
            "full_match": 50.0,
            "syntax_error": -100.0,
            "partial_progress": 5.0,
            "length_penalty": -0.3,
        }
        self.penalty_weights = penalty_weights or default_weights

        self.reset()

    def reset(self):
        self.step_idx = 0
        self.state = np.array([self.empty_state_idx] * self.max_steps)
        self.regex_history = []
        return self.get_state()

    def get_state(self):
        return self.state

    def get_state_tensor(self):
        return torch.FloatTensor(1 - (self.state / self.empty_state_idx)).unsqueeze(0)

    def step(self, action: int) -> tuple[list, float, bool]:
        self.step_idx += 1
        if action == self.finish_action_idx:
            reward = self.reward(self.state[: self.step_idx])
            # self.reset()
            return self.state, reward, True
        if self.step_idx > self.max_steps:  # finish action
            reward = self.reward(self.state[: self.step_idx])
            # self.reset()
            return self.state, reward, True

        self.state[self.step_idx - 1] = action
        return self.state, self.reward(self.state[: self.step_idx]), False

    def _calculate_metrics(self, regex_str: str):
        try:
            # compiled_re = re.compile(regex_str)
            # matches = [m.span() for m in compiled_re.finditer(self.text)]
            matches = [m.span() for m in re.finditer(regex_str, self.text)]

        except:
            print("ERROR_calculate_metrics")
            return None, None, None, True

        target_mask = self.target
        pred_mask = np.zeros_like(target_mask)

        # Create prediction mask
        for start, end in matches:
            pred_mask[start:end] = 1

        reverse_target_mask = [1 - x for x in target_mask]
        reverse_pred_mask = [1 - x for x in pred_mask]

        # Calculate metrics
        tp = np.logical_and(pred_mask, target_mask).sum()
        fp = np.logical_and(pred_mask, reverse_target_mask).sum()
        fn = np.logical_and(reverse_pred_mask, target_mask).sum()

        precision = tp / (tp + fp + 1e-9)
        recall = tp / (tp + fn + 1e-9)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-9)

        return f1, precision, recall, False

    def reward(self, regex_actions: list[int]):
        # Convert actions to regex string
        try:
            regex_symbols = [
                IDX_TO_ACTIONS[x]
                for x in regex_actions
                if x != self.finish_action_idx or x != self.empty_state_idx
            ]
            regex_str = rpn_to_infix_regex(regex_symbols)
            print(regex_str)
            self.regex_history.append(regex_str)
        except Exception:
            # Invalid regex construction
            print("ERROR_syntax_error")
            return self.penalty_weights["syntax_error"]

        # Calculate base metrics
        f1, precision, recall, is_invalid = self._calculate_metrics(regex_str)
        if is_invalid:
            return self.penalty_weights["syntax_error"]

        # Calculate components
        reward_components = {
            "f1": f1 * self.penalty_weights["f1"],
            "precision": precision * self.penalty_weights["precision"],
            "recall": recall * self.penalty_weights["recall"],
            # 'complexity': len(regex_str) * self.penalty_weights['complexity'],
            "length_penalty": len(regex_actions) * self.penalty_weights["length_penalty"],
        }

        # Full match bonus
        # if f1 >= 0.99:
        #     reward_components["full_match"] = self.penalty_weights["full_match"]

        # # Partial progress bonus (compare with previous attempts)
        # if len(self.regex_history) > 1:
        #     prev_f1 = self._calculate_metrics(self.regex_history[-2])[0] or 0
        #     reward_components["partial_progress"] = self.penalty_weights[
        #         "partial_progress"
        #     ] * (f1 - prev_f1)

        # Total reward calculation
        total_reward = sum(reward_components.values())

        # Apply non-linear scaling
        return np.sign(total_reward) * np.log1p(np.abs(total_reward))

In [None]:
device = torch.device("cpu")  # "cuda" if torch.cuda.is_available() else
len(ACTIONS)
device

In [None]:
ACTIONS

In [None]:
len(ACTIONS)

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, observation_space: int, action_space: int, hidden_dim: int = 128):
        super(PolicyNetwork, self).__init__()
        self.input_layer = nn.Linear(observation_space, hidden_dim)
        self.layer1 = nn.Linear(hidden_dim, hidden_dim * 2)
        self.layer2 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, action_space)

    def forward(self, x):
        x = F.leaky_relu(self.input_layer(x))
        x = F.leaky_relu(self.layer1(x))
        x = F.leaky_relu(self.layer2(x))
        actions = self.output_layer(x)
        return F.softmax(actions)

In [None]:
class A2CNet(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = 128) -> None:
        super(A2CNet, self).__init__()

        self.body = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )

        self.policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim),
        )

        self.value = nn.Sequential(
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
        body_out = self.body(x)
        return self.policy(body_out), self.value(body_out)

In [None]:
from random import random


def select_action(policy_network: nn.Module, env: Environment, on_policy=bool):
    state = env.get_state_tensor().to(device)
    action_probs = policy_network(state).squeeze()
    log_probs = torch.log(action_probs)
    cpu_action_probs = action_probs.detach().cpu().numpy()
    if on_policy:
        action = np.argmax(cpu_action_probs)
    else:
        action = np.random.choice(np.arange(len(ACTIONS)), p=cpu_action_probs)

    return action, log_probs, action_probs


class Agent:
    def choose_action(self, action_logits):
        return random.choices(range(len(action_logits)), F.softmax(action_logits, dim=0))[
            0
        ]

    def choose_optimal_action(self, action_logits) -> int:
        return int(np.argmax(F.softmax(action_logits, dim=0).cpu()).item())

In [None]:
policy_network = PolicyNetwork(Environment.max_steps, len(ACTIONS)).to(device)

gamma = 0.99
lr_policy_net = 2**-13
optimizer = torch.optim.Adam(policy_network.parameters(), lr=lr_policy_net)

In [None]:
train_data = """a45b """
array = [x.span() for x in re.finditer("[1-9]", train_data)]

target = [0 for _ in range(len(train_data))]
for it in array:
    for i in range(it[0], it[1]):
        target[i] = 1

In [None]:
target

In [None]:
# sum(np.bitwise_xor([0, 0, 1, 0], target))

In [None]:
env = Environment(train_data, target, 1)

In [None]:
best_score = -HIGHEST
best_env = env.reset()
NUM_EPISODES = 5000
scores = []

In [None]:
from copy import copy

EVAL_INTERVAL = 500
cache_hits = 0

loop = tqdm(range(NUM_EPISODES))
for episode in loop:
    state = env.reset()
    done = False
    scores.append([-HIGHEST])
    prev_reward = 0
    cumulative_discount = 1.0
    prev_states = [env.get_state()]

    on_policy = (episode + 1) % EVAL_INTERVAL == 0

    loss = 0

    episode_reward = 0
    episode_log_probs = []
    rewards = []

    # run episode
    while not done:
        action, actions_log_probabilities, _ = select_action(
            policy_network, env, on_policy
        )
        # print("Action", action, IDX_TO_ACTIONS[action])
        episode_log_probs.append(actions_log_probabilities[action])

        next_state, new_score, done = env.step(action)

        if new_score > best_score:
            best_score = new_score
            best_env = copy(next_state)
            print(
                "Best Action",
                action,
                IDX_TO_ACTIONS[action],
                best_env,
                best_score,
                env.regexp,
            )

        prev_score = scores[episode][-1]

        # reward = round((new_score - prev_score), 2)
        reward = new_score
        rewards.append(reward)
        episode_reward += reward

        scores[episode].append(new_score)

        if done:
            env.reset()

    if not on_policy:
        discounted_returns = []
        running_return = 0

        # Reverse rewards and calculate cumulative discounted returns
        for r in reversed(rewards):
            running_return = r + 0.99 * running_return
            discounted_returns.insert(0, running_return)

        returns_tensor = torch.tensor(discounted_returns)
        # returns_tensor = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + 1e-9)

        policy_loss = []
        for log_prob, Gt in zip(episode_log_probs, returns_tensor):
            policy_loss.append(-log_prob * Gt)
        total_loss = torch.stack(policy_loss).sum()

        # num_steps = len(episode_log_probs)
        # per_timestep_losses = [
        #     -log_prob * episode_reward for log_prob in episode_log_probs
        # ]
        # total_loss = torch.stack(per_timestep_losses).sum()

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    loop.set_postfix({"best": best_score})

precision, recall, f1

accuracy reward = 10 * (new_score-score)
f1 reward = 5 * f1
complexity penalty = -0.3 (len regexp)
bonus = 50 if f1 >= 0.99 else 0
syntax penalty = -15 if invalid syntax
converges bonus = 2 * log(1+env.valid_matches - env.false_positives)
Total = max(sum, -1)

In [None]:
IDX_TO_ACTIONS[8]

In [None]:
best_env

In [None]:
best_score

In [None]:
import matplotlib as mpl


def plot_run(scores: list[float], figsize: tuple[int, int] = (20, 9)):
    fig, ax1 = plt.subplots(1, 1, figsize=figsize)
    cmap = plt.get_cmap("jet", len(scores))
    for i in range(len(scores)):
        ax1.plot(scores[i], c=cmap(i))
    norm = mpl.colors.Normalize(vmin=0, vmax=len(scores))
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ticks=np.arange(len(scores)), ax=ax1)
    plt.grid()
    plt.show()


def split_scores(scores):
    odd = []
    even = []
    for i in range(len(scores)):
        if (i + 1) % EVAL_INTERVAL == 0:
            odd.append(scores[i])
        else:
            even.append(scores[i])
    return odd, even


def plot_scores(scores: list[float], figsize: tuple[int, int] = (16, 9)):
    best_scores = [max(x) for x in scores]
    worst_scores = [min(x) for x in scores]
    mean_scores = [np.mean(x) for x in scores]
    qwerty_scores = [x[0] for x in scores]

    plt.subplots(1, 1, figsize=figsize)
    plt.plot(best_scores, label="Best score")
    plt.plot(worst_scores, label="Worst score")
    plt.plot(mean_scores, label="Mean score")
    plt.plot(qwerty_scores, label="QWERTY score")

    plt.legend()
    plt.grid()
    # plt.xlim([0, 500])
    # plt.ylim([best_score, 14000])
    plt.ylabel("Score")
    plt.xlabel("Episode")
    plt.title("Scores for each episode")
    plt.show()


plot_scores(scores)