In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class RLAgent(nn.Module):
    def __init__(self):
        super(RLAgent, self).__init__()

        self.E_s_input = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )

        self.E_s_bbox = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )

        self.E_r_input = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )

        self.E_r_bbox = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )

        self.E_remained = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )

        self.step_processor = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 512)
        )

        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 512)

    def forward(self, state, step):
        selected_input_ids = state["selected_input_ids"].float().to(device)
        selected_bbox = state["selected_bbox"].float().to(device)
        remain_input_ids = state["remain_input_ids"].float().to(device)
        remain_bbox = state["remain_bbox"].float().to(device)

        selected = self.E_s_input(selected_input_ids) + self.E_s_bbox(selected_bbox.view(-1))
        remained = self.E_r_input(remain_input_ids) + self.E_r_bbox(remain_bbox.view(-1))
        v = self.E_remained(remained)

        selected = selected.unsqueeze(1)
        remained = remained.unsqueeze(1)

        attn_logits = torch.matmul(selected, remained.transpose(-2, -1))
        attention = F.softmax(attn_logits, dim=-1)

        score = torch.matmul(attention, v).squeeze()

        step_tensor = torch.tensor([step], dtype=torch.float32, device=score.device).view(-1, 1)
        step_output = self.step_processor(step_tensor)

        combined = score + step_output

        x = torch.relu(self.fc1(combined))
        return self.fc2(x)

In [4]:
agent = RLAgent().to(device)