In [1]:
!pip install python-doctr
!pip install "python-doctr[torch]"

Collecting python-doctr
  Downloading python_doctr-0.9.0-py3-none-any.whl.metadata (33 kB)
Collecting pypdfium2<5.0.0,>=4.11.0 (from python-doctr)
  Downloading pypdfium2-4.30.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.5/48.5 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyclipper<2.0.0,>=1.2.0 (from python-doctr)
  Downloading pyclipper-1.3.0.post5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (9.0 kB)
Collecting langdetect<2.0.0,>=1.0.9 (from python-doctr)
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rapidfuzz<4.0.0,>=3.0.0 (from python-doctr)
  Downloading rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collectin

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import json

class CustomDataset(Dataset):
    def __init__(self, image_dir, annotation_file, transform=None):
        self.image_dir = image_dir
        self.annotations = json.load(open(annotation_file))
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = f"{self.image_dir}/{self.annotations[idx]['image']}"
        image = Image.open(img_path).convert("RGB")
        boxes = self.annotations[idx]['boxes']
        labels = self.annotations[idx]['labels']

        if self.transform:
            image = self.transform(image)

        return image, {"boxes": boxes, "labels": labels}

In [None]:
from doctr.models import ocr_predictor

model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# Define transformations
transform = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
])

# Load dataset
train_dataset = CustomDataset(image_dir='path/to/images', annotation_file='path/to/annotations.json', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# **load wieghts custom**

In [None]:
import torch
from doctr.models import ocr_predictor, db_resnet50, crnn_vgg16_bn

# Load custom detection model
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_params = torch.load('<path_to_pt>', map_location="cpu")
det_model.load_state_dict(det_params)
predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True)

# Load custom recognition model
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_params = torch.load('<path_to_pt>', map_location="cpu")
reco_model.load_state_dict(reco_params)
predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True)

# Load custom detection and recognition model
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_params = torch.load('<path_to_pt>', map_location="cpu")
det_model.load_state_dict(det_params)
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_params = torch.load('<path_to_pt>', map_location="cpu")
reco_model.load_state_dict(reco_params)
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False)