In [None]:
import torch
from transformers import TableTransformerForObjectDetection, TableTransformerConfig
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset
from transformers import DetrImageProcessor
from PIL import Image
import os
import json
from datetime import datetime, date

In [None]:
# !pwd 

dat = str(date.today())
# dat
dt = datetime.today().strftime("%Y-%m-%d_%H:%M:%S")
dt


#work/Sagar/tabel-transformer-model/config.json

finetune_count = 1

# /Users/sagar17.patil/Documents/work_dir/ocr/layoutlmv3_playgorund
data_folder_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data"
current_working_directory_path = "/home/jovyan/work/Sagar/table-transformer-playground"
other_supporting_data_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data/other_supporting_data"
other_supporting_data_table_tansformer_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data/other_supporting_data/tbl-trsmf-data"

finetuned_model_path = os.path.join(current_working_directory_path, f"finetuned_models/{dat}/{finetune_count}")
os.makedirs(finetuned_model_path, exist_ok=True)

dates = ['2024-08-05','2024-08-06','2024-08-07','2024-08-08','2024-08-09','2024-08-10','2024-08-11']
data_date_range = "5_to_11_august"

# os.environ['HF_DATASETS_OFFLINE'] = "1"
os.environ['TRANSFORMERS_OFFLINE'] = "1"

# Load a pre-trained Table Transformer model and processor
model_name = "/home/jovyan/work/Sagar/table_transformer_detection_model"
# model_name = "microsoft/table-transformer-detection"

processor = DetrImageProcessor.from_pretrained(model_name)
model = TableTransformerForObjectDetection.from_pretrained(pretrained_model_name_or_path = model_name,use_pretrained_backbone=False ,local_files_only=True,cache_dir=model_name)
# model = TableTransformerForObjectDetection.from_pretrained(model_name,cache_dir=model_name)

# model = torch.load(model_name)
# model.eval()

dataset_type = "train"

In [None]:
single_page_files_tables_path = os.path.join(other_supporting_data_table_tansformer_path, f"single_page_files_tables_list_{dataset_type}_3_sd.json")
#Write dictionary to JSON file
with open (single_page_files_tables_path, 'r') as f:
    data_list = json.load(f)

data_list1 = data_list.copy()
len_data_list = len(data_list1)
print(len_data_list)

In [None]:
class TableDetectionDataset(Dataset):
    def __init__(self, data, processor):
        self.dataset = data
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def _get_single_item(self, idx):
        # Load image
        fname = self.dataset[idx]["page_file_name"] + ".png"
        image_path = os.path.join(data_folder_path, os.path.join(os.path.join(self.dataset[idx]["date"], "image_seperated_pdfs"), fname))
        
        # print(fname)
        
        image = Image.open(image_path).convert("RGB")
        width_px, height_px = image.size
        dpi = 300

        tables_bboxes = self.dataset[idx]["tables_bboxes"].copy()

        def create_annotations(tables_bboxes, width_px, height_px, dpi):
            annotations = []
            class_labels = []
            boxes = []

            if fname[:7] != "sample_":
                # print("tables_bboxes: ",len(tables_bboxes))

                for bb in tables_bboxes:
                    # Convert coordinates to pixels and normalize
                    x1 = (bb[0] * dpi) / width_px
                    y1 = (bb[1] * dpi) / height_px
                    x2 = (bb[4] * dpi) / width_px
                    y2 = (bb[5] * dpi) / height_px
                    
                    # print([x1, y1, x2, y2])

                    boxes.append([x1, y1, x2, y2])
                    class_labels.append(1)  # 1 for table class
                # print(boxes)
            else:
                # print("tables_bboxes: ",tables_bboxes)
                # print(width_px, height_px)
                for bb in tables_bboxes:
                    # Convert coordinates to pixels and normalize
                    x1 = (bb[0]) / width_px
                    y1 = (bb[1]) / height_px
                    x2 = (bb[2]) / width_px
                    y2 = (bb[3]) / height_px

                    boxes.append([x1, y1, x2, y2])
                    class_labels.append(1)  # 1 for table class


            return torch.tensor(class_labels, dtype=torch.long), torch.tensor(boxes, dtype=torch.float)

        # Get class labels and boxes
        class_labels, boxes = create_annotations(tables_bboxes, width_px, height_px, dpi)

        # Process the image
        encoding = self.processor(
            images=image,
            do_resize=True,
            size={"height": 800, "width": 800},
            resample=Image.Resampling.BILINEAR,
            do_rescale=True,
            do_normalize=True,
            do_pad=True,
            return_tensors="pt"
        )

        pixel_values = encoding["pixel_values"].squeeze()  # Remove batch dimension

        return pixel_values, {
            "class_labels": class_labels,
            "boxes": boxes
        }

    def __getitem__(self, idx):
        return self._get_single_item(idx)

In [None]:
def prepare_data(data_list, processor):
    main_list = []
    failed_list = []
    ds = TableDetectionDataset(data_list, processor)

    for i in range(len(data_list)):
        try:
            
            # dataset = TableDetectionDataset([data_list[i]], processor)
            item = ds.__getitem__(i)
            main_list.append(item)
            # print("successful: ",i)
        except Exception as e:
            print(f"Failed on index {i}:", e)
            failed_list.append(i)
            continue
    
    return main_list

# Usage
dataset = prepare_data(data_list1, processor)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Usage
# dataset = prepare_data(data_list1, processor)

batch_size = 16

def collate_fn(batch):
    pixel_values = torch.stack([item[0] for item in batch])
    targets = [{
        "class_labels": item[1]["class_labels"],
        "boxes": item[1]["boxes"]
    } for item in batch]
    return pixel_values, targets

dataloader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=collate_fn
)

In [None]:
def train_model(model, dataloader, num_epochs, device):
    
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    ep_cnt = 0
    
    loss_metrics = []
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        loss1 = 0
        
        # print(f"\n\t Current Epoch: {epoch+1} - - - - - - - - - - - - - - - - - - - - - -")
        
        for batch_idx, (pixel_values, targets) in enumerate(dataloader):
            try:
                print(f"\n\t - - - - - - - batch {batch_idx} - - - - - - -")
                # Move inputs to device
                pixel_values = pixel_values.to(device)
                targets = [{
                    "class_labels": t["class_labels"].to(device),
                    "boxes": t["boxes"].to(device)
                } for t in targets]
                
                print("1")
                
                # Forward pass
                outputs = model(pixel_values=pixel_values, labels=targets)
                
                print("2")
                loss = outputs.loss
                print("3")
                
                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                print("4")
                # loss1 = loss.item()
                
                total_loss += loss.item()
                
                # if batch_idx % 10 == 0:
                #     print(f"\t Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            except Exception as e:
                print(f"Error in batch {batch_idx}:", e)
                print(len(pixel_values), len(targets))
                continue
        
        avg_loss = total_loss / len(dataloader)
        print(f"\n Epoch {epoch+1} completed, Average Loss: {avg_loss:.4f}")
        d = {
            "epoch":epoch,
            "avg_loss":avg_loss
        }
        loss_metrics.append(d)
        
        ep_cnt = epoch+1
        
        if ep_cnt in [3,5,8,10]:

            model_state_dict_path = os.path.join(finetuned_model_path, f'ttft_{dt}_total_epoch_{num_epochs}_bs_{batch_size}_epcnt_{ep_cnt}.pth')

            torch.save(model.state_dict(), model_state_dict_path)
            print("model saved on path:",model_state_dict_path)
            
    model_state_dict_path = os.path.join(finetuned_model_path, f'ttft_{dt}_total_epoch_{num_epochs}_bs_{batch_size}_epcnt_{ep_cnt}.pth')

    torch.save(model.state_dict(), model_state_dict_path)
    print("model saved on path:",model_state_dict_path)
    
    return loss_metrics

In [None]:
num_epochs = 16

loss_metrics = train_model(model, dataloader, num_epochs=num_epochs, device=device)

model_state_dict_path = os.path.join(finetuned_model_path, f'training_loss_metrics_{dt}_tot_epoch_{num_epochs}_bs_{batch_size}.json')
with open(model_state_dict_path, "w") as f:
    json.dump(loss_metrics, f, indent=3)  # indent=4 makes it pretty-printed
    f.close()
    
print("loss metrics file saved")

## Inferance

In [None]:

import torch
from transformers import TableTransformerForObjectDetection, TableTransformerConfig
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset
from transformers import DetrImageProcessor
from PIL import Image
import os
import json
from datetime import datetime, date


# /Users/sagar17.patil/Documents/work_dir/ocr/layoutlmv3_playgorund
data_folder_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data"
current_working_directory_path = "/home/jovyan/work/Sagar/table-transformer-playground"
other_supporting_data_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data/other_supporting_data"
other_supporting_data_table_tansformer_path = "/home/jovyan/work/Sagar/p2p_august_5_to_11_data/other_supporting_data/tbl-trsmf-data"

dates = ['2024-08-05','2024-08-06','2024-08-07','2024-08-08','2024-08-09','2024-08-10','2024-08-11']
data_date_range = "5_to_11_august"

#dates of which data is to be prepared
dates = ['2024-08-05','2024-08-06','2024-08-07','2024-08-08','2024-08-09','2024-08-10','2024-08-11']
data_date_range = "5_to_11_august"
# dates = ['2024-08-05']

# os.environ['HF_DATASETS_OFFLINE'] = "1"
os.environ['TRANSFORMERS_OFFLINE'] = "1"

dataset_type = "test"

single_page_files_tables_path = os.path.join(other_supporting_data_table_tansformer_path, f"single_page_files_tables_list_{dataset_type}_1.json")

#Write dictionary to JSON file
with open (single_page_files_tables_path, 'r') as f:
    data_list = json.load(f)


len_data_list = len(data_list)
print(len_data_list)

ele1 = data_list[0].copy()
ele1.keys()


model_name = "/home/jovyan/work/Sagar/table_transformer_detection_model"
state_dict_path = "/home/jovyan/work/Sagar/table-transformer-playground/ttft_2025-02-06_05:43:12_epoch_10_bs_16.pth"
# Initialize model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# config = TableTransformerConfig.from_pretrained(model_name)
print("1")
model = TableTransformerForObjectDetection.from_pretrained(pretrained_model_name_or_path = model_name,use_pretrained_backbone=False ,local_files_only=True,cache_dir=model_name)


print("2")
state_dict = torch.load(state_dict_path, map_location=device)
model.load_state_dict(state_dict, strict=False)


model.to(device)

import torch
from transformers import TableTransformerForObjectDetection, TableTransformerConfig
from PIL import Image, ImageDraw
import torchvision.transforms as T
import numpy as np
import cv2

class TableDetector:
    def __init__(self, model_path=None, confidence_threshold=0.7):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.confidence_threshold = confidence_threshold

#         model_name = "/home/jovyan/work/Sagar/table_transformer_detection_model"
#         state_dict_path = "/home/jovyan/work/Sagar/table-transformer-playground/ttft_2025-02-06_05:43:12_epoch_10_bs_16.pth"
#         # Initialize model
        
#         config = TableTransformerConfig.from_pretrained(model_name)
#         print("1")
#         self.model = TableTransformerForObjectDetection.from_pretrained(pretrained_model_name_or_path = model_name,use_pretrained_backbone=False ,local_files_only=True,cache_dir=model_name)
#         print("2")
#         state_dict = torch.load(state_dict_path, map_location=self.device)
#         self.model.load_state_dict(state_dict, strict=False)
        
            
#         self.model.to(self.device)
        self.model = model
        self.model.eval()
        
        processor = DetrImageProcessor.from_pretrained(model_name)
        
        self.transform = T.Compose([
            T.Resize((800, 800)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _fix_box_coordinates(self, box):
        """
        Ensure box coordinates are in the correct order (x1 <= x2 and y1 <= y2)
        """
        x1, y1, x2, y2 = box
        return [
            min(x1, x2),  # x1
            min(y1, y2),  # y1
            max(x1, x2),  # x2
            max(y1, y2)   # y2
        ]

    def _validate_box(self, box, image_size):
        """
        Validate and clip box coordinates to image boundaries
        """
        width, height = image_size
        x1, y1, x2, y2 = box
        
        x1 = max(0, min(x1, width))
        y1 = max(0, min(y1, height))
        x2 = max(0, min(x2, width))
        y2 = max(0, min(y2, height))
        
        # Ensure minimum box size
        if x2 - x1 < 1 or y2 - y1 < 1:
            return None
            
        return [x1, y1, x2, y2]

    def detect_tables(self, image_path):
        # Load and transform image
        image = Image.open(image_path).convert('RGB')
        original_size = image.size
        
        # Transform image
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        # Perform inference
        with torch.no_grad():
            outputs = self.model(img_tensor)
        
        # Process outputs
        probas = outputs.logits.softmax(-1)[0, :, :-1]
        keep = probas.max(-1).values > self.confidence_threshold
        
        # Convert boxes to original image size
        boxes = outputs.pred_boxes[0, keep].cpu()
        scores = probas[keep].cpu()
        
        # Denormalize boxes to original image size
        boxes = self._rescale_boxes(boxes, original_size)
        
        # Fix and validate boxes
        valid_boxes = []
        valid_scores = []
        for box, score in zip(boxes, scores):
            box = self._fix_box_coordinates(box)
            box = self._validate_box(box, original_size)
            if box is not None:
                valid_boxes.append(box)
                valid_scores.append(score)
        
        return np.array(valid_boxes), np.array(valid_scores)

    def _rescale_boxes(self, boxes, original_size):
        width, height = original_size
        boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
        return boxes

    def visualize_detections(self, image_path, output_path, boxes, scores):
        # Load image
        image = Image.open(image_path).convert('RGB')
        draw = ImageDraw.Draw(image)
        
        # Detect tables
        # boxes, scores = self.detect_tables(image_path)
        
        if len(boxes) == 0:
            print("No tables detected!")
            image.save(output_path)
            return image
        
        # Draw boxes
        for box, score in zip(boxes, scores):
            if score > 0.40:
                print(score)
                box = box.astype(np.int32)
                try:
                    # Draw rectangle
                    draw.rectangle(
                        [(box[0], box[1]), (box[2], box[3])],
                        outline='red',
                        width=4
                    )

                    # Add confidence score
                    text = f'Table: {score.max():.2f}'
                    text_position = (int(box[0]), max(0, int(box[1] - 20)))
                    draw.text(
                        text_position,
                        text,
                        fill='red'

                    )
                except Exception as e:
                    print(f"Warning: Failed to draw box {box}: {str(e)}")
                    continue
        
        # Save result
        
        image.save(output_path)
        # print("image saved on path: ",output_path)
        return output_path

tbls = []
def make_predictions(ele1=None):
    # Initialize detector
    detector = TableDetector(
        model_path=None,  # Use pretrained model
        confidence_threshold=0.0
    )

#     width_in = ele1["width"]
#     height_in = ele1["height"]
#     dpi = 300
#     polygon_in = ele1["tables"][0]["boundingRegions"][0]["polygon"]
    
#     fn1 = str(ele1["page_file_name"])
    
#     fname = ele1["page_file_name"]+".png"
    
#     output_path = os.path.join(current_working_directory_path,f"original/{fname}")
    
#     image_path = os.path.join(data_folder_path,os.path.join(os.path.join(ele1["date"], "image_seperated_pdfs"),fname))


    width_in = 2480
    height_in = 3508
    dpi = 300
    # polygon_in = ele1["tables"][0]["boundingRegions"][0]["polygon"]
    
#     fn1 = str(ele1["page_file_name"])
    
#     fname = ele1["page_file_name"]+".png"
    
#     output_path = os.path.join(current_working_directory_path,f"original/{fname}")
    
#     image_path = os.path.join(data_folder_path,os.path.join(os.path.join(ele1["date"], "image_seperated_pdfs"),fname))

    
    # Example usage
    image_path = './a4_table_with_text.png'
    output_path = './a4_table_with_text_1.png'
    
    try:
        # Method 1: Get detection coordinates and scores
        boxes, scores = detector.detect_tables(image_path)
        print("Detected Tables:")
        
        for i, (box, score) in enumerate(zip(boxes, scores)):
            d = {}
            print(f"Table {i+1}:")
            print(f"  Coordinates: {box}")
            print(f"  Confidence: {score.max():.2f}")
            
        
        # Method 2: Visualize detections
        saved_path = detector.visualize_detections(image_path, output_path, boxes, scores)
        print(f"Visualization saved to {saved_path}")
        d = {
            "bbox":boxes,
            "confidence":scores
        }
        tbls.append(d)
    except Exception as e:
        print(f"Error during detection: {str(e)}")

if __name__ == "__main__":
    make_predictions()