In [1]:
import json
import os
import sys
from tqdm.auto import tqdm
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig


sys.path.insert(0, '../src')
from modeling_pos_donut import PosDonutModel


import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

ANN_PATH = '../../aux/data/anns/train/'
IMAGE_PATH = '../../aux/data/imgs/train/'

PROCESSORS_PATH = "../../aux/processors/"
MODELS_PATH = "../../aux/models/"

IMG_FORMAT = '.png'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.system("ls ../../aux/data")

anns
imgs
PubTabNet_2.0.0.jsonl
pubtabnet.tar.gz


0

In [3]:
json_list = os.listdir(ANN_PATH)

aux_list = []

for json_item in json_list:
    if json_item[-6] != "L":
        aux_list.append(json_item[:-5])

json_list = aux_list

In [4]:
json_list[:10]

['PMC3549936_002_00',
 'PMC3112512_009_00',
 'PMC3772909_007_00',
 'PMC5294720_004_00',
 'PMC3446317_008_00',
 'PMC5071235_005_00',
 'PMC1421390_001_00',
 'PMC5251219_004_01',
 'PMC3973966_003_00',
 'PMC4147206_002_00']

In [5]:
from transformers import DonutProcessor
from transformers import VisionEncoderDecoderConfig

config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")


#width height
image_size = [640, 640]
max_length = 960


config.encoder.image_size = image_size


processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")



In [6]:
config.decoder.max_position_encoddings = 2048
model = PosDonutModel.from_pretrained("naver-clova-ix/donut-base", config=config, ignore_mismatched_sizes=True)

In [7]:
model.config.decoder

MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_final_layer_norm": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 4,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "init_std": 0.02,
  "is_decoder": true,
  "is_encoder_decoder": false,
  "max_position_embeddings": 1536,
  "max_position_encoddings": 2048,
  "model_type": "mbart",
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "scale_embedding": true,
  "transformers_version": "4.41.0",
  "use_cache": true,
  "vocab_size": 57525
}

In [8]:
len(processor.tokenizer)

57525

In [9]:
from transformers import AddedToken

new_tokens  = ["<table_extraction>"]
new_tokens += ["<thead>", "</thead>", "<tbody>", "</tbody>"]
new_tokens += ["<tr>", "</tr>", "<td>", "</td>"]

new_tokens += ["<td ", ">"]
for i in range(1, 11):
    new_tokens +=['colspan="'+str(i)+'"']
    new_tokens +=['rowspan="'+str(i)+'"']

new_tokens += ["<i>", "</i>","<b>", "</b>", "<sup>", "</sup>", "<sub>", "</sub>"]

new_tokens = [AddedToken(tag, rstrip = True, lstrip=True, normalized=False) for tag in new_tokens]

processor.tokenizer.add_tokens(new_tokens, special_tokens = False)
model.decoder.resize_token_embeddings(len(processor.tokenizer))

Embedding(57563, 1024)

In [10]:
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False


In [11]:
model

VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0-1): 2 x DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
           

In [12]:
with open("msg.json", 'w') as out:
        json.dump({'outputs': []}, out, ensure_ascii=False, indent=4)

def write_msg(msg):
    with open("msg.json", encoding="utf-8") as f:
        json_data = json.load(f)
    
    with open("msg.json", 'w') as out:
        json_data['outputs'].append(msg)
        json.dump(json_data, out, ensure_ascii=False, indent=4)

In [13]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

class DonutTableDataset(Dataset):
    def __init__(
        self,
        annotations,
        image_size,
        max_length,
        shuffle = True,
        split = "train",
        ignore_id = -100,
        prompt_end_token = None,
    ):            
        self.annotations = annotations
        
        
        self.image_size = image_size
        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        
        self.resize = transforms.Compose([transforms.Resize(image_size)])
        
        
        
    def __len__(self):
        return len(self.annotations)
    
    
    def __getitem__(self, idx):
        
        file_name = self.annotations[idx]
        
        with open(ANN_PATH + file_name +"-HTML.json", encoding="utf-8") as f:
            annotation = json.load(f)
        
        image = Image.open(IMAGE_PATH + file_name + IMG_FORMAT)
        
        
        # inputs
        pixel_values = processor(image.convert("RGB"), random_padding=self.split == "train", return_tensors="pt").pixel_values.squeeze()
        pixel_values = pixel_values.squeeze()
        
        target_sequence = annotation+"</s>"
        
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens = False,
            max_length= max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id
        
        encoding = dict(pixel_values=pixel_values,
                        labels=labels)
        
        return encoding

In [14]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(["<table_extraction>"])[0] 


In [15]:
train_dataset = DonutTableDataset(json_list,
                             max_length = max_length,
                             image_size = image_size)

In [16]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=1, shuffle=True)

In [17]:
#model = VisionEncoderDecoderModel.from_pretrained("../../pubtabnet/HTML-lre-5-epoch1")
start_epoch = 0
avg_size = 1000

In [18]:
import torch
from tqdm.auto import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu' 
model.to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=5e-5)


for epoch in range(start_epoch, 3):
    print("Epoch:", epoch+1)
    mean_loss = 0
    mean_smpl_loss = 0 
    model.train()
    for i, batch in enumerate(tqdm(train_dataloader)):
        
        batch = {k: v.to(device) for k, v in batch.items()}
        
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        
        
        outputs = model(pixel_values=pixel_values, labels=labels)
        
        loss = outputs.loss
        loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        mean_loss += loss.item()   
        mean_smpl_loss += loss.item() 
        if i % avg_size == 0:
            print(str(i) + " loss: ", mean_smpl_loss/avg_size)
            write_msg("batch " + str(i) +" loss: "+ str(mean_smpl_loss/avg_size))
            mean_smpl_loss = 0 
        
        if i % 10000 == 0:
            model.save_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
    
    model.save_pretrained("../../pubtabnet/HTML-lre-5-checkpoint")
    print("Epoch's mean loss: ", mean_loss/len(train_dataloader))
    
    write_msg("Epoch checkpointed: " + str(epoch+1) +" \n"+
              "Epoch's mean Loss: " + str(mean_loss/len(train_dataloader)))

Epoch: 1


  0%|                                                | 0/500777 [00:00<?, ?it/s]

0 loss:  0.01389599323272705


  0%|                                    | 27/500777 [00:14<75:06:42,  1.85it/s]


KeyboardInterrupt: 

In [None]:
len(train_dataset)

In [None]:
model.save_pretrained("../../pubtabnet/modelos/HTML-lre-5-epoch3")

In [None]:
processor.save_pretrained("../../pubtabnet/Donut_PubTables_HTML_Processor")