In [1]:
# !pip install -q datasets seqeval

  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <E03EDA44-89AE-3115-9796-62BA9E0E2EDE> /Users/dongpochen/opt/anaconda3/envs/nlp/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [None]:
# !python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
# !huggingface-cli login

In [None]:
import json
import os
import math
import numpy as np
from datasets import load_metric
import torch.nn.functional as F
from PIL import Image
from transformers import LayoutLMv2Processor
from datasets import load_dataset, Features, Sequence, ClassLabel, Value, Array2D, Array3D
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import LayoutLMv2ForTokenClassification


In [2]:
funsd_dataset = load_dataset("nielsr/funsd")

In [3]:
labels = ['O', 'B-HEADER', 'I-HEADER', 'B-QUESTION', 'I-QUESTION', 'B-ANSWER', 'I-ANSWER']
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}

In [4]:
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")



In [5]:
features = Features({
    'image': Array3D(dtype="int64", shape=(3, 224, 224)),
    'input_ids': Sequence(feature=Value(dtype='int64')),
    'attention_mask': Sequence(Value(dtype='int64')),
    'token_type_ids': Sequence(Value(dtype='int64')),
    'bbox': Array2D(dtype="int64", shape=(512, 4)),
    'labels': Sequence(ClassLabel(names=labels)),
})

In [6]:
def get_result(model, input_ids, mask, token_type_ids, bbox, image, target):
    outputs = model(
        input_ids=input_ids.unsqueeze(0),
        attention_mask=mask.unsqueeze(0),
        token_type_ids=token_type_ids.unsqueeze(0),
        bbox=bbox.unsqueeze(0),
        image=image.unsqueeze(0),
        labels=target
    )
    return outputs

In [7]:
def preprocess_data(examples):
    images = [Image.open(path).convert("RGB") for path in examples['image_path']]
    words = examples['words']
    boxes = examples['bboxes']
    word_labels = examples['ner_tags']

    encoded_inputs = processor(images, words, boxes=boxes, word_labels=word_labels,
                               padding="max_length", truncation=True)

    return encoded_inputs

In [8]:
train_dataset = funsd_dataset['train'].map(preprocess_data, batched=True, remove_columns=funsd_dataset['train'].column_names, features=features)
test_dataset = funsd_dataset['test'].map(preprocess_data, batched=True, remove_columns=funsd_dataset['test'].column_names, features=features)

train_dataset.set_format(type="torch")
test_dataset.set_format(type="torch")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [9]:
from transformers import AutoProcessor, AutoModelForTokenClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForTokenClassification.from_pretrained("edmondz/layoutlmv2-finetuned-funsd-test").to(device)

# model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased', num_labels=len(label2id))
model.config.id2label = id2label
model.config.label2id = label2id

In [None]:
if (device == 'cpu'):
    print_epoch = 1
else:
    print_epoch = 10

In [10]:
def compute_metrics(outputs, target):
    predictions = np.argmax(outputs.logits.detach().numpy(), axis=2)[0]
    true_predictions = []
    true_labels = []

    counter = 0

    metric = load_metric("seqeval")

    for prediction, label in zip(predictions, target):
        current_prediction = []
        if label != -100:
            counter += 1
            current_prediction.append(id2label[prediction])
        true_predictions.append(current_prediction)

    for prediction, label in zip(predictions, target):
        current_labels = []
        if label != -100:
            current_labels.append(id2label[label.item()])
        true_labels.append(current_labels)

    if counter == 0:
        return {
            "precision": 0.0,
            "recall": 0.0,
            "f1": 0.0,
            "accuracy": 0.0,
        }
    else:
        results = metric.compute(predictions=true_predictions, references=true_labels, zero_division=0)
        return {
                "precision": results["overall_precision"],
                "recall": results["overall_recall"],
                "f1": results["overall_f1"],
                "accuracy": results["overall_accuracy"],
            }

In [11]:
# RL Agent
class RLAgent(nn.Module):
    def __init__(self, state_size, action_size):
        super(RLAgent, self).__init__()
        self.fc1 = nn.Linear(state_size, 256)
        self.fc2 = nn.Linear(256, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)
    
class RLAgent(nn.Module):
    def __init__(self):
        super(RLAgent, self).__init__()

        self.selected_input_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.selected_bbox_embedding = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.selected_token_type_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.remain_input_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.remain_bbox_embedding = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.remain_token_type_embedding = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 512)

    def forward(self, state):
        selected_input_ids =  state["selected_input_ids"]
        selected_token_type_ids = state["selected_token_type_ids"]
        selected_bbox = state["selected_bbox"]

        selected = self.selected_input_embedding(selected_input_ids) + \
            self.selected_bbox_embedding(selected_bbox.view(-1)) +\
            self.selected_token_type_embedding(selected_token_type_ids)

        remain_input_ids = state["remain_input_ids"]
        remain_token_type_ids = state["remain_token_type_ids"]
        remain_bbox = state["remain_bbox"]

        remained = self.remain_input_embedding(remain_input_ids) + \
            self.remain_bbox_embedding(remain_bbox.view(-1)) +\
            self.remain_token_type_embedding(remain_token_type_ids)

        selected = selected.unsqueeze(1)
        remained = remained.unsqueeze(1)

        attn_logits = torch.matmul(selected, remained.transpose(-2, -1))
        attention = F.softmax(attn_logits, dim=-1)

        values = torch.matmul(attention, remained).squeeze()

        x = torch.relu(self.fc1(values))
        return self.fc2(x)

In [12]:
def train_rl_agent(agent, dataloader, num_epochs=10, learning_rate=1e-5, gamma=0.99, 
                   epsilon_start=1.0, epsilon_end=0.1, epsilon_decay=0.995):
    
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)
    criterion = nn.SmoothL1Loss()
    model.to(device)

    epsilon = epsilon_start

    for epoch in range(num_epochs):
        for batch in dataloader:
            batch = {k: v.squeeze() for k, v in batch.items()}

            input_ids = batch['input_ids'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            bbox = batch['bbox'].to(device)
            target = batch['labels'].to(device)
            image = batch['image'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            mask = torch.zeros_like(batch['input_ids'], dtype=torch.bool).to(device)

            remain_input_ids = input_ids.clone()
            remain_token_type_ids = token_type_ids.clone()
            remain_bbox = bbox.clone()
            remain_target = target.clone()

            selected_input_ids = torch.zeros_like(input_ids).to(device)
            selected_token_type_ids = torch.zeros_like(token_type_ids).to(device)
            selected_bbox = torch.zeros_like(bbox).to(device)
            selected_target = torch.full((512,), -100).to(device)

            state = {
                "selected_input_ids": selected_input_ids.float(),
                "selected_token_type_ids": selected_token_type_ids.float(),
                "selected_bbox": selected_bbox.float(),
                "remain_input_ids": remain_input_ids.float(),
                "remain_token_type_ids": remain_token_type_ids.float(),
                "remain_bbox": remain_bbox.float()
            }

            outputs = get_result(model, input_ids, attention_mask, token_type_ids, bbox, image, target)

            print(f"UnSequenced Batch Loss: {outputs.loss.item()}")

            outputs = get_result(model, selected_input_ids, mask, selected_token_type_ids, selected_bbox, image, selected_target)
            prev_loss = outputs.loss.item()

            step = 0
            while (step < 511):
                q_values = agent(state)
                if np.random.rand() < epsilon:
                    action = np.random.choice(len(remain_input_ids))
                else:
                    action = torch.argmax(q_values).item()
                
                # Add the [CLS token] to the begining of the sequence
                if (step == 0):
                    action = step

                # The token has been picked cannot be picked again
                if remain_input_ids[action] == -100:
                    reward = -1
                else:
                    # Add [EOS Token] to end of the sequence
                    selected_input_ids[step + 1] = 102
                    selected_target[step + 1] = -100

                    # Add Selected Token to sequence
                    selected_input_ids[step] = remain_input_ids[action]
                    selected_token_type_ids[step] = remain_token_type_ids[action]
                    selected_bbox[step] = remain_bbox[action]
                    selected_target[step] = remain_target[action]

                    # Generate Attention Mask
                    mask = selected_input_ids != 0

                    # Remove the Selected Token from Remain
                    remain_input_ids[action] = 0
                    remain_token_type_ids[action] = 0
                    remain_bbox[action] = torch.tensor([0, 0, 0, 0])
                    remain_target[action] = -100


                    # Get the result of current selected sequence
                    outputs = get_result(model, selected_input_ids, mask, selected_token_type_ids, selected_bbox, image, selected_target)

                    metrics = compute_metrics(outputs, selected_target)
                    cur_loss = outputs.loss.item()

                    if (math.isnan(cur_loss) or math.isnan(prev_loss)):
                        reward = 0
                    else:
                        reward = prev_loss - cur_loss + metrics['accuracy'] + metrics['f1']

                    prev_loss = cur_loss
                    
                    step += 1

                with torch.no_grad():
                    next_q_values = agent(state)
                    max_next_q_value = torch.max(next_q_values).item()
                    target_q_value = reward + gamma * max_next_q_value

                loss = criterion(q_values[action], torch.tensor(target_q_value).to(device))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if (step % print_epoch == 0):
                    print(f"Step {step}, reward: {reward}, agent loss: {loss}, model loss: {cur_loss}")

                state = {
                        "selected_input_ids": selected_input_ids.float(),
                        "selected_token_type_ids": selected_token_type_ids.float(),
                        "selected_bbox": selected_bbox.float(),
                        "remain_input_ids": remain_input_ids.float(),
                        "remain_token_type_ids": remain_token_type_ids.float(),
                        "remain_bbox": remain_bbox.float()
                }

        epsilon = max(epsilon_end, epsilon * epsilon_decay)

    return agent

In [13]:
agent = RLAgent()
trained_agent = train_rl_agent(agent, train_loader)

[W NNPACK.cpp:61] Could not initialize NNPACK! Reason: Unsupported hardware.


UnSequenced Batch Loss: 3.5593161582946777


  metric = load_metric("seqeval")


Step 1, reward: 0, agent loss: 3656.5478515625, model loss: nan
Step 2, reward: 0, agent loss: 1321.29345703125, model loss: 4.295873165130615
Step 3, reward: -0.12428760528564453, agent loss: 2365.6997901535033, model loss: 4.42016077041626
Step 4, reward: 0.025440216064453125, agent loss: 2655.9764826965334, model loss: 4.394720554351807
Step 5, reward: 0.029704570770263672, agent loss: 3078.708588848114, model loss: 4.365015983581543
Step 6, reward: 3.00687837600708, agent loss: 1264.0548983955382, model loss: 2.358137607574463
Step 7, reward: 0.9546432495117188, agent loss: 3298.5055856323243, model loss: 2.403494358062744
Step 8, reward: 1.0123977661132812, agent loss: 3108.3795584106447, model loss: 2.391096591949463
Step 9, reward: 0.9891312122344971, agent loss: 972.5921121692659, model loss: 2.401965379714966
Step 10, reward: 0.997020959854126, agent loss: 2660.8182344055176, model loss: 2.40494441986084
Step 11, reward: 1.0024187564849854, agent loss: 3355.6114763736723, mode

KeyboardInterrupt: 