In [1]:
import json
import os

PUBTAB_ANN = '../../pubtabnet/anns/train/'
PUBTAB_IMG = '../../pubtabnet/imgs/train/'

ANN_PATH = PUBTAB_ANN
IMAGE_PATH = PUBTAB_IMG
IMG_FORMAT = '.png'

with open(ANN_PATH[:-1] + '_trunc_filelist.json') as file:
    json_list = file.read().splitlines()
    
json_list = [item for item in json_list]

In [2]:
json_list[:10]

['PMC6051241_006_00.json',
 'PMC3398398_014_00.json',
 'PMC2682793_006_00.json',
 'PMC6082926_006_00.json',
 'PMC2809502_009_00.json',
 'PMC3926592_006_01.json',
 'PMC548940_003_00.json',
 'PMC4394591_006_00.json',
 'PMC5954872_006_00.json',
 'PMC3546910_004_01.json']

In [3]:
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers import VisionEncoderDecoderConfig

config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")

image_size = [750, 750]
max_length = 1300 #config.decoder.max_position_embeddings


config.encoder.image_size = image_size


processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base", config=config, ignore_mismatched_sizes=True)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [4]:
len(processor.tokenizer)

57525

In [5]:
from transformers import AddedToken

new_tokens  = ["<table_extraction>"]
new_tokens += ["<thead>", "</thead>", "<tbody>", "</tbody>"]
new_tokens += ["<tr>", "</tr>", "<td>", "</td>"]

new_tokens += ["<td ", ">"]
for i in range(10):
    new_tokens +=[' colspan="'+str(i)+'"']
    new_tokens +=[' rowspan="'+str(i)+'"']

new_tokens += ["<i>", "</i>","<b>", "</b>", "<sup>", "</sup>", "<sub>", "</sub>"]

new_tokens += ["1", "≥","≤"]

new_tokens = [AddedToken(tag, rstrip = True, lstrip=True, normalized=False) for tag in new_tokens]

processor.tokenizer.add_tokens(new_tokens, special_tokens = False)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

Embedding(57566, 1024)

In [6]:
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False


In [7]:
with open("msg.json", 'w') as out:
        json.dump({'outputs': []}, out, ensure_ascii=False, indent=4)

def write_msg(msg):
    with open("msg.json", encoding="utf-8") as f:
        json_data = json.load(f)
    
    with open("msg.json", 'w') as out:
        json_data['outputs'].append(msg)
        json.dump(json_data, out, ensure_ascii=False, indent=4)

In [8]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        image_size,
        max_length,
        shuffle = True,
        split = "train",
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations = annotations
        
        
        self.image_size = image_size
        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        
        self.resize = transforms.Compose([transforms.Resize(image_size)])
        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations[idx]
        
        with open(ANN_PATH + file_name + ".json", encoding="utf-8") as f:
            annotation = json.load(f)
        
        image = Image.open(IMAGE_PATH + file_name[:-5] + IMG_FORMAT)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        target_sequence = annotation
        
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=True,
            max_length= max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id
        
        encoding = dict(pixel_values=pixel_values,
                        labels=labels)
        
        return encoding

In [9]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<table_extraction>"])[0] 


In [10]:
train_dataset = DonutTableDataset(json_list,
                             max_length = max_length,
                             image_size = image_size)

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [12]:
#model = DonutProcessor.from_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
start_epoch = 0
avg_size = 2000

In [13]:
import torch
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)


for epoch in range(start_epoch, 1):
    print("Epoch:", epoch+1)
    mean_loss = 0
    mean_smpl_loss = 0 
    model.train()
    for i, batch in enumerate(tqdm(train_dataloader)):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        
        
        outputs = model(pixel_values=pixel_values, labels=labels)
        
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        mean_loss += loss.item()   
        mean_smpl_loss += loss.item() 
        if i % avg_size == 0:
            print(str(i) + " loss: ", mean_smpl_loss/avg_size)
            write_msg("batch " + str(i) +" loss: "+ str(mean_smpl_loss/avg_size))
            mean_smpl_loss = 0 
        
        if i % 10000 == 0:
            model.save_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
    
    model.save_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
    print("Epoch's mean loss: ", mean_loss/len(train_dataloader))
    
    write_msg("Epoch checkpointed: " + str(epoch+1) +" \n"+
              "Epoch's mean Loss: " + str(mean_loss/len(train_dataloader)))

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.68 GiB total capacity; 67.04 MiB already allocated; 14.12 MiB free; 70.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
len(train_dataset)

In [None]:
model.save_pretrained("../../pubtabnet/HTML-lre-5-epoch1")

In [None]:
processor.save_pretrained("../../pubtabnet/Donut_PubTables_HTML_Processor")