In [1]:
import json
import os

PUBTAB_ANN = '../../pubtabnet/anns/train/'
PUBTAB_IMG = '../../pubtabnet/imgs/train/'

ANN_PATH = PUBTAB_ANN
IMAGE_PATH = PUBTAB_IMG
IMG_FORMAT = '.png'

with open(ANN_PATH[:-1] + '_trunc_filelist.json') as file:
    json_list = file.read().splitlines()
    
json_list = [item for item in json_list]

In [2]:
json_list[:10]

['PMC5056244_003_01-HTML.json',
 'PMC5578000_007_00-HTML.json',
 'PMC3923559_003_00-HTML.json',
 'PMC2116997_001_00-HTML.json',
 'PMC5106812_003_01-HTML.json',
 'PMC5379740_003_00-HTML.json',
 'PMC4860401_004_00-HTML.json',
 'PMC3571919_003_00-HTML.json',
 'PMC5116148_001_00-HTML.json',
 'PMC4361879_003_00-HTML.json']

In [3]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import VisionEncoderDecoderConfig

config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")

image_size = [750, 750]
max_length = 1300


config.encoder.image_size = image_size


processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base", config=config, ignore_mismatched_sizes=True)

2023-10-28 21:16:19.410250: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-28 21:16:19.410276: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-28 21:16:19.410293: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [4]:
len(processor.tokenizer)

57525

In [5]:
from transformers import AddedToken

new_tokens  = ["<table_extraction>"]
new_tokens += ["<thead>", "</thead>", "<tbody>", "</tbody>"]
new_tokens += ["<tr>", "</tr>", "<td>", "</td>"]

new_tokens += ["<td ", ">"]
for i in range(1, 11):
    new_tokens +=['colspan="'+str(i)+'"']
    new_tokens +=['rowspan="'+str(i)+'"']

new_tokens += ["<i>", "</i>","<b>", "</b>", "<sup>", "</sup>", "<sub>", "</sub>"]

new_tokens += ["1", "≥","≤"]

new_tokens = [AddedToken(tag, rstrip = True, lstrip=True, normalized=False) for tag in new_tokens]

processor.tokenizer.add_tokens(new_tokens, special_tokens = False)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

Embedding(57566, 1024)

In [6]:
import torch
from torch import nn, Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self,  max_len: int, d_model: int, dropout: float = 0.05):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.offset = 2
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.pos_enc = pe.squeeze(1)

    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        """`input_ids' shape is expected to be [bsz x seqlen]."""

        bsz, seq_len = input_ids.shape[:2]
        
        pos_start = past_key_values_length
        pos_end = past_key_values_length+seq_len
        
        positions = self.pos_enc[pos_start:pos_end].expand(bsz, -1, -1)
        
        return self.dropout(positions)

In [7]:
model.decoder.model.decoder.embed_positions = PositionalEncoding(4098, 1024)
config.decoder.max_position_embeddings = 4096

In [8]:
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False


In [9]:
model

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
           

In [10]:
with open("msg.json", 'w') as out:
        json.dump({'outputs': []}, out, ensure_ascii=False, indent=4)

def write_msg(msg):
    with open("msg.json", encoding="utf-8") as f:
        json_data = json.load(f)
    
    with open("msg.json", 'w') as out:
        json_data['outputs'].append(msg)
        json.dump(json_data, out, ensure_ascii=False, indent=4)

In [11]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        image_size,
        max_length,
        shuffle = True,
        split = "train",
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations = annotations
        
        
        self.image_size = image_size
        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        
        self.resize = transforms.Compose([transforms.Resize(image_size)])
        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations[idx]
        
        with open(ANN_PATH + file_name, encoding="utf-8") as f:
            annotation = json.load(f)
        
        image = Image.open(IMAGE_PATH + file_name[:-10] + IMG_FORMAT)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        target_sequence = "<s>"+annotation+"</s>"
        
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens = False,
            max_length= max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id
        
        encoding = dict(pixel_values=pixel_values,
                        labels=labels)
        
        return encoding

In [12]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<table_extraction>"])[0] 


In [13]:
train_dataset = DonutTableDataset(json_list,
                             max_length = max_length,
                             image_size = image_size)

In [14]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [15]:
#model = VisionEncoderDecoderModel.from_pretrained("../../pubtabnet/HTML-lre-5-epoch1")
start_epoch = 0
avg_size = 1000

In [16]:
import torch
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)


for epoch in range(start_epoch, 1):
    print("Epoch:", epoch+1)
    mean_loss = 0
    mean_smpl_loss = 0 
    model.train()
    for i, batch in enumerate(tqdm(train_dataloader)):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        
        
        outputs = model(pixel_values=pixel_values, labels=labels)
        
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        mean_loss += loss.item()   
        mean_smpl_loss += loss.item() 
        if i % avg_size == 0:
            print(str(i) + " loss: ", mean_smpl_loss/avg_size)
            write_msg("batch " + str(i) +" loss: "+ str(mean_smpl_loss/avg_size))
            mean_smpl_loss = 0 
        
        if i % 10000 == 0:
            model.save_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
    
    model.save_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
    print("Epoch's mean loss: ", mean_loss/len(train_dataloader))
    
    write_msg("Epoch checkpointed: " + str(epoch+1) +" \n"+
              "Epoch's mean Loss: " + str(mean_loss/len(train_dataloader)))

Epoch: 1


  0%|          | 0/125194 [00:00<?, ?it/s]

0 loss:  0.01856515884399414


KeyboardInterrupt: 

In [None]:
len(train_dataset)

In [None]:
model.save_pretrained("../../pubtabnet/HTML-lre-5-epoch1")

In [None]:
processor.save_pretrained("../../pubtabnet/Donut_PubTables_HTML_Processor")