In [315]:
import pandas as pd
import html5lib
from bs4 import BeautifulSoup
import json
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import os
from transformers import BertTokenizer, BertModel
import warnings
import json
import numpy as np
import re
warnings.filterwarnings('ignore')

In [317]:
tables_path = "Downloads/generated_tables/tables/"
json_path = "Downloads/generated_tables/metadata/"

## Solution Overview
The task: Given the structured nature of HTML tables and the hierarchical nature of JSON, this task involves understanding and transforming structured data from one format to another.
Upon quick inspection I see three possible approaches to the problem, each with pros and cons. 
1. Rule based - the HTML structure appears identical for all tables, and is not overly complex. A simple python script to automate the task of extracting elements from HTML and converting them to JSON would likely due the job. However, this approach would not scale if the data source was changed from a HTML to say a table embedded in a PDF, or if the table structure was changed dramatically. 
2. LLM approach- let the LLM do the heavy lifting by providing it with a few examples (few-shot learning) of html-json pairs, and then give it the HTML to product json. THis could be followed by a validation step to ensure the JSON format is valid. 
3. A third approach is to use a transformer-based encoder-decoder sequence to sequence model to embed the html text from the table and reconstruct it as JSON. We can leverage the json and html pairs to train the model. The encoder is a BERT embedding + an LSTM layer that converts the HTML table to a context vector. The decoder converts the context vector via an LSTM and a fully connected layer to the tokenized json. We can then calculate the loss by comparing the generated tokens to the actual tokens representing the JSON dictionary. While BERT and LSTM layers are a robust solution for handling the structured and variable-length nature of the data involved, this is a bit overkill since we are not dealing with semantic table interpretation where we would need a transformer based model to perform cell entity or column type annotation. This is also the riskiest in terms of being to train a model that can do the task accurately in terms of minimizing a reconstruction error, since certain elements of the json structure might be more important than others. Thus weighing all generated tokens equally in the loss calculation is risky. A better choice of loss might be the jaccard similarity coefficient score. Further, we are relying on the model to generate not only the json data values but the keys and the schema as well. 

Approach #3 could be improved upon by encoding four components of a table to form its embedding sequence: row headers, column headers, metadata, and context, as opposed to a single fixed length embedding. Serializing vectors of four components into a sequence of vectors as the embedding of the table could likely preserve the local information in the table. We could use these four vectors to customize the loss function by employing multifield evaluation, where each field is evaluated separately against its JSON component and then the weighted sum of the scores is taken. 

### (1. Rule Based)

In [None]:
file_name = os.path.join(tables_path, '1_table.html')

In [None]:
def get_table_id(html_file):
    
    with open(file_name, 'r') as f:
        soup = BeautifulSoup(f, 'html.parser')
        text_elements = str(soup.contents[0])
    
    start_delim = "<caption>"
    end_delim = "</caption>"

    pattern = re.escape(start_delim) + r"(.*?)" + re.escape(end_delim)
    match = re.search(pattern, text_elements)
    if match:
        return match.group(1)
    else:
        return None

In [None]:
def get_data(html_file):
    html1 = pd.read_html(html_file)
    df = html1[0]
    cols = list(df.columns[1:].values)
    df2 = df.set_index(df.columns[0])
    rows = list(df2.index.values[:-1])
    creation_text = df2.index.values[-1]
    creation_date = df2.index.values[-1].split()[0]
    content = list(np.concatenate(df2[:-1].values))
    
    table_id = get_table_id(html_file)
    
    jsonn = {
    "body": {
        "content": concent,
        "headers": {
            "col": cols,
            "row": rows
        }
    },
    "footer": {
        "table_creation_date:": creation_date,
        "text": creation_text
    },
    "header": {
        "table_id": table_id.split()[1],
        "text": table_id
    }
}
    
    return json.dumps(jsonn)

### 3. BERT Encoder-Decoder

In [353]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [374]:
class HTMLTextExtractionDataset(Dataset):
    def __init__(self, html_path, json_path, max_len=128):
        self.html_files = os.listdir(html_path)
        self.json_files = os.listdir(json_path)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_len = max_len
        self.data = self.load_data()
    
    def load_data(self):
        data = []
        for html_file, json_file in zip(self.html_files, self.json_files):
            with open(os.path.join(html_path, html_file), 'r') as f:
                soup = BeautifulSoup(f, 'html.parser')
                text_elements = soup.table.text
            
            with open(os.path.join(json_path, json_file), 'r') as f:
                json_data = json.load(f)
            
            data.append((text_elements, json_data))
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        html_table, json_dict =  self.data[idx]
        tokens = self.tokenizer.encode(html_table, padding='max_length', max_length=self.max_len, truncation=True)
        json_tokens = self.tokenizer.encode(json.dumps(json_dict), padding='max_length', max_length=self.max_len, truncation=True, add_special_tokens=True)
        return torch.tensor(tokens), torch.tensor(json_tokens)

In [375]:
dataset = HTMLTextExtractionDataset(html_path, json_path)

In [376]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [322]:
class TableToJSONModel(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(TableToJSONModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.encoder = nn.LSTM(768, hidden_size, batch_first=True)
        self.decoder = nn.LSTM(hidden_size, output_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, attention_mask=None):
        bert_output = self.bert(x, attention_mask=attention_mask)
        encoder_output, (hidden, cell) = self.encoder(bert_output.last_hidden_state)
        decoder_output, _ = self.decoder(encoder_output)
        output = self.fc(decoder_output)
        return output

In [403]:
def criteria(logits, target):
    logits = logits.view(-1, vocab_size)
    return F.cross_entropy(logits, target.view(-1))

In [404]:
vocab_size = len(tokenizer)
hidden_dim = 256
linear_layer = nn.Linear(hidden_dim, vocab_size)

## Train

In [405]:
model = TableToJSONModel(hidden_size=hidden_dim, output_size=hidden_dim)

In [406]:
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=0.001)
n_epochs = 5

In [None]:
for epoch in range(n_epochs):
    total_loss = 0
    model.train()
    for tables, json_dicts in dataloader:
        attention_mask = (tables != tokenizer.pad_token_id).long()
        outputs = model(tables, attention_mask=attention_mask)
        logits = linear_layer(outputs)
        loss = criteria(logits, json_dicts)
        loss.requires_grad_(True)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item()}')

## Evaluate

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for tables, json_dicts in dataloader:
            attention_mask = (tables != tokenizer.pad_token_id).long()
            outputs = model(tables, attention_mask=attention_mask)
            logits = linear_layer(outputs)
            loss = criteria(logits, json_dicts)
            total_loss += loss.item()
    
    average_loss = total_loss / len(dataloader)
    print(f'Average Loss: {average_loss}')

In [None]:
evaluate_model(model, dataloader)

## Inference

In [368]:
def convert_embedding_to_string(outputs):
    tokens = []
    for output in outputs:
        tokens.append(tokenizer.decode(output.argmax(dim=-1)))
        
    return tokens

In [None]:
def predict_and_decode(model, html_text, tokenizer, max_len=128):
    tokens = tokenizer.encode(html_text, padding='max_length', max_length=max_len, truncation=True)
    inpt =  torch.tensor(tokens)
    attention_mask = (inpt != tokenizer.pad_token_id).long()
    outputs = model(inpt, attention_mask=attention_mask)
    output_string = convert_embedding_to_string(outputs)
    return tokens_to_json(output_tokens)

In [None]:
html_inference_file = # TODO insert html file name to test

In [None]:
with open(html_inference_file, 'r') as f:
    soup = BeautifulSoup(f, 'html.parser')
    text_elements = soup.table.text
    prediction = predict_and_decode(model, text_elements, tokenizer)
    print(prediction)