In [46]:
import json
import os
import math
import numpy as np
import torch.nn.functional as F
from PIL import Image
from transformers import LayoutLMv2Processor
from datasets import load_dataset, Features, Sequence, ClassLabel, Value, Array2D, Array3D
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import LayoutLMv2ForTokenClassification


In [47]:
funsd_dataset = load_dataset("nielsr/funsd")

In [48]:
labels = ['O', 'B-HEADER', 'I-HEADER', 'B-QUESTION', 'I-QUESTION', 'B-ANSWER', 'I-ANSWER']
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}

In [49]:
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")



In [50]:
features = Features({
    'image': Array3D(dtype="int64", shape=(3, 224, 224)),
    'input_ids': Sequence(feature=Value(dtype='int64')),
    'attention_mask': Sequence(Value(dtype='int64')),
    'token_type_ids': Sequence(Value(dtype='int64')),
    'bbox': Array2D(dtype="int64", shape=(512, 4)),
    'labels': Sequence(ClassLabel(names=labels)),
})

In [51]:
def preprocess_data(examples):
    images = [Image.open(path).convert("RGB") for path in examples['image_path']]
    words = examples['words']
    boxes = examples['bboxes']
    word_labels = examples['ner_tags']

    encoded_inputs = processor(images, words, boxes=boxes, word_labels=word_labels,
                               padding="max_length", truncation=True)

    return encoded_inputs

In [52]:
train_dataset = funsd_dataset['train'].map(preprocess_data, batched=True, remove_columns=funsd_dataset['train'].column_names, features=features)
test_dataset = funsd_dataset['test'].map(preprocess_data, batched=True, remove_columns=funsd_dataset['test'].column_names, features=features)

train_dataset.set_format(type="torch")
test_dataset.set_format(type="torch")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [53]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased', num_labels=len(label2id))
model.config.id2label = id2label
model.config.label2id = label2id

Some weights of LayoutLMv2ForTokenClassification were not initialized from the model checkpoint at microsoft/layoutlmv2-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
# RL Agent
class RLAgent(nn.Module):
    def __init__(self, state_size, action_size):
        super(RLAgent, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    
class RLAgent(nn.Module):
    def __init__(self):
        super(RLAgent, self).__init__()

        self.selected_input_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.selected_bbox_embedding = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.selected_token_type_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.remain_input_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.remain_bbox_embedding = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.remain_token_type_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 512)

    def forward(self, state):
        selected_input_ids =  state["selected_input_ids"]
        selected_token_type_ids = state["selected_token_type_ids"]
        selected_bbox = state["selected_bbox"]

        selected = self.selected_input_embedding(selected_input_ids) + \
            self.selected_bbox_embedding(selected_bbox.view(-1)) +\
            self.selected_token_type_embedding(selected_token_type_ids)

        remain_input_ids = state["remain_input_ids"]
        remain_token_type_ids = state["remain_token_type_ids"]
        remain_bbox = state["remain_bbox"]

        remained = self.remain_input_embedding(remain_input_ids) + \
            self.remain_bbox_embedding(remain_bbox.view(-1)) +\
            self.remain_token_type_embedding(remain_token_type_ids)

        selected = selected.unsqueeze(1)
        remained = remained.unsqueeze(1)

        attn_logits = torch.matmul(selected, remained.transpose(-2, -1))
        attention = F.softmax(attn_logits, dim=-1)


        values = torch.matmul(attention, remained).squeeze()

        x = torch.relu(self.fc1(values))
        return self.fc2(x)

In [55]:
def get_result(model, input_ids, mask, token_type_ids, bbox, image, target):
    outputs = model(
        input_ids=input_ids.unsqueeze(0),
        attention_mask=mask.unsqueeze(0),
        token_type_ids=token_type_ids.unsqueeze(0),
        bbox=bbox.unsqueeze(0),
        image=image.unsqueeze(0),
        labels=target
    )
    return outputs

In [56]:
def train_rl_agent(agent, dataloader, num_epochs=10, learning_rate=1e-5, gamma=0.99):
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)
    criterion = nn.SmoothL1Loss()
    model.to(device)

    for epoch in range(num_epochs):
        for batch in dataloader:
            batch = {k: v.squeeze() for k, v in batch.items()}

            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            bbox = batch['bbox'].to(device)
            target = batch['labels'].to(device)
            image = batch['image'].to(device)
            mask = torch.zeros_like(batch['input_ids'], dtype=torch.bool).to(device)

            remain_input_ids = input_ids.clone()
            remain_token_type_ids = token_type_ids.clone()
            remain_bbox = bbox.clone()
            remain_target = target.clone()

            selected_input_ids = torch.zeros_like(input_ids).to(device)
            selected_token_type_ids = torch.zeros_like(token_type_ids).to(device)
            selected_bbox = torch.zeros_like(bbox).to(device)
            selected_target = torch.full((512,), 0).to(device)  # 0 is the special token

            state = {
                "selected_input_ids": selected_input_ids.float(),
                "selected_token_type_ids": selected_token_type_ids.float(),
                "selected_bbox": selected_bbox.float(),
                "remain_input_ids": remain_input_ids.float(),
                "remain_token_type_ids": remain_token_type_ids.float(),
                "remain_bbox": remain_bbox.float()
            }

            outputs = get_result(model, selected_input_ids, mask, selected_token_type_ids, selected_bbox, image, selected_target)
            prev_loss = outputs.loss.item()

            for step in range(len(batch['input_ids'])):
                random_prob = max(0.1, 1 - epoch / num_epochs)
                max_q_prob = min(0.8, epoch / num_epochs)
                step_prob = 1 - random_prob - max_q_prob

                # Normalize probabilities
                probs = [random_prob, max_q_prob, step_prob]
                probs = [p / sum(probs) for p in probs]

                # Choose action selection strategy
                strategy = np.random.choice(['random', 'max_q', 'step'], p=probs)
                q_values = agent(state)
                
                if strategy == 'random':
                    action = np.random.choice(len(remain_input_ids))
                elif strategy == 'max_q':
                    action = torch.argmax(q_values).item()
                else:  # step strategy
                    action = step

                if remain_input_ids[action] == -100:
                    reward = -1
                else:
                    selected_input_ids[step] = remain_input_ids[action]
                    selected_token_type_ids[step] = remain_token_type_ids[action]
                    selected_bbox[step] = remain_bbox[action]
                    selected_target[step] = remain_target[action]

                    mask = selected_input_ids != 0

                    remain_input_ids[action] = -100
                    remain_token_type_ids[action] = -100
                    remain_bbox[action] = torch.tensor([-100, -100, -100, -100])
                    remain_target[action] = -100
                    
                    outputs = get_result(model, selected_input_ids, mask, selected_token_type_ids, selected_bbox, image, selected_target)
                    cur_loss = outputs.loss.item()

                    reward = prev_loss - cur_loss

                    state = {
                        "selected_input_ids": selected_input_ids.float(),
                        "selected_token_type_ids": selected_token_type_ids.float(),
                        "selected_bbox": selected_bbox.float(),
                        "remain_input_ids": remain_input_ids.float(),
                        "remain_token_type_ids": remain_token_type_ids.float(),
                        "remain_bbox": remain_bbox.float()
                    }

                    prev_loss = cur_loss

                with torch.no_grad():
                    next_q_values = agent(state)
                    max_next_q_value = torch.max(next_q_values).item()
                    target_q_value = reward + gamma * max_next_q_value

                loss = criterion(q_values[action], torch.tensor(target_q_value).to(device))
                print(f"Step {step}, reward: {reward}, agent loss: {loss}, model loss: {cur_loss}")

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    return agent

In [57]:
agent = RLAgent()
trained_agent = train_rl_agent(agent, train_loader)

Step 0, reward: 0.040981173515319824, agent loss: 3625.28076171875, model loss: 1.972579836845398
Step 1, reward: 0.00022530555725097656, agent loss: 2169.290771484375, model loss: 1.972354531288147
Step 2, reward: -0.033066630363464355, agent loss: 3479.0634765625, model loss: 2.0054211616516113
Step 3, reward: 0.012583732604980469, agent loss: 2857.23095703125, model loss: 1.9928374290466309
Step 4, reward: -0.0003248453140258789, agent loss: 2272.423095703125, model loss: 1.9931622743606567
Step 5, reward: -4.8041343688964844e-05, agent loss: 1077.6231689453125, model loss: 1.9932103157043457
Step 6, reward: -0.005624055862426758, agent loss: 2692.226318359375, model loss: 1.9988343715667725
Step 7, reward: -4.684925079345703e-05, agent loss: 2511.562255859375, model loss: 1.998881220817566
Step 8, reward: -0.0018848180770874023, agent loss: 2082.2294921875, model loss: 2.0007660388946533
Step 9, reward: -3.0994415283203125e-05, agent loss: 3852.589599609375, model loss: 2.000797033

KeyboardInterrupt: 