# Document Layout Clustering - embedding generation

Here we generate embeddings for a number of pages in each document using a general purpose VGG model.

In [1]:
import glob
import os
import io
import time
import typing as t

import fitz
from PIL import Image
import torch
import numpy as np
from torchvision import transforms
from tqdm.auto import tqdm
import pandas as pd
from dotenv import load_dotenv

load_dotenv()

model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
CORPUS_PATH = os.getenv("CORPUS_PATH")

CORPUS_PATH

Using cache found in /Users/kalyan/.cache/torch/hub/pytorch_vision_v0.10.0


'/Users/kalyan/Documents/CPR/policy-search/data/corpus/content'

In [9]:
pdf_files = glob.glob(f"{CORPUS_PATH}/*.pdf") + glob.glob(f"{CORPUS_PATH}/*.PDF")
len(pdf_files)

2824

In [4]:
# S3
S3_BUCKET="cpr-cclw-cpd-docs-temp"
import boto3

def s3_glob(bucket_name, suffix = None):
    s3_client = boto3.client('s3')
    s3_result = s3_client.list_objects_v2(Bucket=bucket_name)
    shortlisted_files = dict()            
    for obj in s3_result['Contents']:
        key = obj['Key']
        timestamp = obj['LastModified']
        
        if suffix:
            if key.endswith(suffix):              
                # Adding a new key value pair
                shortlisted_files.update( {key : timestamp} )   
        else:
            shortlisted_files.update( {key : timestamp} )   
    
    while s3_result['IsTruncated']:
        continuation_key = s3_result['NextContinuationToken']
        s3_result = s3_client.list_objects_v2(Bucket=bucket_name, ContinuationToken=continuation_key)
        
        for key in s3_result['Contents']:
            if suffix:
                if key.endswith(suffix):              
                    # Adding a new key value pair
                    shortlisted_files.update( {key : timestamp} )   
                else:
                    shortlisted_files.update( {key : timestamp} )   
    
    return shortlisted_files


pdf_files = s3_glob(S3_BUCKET)#, ".pdf")

# Local
# pdf_files = glob.glob(f"{CORPUS_PATH}/*.pdf") + glob.glob(f"{CORPUS_PATH}/*.PDF")
len(pdf_files)


ClientError: An error occurred (AccessDenied) when calling the ListObjectsV2 operation: Access Denied

## 1. Define pages to process for each document

This is so we don't have to process all the pages in each document.

In [67]:
def page_locs_from_doc(doc: fitz.Document, n: int) -> t.List[int]:
    """
    Get indices for `n` evenly spaced pages from the doc, excluding the last page. 
    Returns all the pages excluding the last if there are fewer than `n` pages in the doc (excluding the last).
    """
    
    if n < len(doc):
        return [int(i) for i in np.linspace(0, len(doc)-2 , n).tolist()]
    else:
        return list(range(len(doc)))

MAX_PAGES_PER_DOC = 6    

## 2. Run processing

In [68]:
def page_to_image(page: 'Page') -> Image:
    pix = page.get_pixmap()
    input_bytes = pix.pil_tobytes(format="JPEG")

    return Image.open(io.BytesIO(input_bytes))

def image_to_emb_vector(img: Image, model: torch.nn) -> torch.Tensor:
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        # output = model(input_batch)
        fl_embedding = model.features(input_batch)

    return fl_embedding[0].reshape(-1)


In [74]:
embedding_store = []

for f_name in tqdm(pdf_files[0:9]):
    doc = fitz.open(f_name)
    page_idxs_to_process = page_locs_from_doc(doc, MAX_PAGES_PER_DOC)
    
    for idx in page_idxs_to_process:
        input_image = page_to_image(doc.load_page(idx))
        emb = image_to_emb_vector(input_image, model)
        
        embedding_store.append(
            {
                "filename": f_name,
                "page_num": idx,
                "embedding": emb.tolist()
            }
        )

    # display(input_image)

  0%|          | 0/9 [00:00<?, ?it/s]

### Save embeddings

In [83]:
pd.DataFrame(embedding_store).to_pickle("vgg16_embeddings.pkl")
pd.DataFrame(embedding_store).to_csv("vgg16_embeddings.tsv", sep="\t")