diff --git a/src/main/java/com/mindee/CommandLineInterface.java b/src/main/java/com/mindee/CommandLineInterface.java index 9430a1b56..2a14051f2 100644 --- a/src/main/java/com/mindee/CommandLineInterface.java +++ b/src/main/java/com/mindee/CommandLineInterface.java @@ -11,6 +11,7 @@ import com.mindee.product.custom.CustomV1; import com.mindee.product.invoice.InvoiceV4; import com.mindee.product.invoicesplitter.InvoiceSplitterV1; +import com.mindee.product.multireceiptsdetector.MultiReceiptsDetectorV1; import com.mindee.product.passport.PassportV1; import com.mindee.product.receipt.ReceiptV4; import java.io.File; @@ -79,7 +80,7 @@ public static void main(String[] args) { System.exit(exitCode); } - @Command(name = "invoice", description = "Invokes the invoice API") + @Command(name = "invoice", description = "Invokes the Invoice API") void invoiceMethod( @Parameters(index = "0", paramLabel = "", scope = ScopeType.LOCAL) File file @@ -87,7 +88,7 @@ void invoiceMethod( System.out.println(standardProductOutput(InvoiceV4.class, file)); } - @Command(name = "receipt", description = "Invokes the receipt API") + @Command(name = "receipt", description = "Invokes the Expense Receipt API") void receiptMethod( @Parameters(index = "0", paramLabel = "", scope = ScopeType.LOCAL) File file @@ -95,7 +96,15 @@ void receiptMethod( System.out.println(standardProductOutput(ReceiptV4.class, file)); } - @Command(name = "passport", description = "Invokes the passport API") + @Command(name = "multi-receipt-detector", description = "Invokes the Multi Receipts Detector API") + void multiReceiptDetectorMethod( + @Parameters(index = "0", paramLabel = "", scope = ScopeType.LOCAL) + File file + ) throws IOException { + System.out.println(standardProductOutput(MultiReceiptsDetectorV1.class, file)); + } + + @Command(name = "passport", description = "Invokes the Passport API") void passportMethod( @Parameters(index = "0", paramLabel = "", scope = ScopeType.LOCAL) File file @@ -103,7 +112,7 @@ void passportMethod( System.out.println(standardProductOutput(PassportV1.class, file)); } - @Command(name = "invoice-splitter", description = "Invokes the invoice-splitter API") + @Command(name = "invoice-splitter", description = "Invokes the Invoice Splitter API") void invoiceSplitterMethod( @Parameters(index = "0", paramLabel = "", scope = ScopeType.LOCAL) File file @@ -111,7 +120,7 @@ void invoiceSplitterMethod( System.out.println(standardProductAsyncOutput(InvoiceSplitterV1.class, file)); } - @Command(name = "custom", description = "Invokes a builder API") + @Command(name = "custom", description = "Invokes a Custom API") void customMethod( @Option( names = {"-a", "--account"}, diff --git a/src/main/java/com/mindee/MindeeClient.java b/src/main/java/com/mindee/MindeeClient.java index 33b12ebbf..780990d64 100644 --- a/src/main/java/com/mindee/MindeeClient.java +++ b/src/main/java/com/mindee/MindeeClient.java @@ -492,8 +492,7 @@ private byte[] getSplitFile( PageOptions pageOptions ) throws IOException { byte[] splitFile; - boolean isPDF = InputSourceUtils.isPdf(localInputSource.getFilename()); - if (pageOptions == null || !isPDF) { + if (pageOptions == null || !localInputSource.isPdf()) { splitFile = localInputSource.getFile(); } else { splitFile = pdfOperation.split( diff --git a/src/main/java/com/mindee/extraction/ImageExtractor.java b/src/main/java/com/mindee/extraction/ImageExtractor.java index db8dabdd5..a3e5cafd8 100644 --- a/src/main/java/com/mindee/extraction/ImageExtractor.java +++ b/src/main/java/com/mindee/extraction/ImageExtractor.java @@ -6,9 +6,10 @@ import com.mindee.input.InputSourceUtils; import com.mindee.input.LocalInputSource; import com.mindee.parsing.standard.PositionData; +import com.mindee.pdf.PDFUtils; +import com.mindee.pdf.PdfPageImage; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; -import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -18,17 +19,16 @@ * Extract sub-images from an image. */ public class ImageExtractor { - private final BufferedImage bufferedImage; + private final List pageImages; private final String filename; + private final String saveFormat; /** * Init from a path. * @param filePath Path to the file. */ public ImageExtractor(String filePath) throws IOException { - File file = new File(filePath); - this.filename = file.getName(); - this.bufferedImage = ImageIO.read(file); + this(new LocalInputSource(filePath)); } /** @@ -37,29 +37,79 @@ public ImageExtractor(String filePath) throws IOException { */ public ImageExtractor(LocalInputSource source) throws IOException { this.filename = source.getFilename(); - ByteArrayInputStream input = new ByteArrayInputStream(source.getFile()); - this.bufferedImage = ImageIO.read(input); + this.pageImages = new ArrayList<>(); + + if (source.isPdf()) { + this.saveFormat = "jpg"; + List pdfPageImages = PDFUtils.pdfToImages(source); + for (PdfPageImage pdfPageImage : pdfPageImages) { + this.pageImages.add(pdfPageImage.getImage()); + } + } else { + String[] splitName = InputSourceUtils.splitNameStrict(this.filename); + this.saveFormat = splitName[1].toLowerCase(); + + ByteArrayInputStream input = new ByteArrayInputStream(source.getFile()); + this.pageImages.add(ImageIO.read(input)); + } + } + + /** + * @return The number of pages in the file. + */ + public int getPageCount() { + return this.pageImages.size(); } /** * Extract images from a list of fields having position data. + * Use this when the input file is a PDF with multiple pages. * @param fields List of Fields to extract. + * @param pageIndex The page index to extract, begins at 0. * @return A list of {@link ExtractedImage}. */ - public List extractImages(List fields) { - return extractImages(fields, this.filename); + public List extractImagesFromPage( + List fields, + int pageIndex + ) { + return extractImagesFromPage(fields, pageIndex, this.filename); } /** * Extract images from a list of fields having position data. + * Use this when the input file is a PDF with multiple pages. * @param fields List of Fields to extract. - * @param filename The base output filename. + * @param pageIndex The page index to extract, begins at 0. + * @param outputName The base output filename, must have an image extension. * @return A list of {@link ExtractedImage}. */ - public List extractImages(List fields, String filename) { + public List extractImagesFromPage( + List fields, + int pageIndex, + String outputName + ) { + String filename; + if (this.getPageCount() > 1) { + String[] splitName = InputSourceUtils.splitNameStrict(outputName); + filename = splitName[0] + "." + this.saveFormat; + } else { + filename = outputName; + } + return extractFromPage(fields, pageIndex, filename); + } + + private List extractFromPage( + List fields, + int pageIndex, + String outputName + ) { + String[] splitName = InputSourceUtils.splitNameStrict(outputName); + String filename = String.format("%s_page-%3s.%s", splitName[0], pageIndex + 1, splitName[1]) + .replace(" ", "0"); + List extractedImages = new ArrayList<>(); for (int i = 0; i < fields.size(); i++) { - ExtractedImage extractedImage = extractImage(fields.get(i), filename, i+1); + ExtractedImage extractedImage = extractImage(fields.get(i), pageIndex, i+1, filename); if (extractedImage != null) { extractedImages.add(extractedImage); } @@ -71,9 +121,15 @@ public List extractImages(Listnull if the field does not have valid position data. */ - public ExtractedImage extractImage(FieldT field, String filename, int index) { + public ExtractedImage extractImage( + FieldT field, + int pageIndex, + int index, + String filename + ) { String[] splitName = InputSourceUtils.splitNameStrict(filename); String saveFormat = splitName[1].toLowerCase(); Polygon boundingBox = field.getBoundingBox(); @@ -84,27 +140,29 @@ public ExtractedImage extractImage(FieldT field, S String fieldFilename = splitName[0] + String.format("_%3s", index).replace(" ", "0") + "." - + splitName[1]; - return new ExtractedImage(extractImage(bbox), fieldFilename, saveFormat); + + saveFormat; + return new ExtractedImage(extractImage(bbox, pageIndex), fieldFilename, saveFormat); } /** * Extract an image from a field having position data. * @param field The field to extract. * @param index The index to use for naming the extracted image. + * @param pageIndex The page index to extract, begins at 0. * @return The {@link ExtractedImage}, or null if the field does not have valid position data. */ - public ExtractedImage extractImage(FieldT field, int index) { - return extractImage(field, this.filename, index); + public ExtractedImage extractImage(FieldT field, int pageIndex, int index) { + return extractImage(field, pageIndex, index, this.filename); } - private BufferedImage extractImage(Bbox bbox) { - int width = this.bufferedImage.getWidth(); - int height = this.bufferedImage.getHeight(); + private BufferedImage extractImage(Bbox bbox, int pageIndex) { + BufferedImage image = this.pageImages.get(pageIndex); + int width = image.getWidth(); + int height = image.getHeight(); int minX = (int) Math.round(bbox.getMinX() * width); int maxX = (int) Math.round(bbox.getMaxX() * width); int minY = (int) Math.round(bbox.getMinY() * height); int maxY = (int) Math.round(bbox.getMaxY() * height); - return this.bufferedImage.getSubimage(minX, minY, maxX - minX, maxY - minY); + return image.getSubimage(minX, minY, maxX - minX, maxY - minY); } } diff --git a/src/main/java/com/mindee/input/LocalInputSource.java b/src/main/java/com/mindee/input/LocalInputSource.java index e5b4914fb..7000a6759 100644 --- a/src/main/java/com/mindee/input/LocalInputSource.java +++ b/src/main/java/com/mindee/input/LocalInputSource.java @@ -42,4 +42,8 @@ public LocalInputSource(String fileAsBase64, String filename) { this.file = Base64.getDecoder().decode(fileAsBase64.getBytes()); this.filename = filename; } + + public boolean isPdf() { + return InputSourceUtils.isPdf(this.filename); + } } diff --git a/src/test/java/com/mindee/extraction/ImageExtractorTest.java b/src/test/java/com/mindee/extraction/ImageExtractorTest.java index 1aa203aa0..6e8e97c47 100644 --- a/src/test/java/com/mindee/extraction/ImageExtractorTest.java +++ b/src/test/java/com/mindee/extraction/ImageExtractorTest.java @@ -2,7 +2,9 @@ import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.ObjectMapper; +import com.mindee.MindeeException; import com.mindee.input.LocalInputSource; +import com.mindee.parsing.common.Page; import com.mindee.parsing.common.PredictResponse; import com.mindee.product.barcodereader.BarcodeReaderV1; import com.mindee.product.barcodereader.BarcodeReaderV1Document; @@ -51,20 +53,26 @@ public void givenAnImage_shouldExtractPositionFields() throws IOException { "src/test/resources/products/multi_receipts_detector/default_sample.jpg" ); PredictResponse response = getMultiReceiptsPrediction("complete"); - MultiReceiptsDetectorV1Document prediction = response.getDocument().getInference().getPrediction(); + MultiReceiptsDetectorV1 inference = response.getDocument().getInference(); ImageExtractor extractor = new ImageExtractor(image); - List subImages = extractor.extractImages(prediction.getReceipts()); - for (int i = 0; i < subImages.size(); i++) { - ExtractedImage extractedImage = subImages.get(i); - Assertions.assertNotNull(extractedImage.getImage()); - extractedImage.writeToFile("src/test/resources/output/"); - - LocalInputSource source = extractedImage.asInputSource(); - Assertions.assertEquals( - String.format("default_sample_%3s.jpg", i + 1).replace(" ", "0"), - source.getFilename() + Assertions.assertEquals(1, extractor.getPageCount()); + + for (Page page : inference.getPages()) { + List subImages = extractor.extractImagesFromPage( + page.getPrediction().getReceipts(), page.getPageId() ); + for (int i = 0; i < subImages.size(); i++) { + ExtractedImage extractedImage = subImages.get(i); + Assertions.assertNotNull(extractedImage.getImage()); + extractedImage.writeToFile("src/test/resources/output/"); + + LocalInputSource source = extractedImage.asInputSource(); + Assertions.assertEquals( + String.format("default_sample_page-001_%3s.jpg", i + 1).replace(" ", "0"), + source.getFilename() + ); + } } } @@ -72,18 +80,63 @@ public void givenAnImage_shouldExtractPositionFields() throws IOException { public void givenAnImage_shouldExtractValueFields() throws IOException { String imagePath = "src/test/resources/products/barcode_reader/default_sample.jpg"; PredictResponse response = getBarcodeReaderPrediction("complete"); - BarcodeReaderV1Document prediction = response.getDocument().getInference().getPrediction(); + BarcodeReaderV1 inference = response.getDocument().getInference(); ImageExtractor extractor = new ImageExtractor(imagePath); - List codes1D = extractor.extractImages(prediction.getCodes1D(), "barcodes_1D.png"); - for (ExtractedImage extractedImage : codes1D) { - Assertions.assertNotNull(extractedImage.getImage()); - extractedImage.writeToFile("src/test/resources/output/"); + Assertions.assertEquals(1, extractor.getPageCount()); + + for (Page page : inference.getPages()) { + List codes1D = extractor.extractImagesFromPage( + page.getPrediction().getCodes1D(), page.getPageId(), "barcodes_1D.png" + ); + for (int i = 0; i < codes1D.size(); i++) { + ExtractedImage extractedImage = codes1D.get(i); + Assertions.assertNotNull(extractedImage.getImage()); + LocalInputSource source = extractedImage.asInputSource(); + Assertions.assertEquals( + String.format("barcodes_1D_page-001_%3s.png", i + 1).replace(" ", "0"), + source.getFilename() + ); + extractedImage.writeToFile("src/test/resources/output/"); + } + List codes2D = extractor.extractImagesFromPage( + page.getPrediction().getCodes2D(), page.getPageId(),"barcodes_2D.png" + ); + for (ExtractedImage extractedImage : codes2D) { + Assertions.assertNotNull(extractedImage.getImage()); + extractedImage.writeToFile("src/test/resources/output/"); + } } - List codes2D = extractor.extractImages(prediction.getCodes2D(), "barcodes_2D.png"); - for (ExtractedImage extractedImage : codes2D) { - Assertions.assertNotNull(extractedImage.getImage()); - extractedImage.writeToFile("src/test/resources/output/"); + } + + @Test + public void givenAPdf_shouldExtractPositionFields() throws IOException { + LocalInputSource image = new LocalInputSource( + "src/test/resources/products/multi_receipts_detector/multipage_sample.pdf" + ); + PredictResponse response = getMultiReceiptsPrediction("multipage_sample"); + MultiReceiptsDetectorV1 inference = response.getDocument().getInference(); + + ImageExtractor extractor = new ImageExtractor(image); + Assertions.assertEquals(2, extractor.getPageCount()); + + for (Page page : inference.getPages()) { + List subImages = extractor.extractImagesFromPage( + page.getPrediction().getReceipts(), + page.getPageId() + ); + + for (int i = 0; i < subImages.size(); i++) { + ExtractedImage extractedImage = subImages.get(i); + Assertions.assertNotNull(extractedImage.getImage()); + extractedImage.writeToFile("src/test/resources/output/"); + + LocalInputSource source = extractedImage.asInputSource(); + Assertions.assertEquals( + String.format("multipage_sample_page-%3s_%3s.jpg", page.getPageId() + 1, i + 1).replace(" ", "0"), + source.getFilename() + ); + } } } } diff --git a/src/test/resources b/src/test/resources index 86d4fd574..249c253aa 160000 --- a/src/test/resources +++ b/src/test/resources @@ -1 +1 @@ -Subproject commit 86d4fd57411bd96353a6b4f16a0b447cfbd9f1ae +Subproject commit 249c253aa93cd9dd10e235dc5ec0cacb2536f7ed