@@ -782,6 +782,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
782782 ip += 4 ;
783783 }
784784
785+ frameSizeInfo .nbBlocks = nbBlocks ;
785786 frameSizeInfo .compressedSize = (size_t )(ip - ipstart );
786787 frameSizeInfo .decompressedBound = (zfh .frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN )
787788 ? zfh .frameContentSize
@@ -825,6 +826,48 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
825826 return bound ;
826827}
827828
829+ size_t ZSTD_decompressionMargin (void const * src , size_t srcSize )
830+ {
831+ size_t margin = 0 ;
832+ unsigned maxBlockSize = 0 ;
833+
834+ /* Iterate over each frame */
835+ while (srcSize > 0 ) {
836+ ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo (src , srcSize );
837+ size_t const compressedSize = frameSizeInfo .compressedSize ;
838+ unsigned long long const decompressedBound = frameSizeInfo .decompressedBound ;
839+ ZSTD_frameHeader zfh ;
840+
841+ FORWARD_IF_ERROR (ZSTD_getFrameHeader (& zfh , src , srcSize ), "" );
842+ if (ZSTD_isError (compressedSize ) || decompressedBound == ZSTD_CONTENTSIZE_ERROR )
843+ return ERROR (corruption_detected );
844+
845+ if (zfh .frameType == ZSTD_frame ) {
846+ /* Add the frame header to our margin */
847+ margin += zfh .headerSize ;
848+ /* Add the checksum to our margin */
849+ margin += zfh .checksumFlag ? 4 : 0 ;
850+ /* Add 3 bytes per block */
851+ margin += 3 * frameSizeInfo .nbBlocks ;
852+
853+ /* Compute the max block size */
854+ maxBlockSize = MAX (maxBlockSize , zfh .blockSizeMax );
855+ } else {
856+ assert (zfh .frameType == ZSTD_skippableFrame );
857+ /* Add the entire skippable frame size to our margin. */
858+ margin += compressedSize ;
859+ }
860+
861+ assert (srcSize >= compressedSize );
862+ src = (const BYTE * )src + compressedSize ;
863+ srcSize -= compressedSize ;
864+ }
865+
866+ /* Add the max block size back to the margin. */
867+ margin += maxBlockSize ;
868+
869+ return margin ;
870+ }
828871
829872/*-*************************************************************
830873 * Frame decoding
@@ -850,7 +893,7 @@ static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity,
850893 if (srcSize == 0 ) return 0 ;
851894 RETURN_ERROR (dstBuffer_null , "" );
852895 }
853- ZSTD_memcpy (dst , src , srcSize );
896+ ZSTD_memmove (dst , src , srcSize );
854897 return srcSize ;
855898}
856899
@@ -928,6 +971,7 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
928971
929972 /* Loop on each block */
930973 while (1 ) {
974+ BYTE * oBlockEnd = oend ;
931975 size_t decodedSize ;
932976 blockProperties_t blockProperties ;
933977 size_t const cBlockSize = ZSTD_getcBlockSize (ip , remainingSrcSize , & blockProperties );
@@ -937,16 +981,34 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
937981 remainingSrcSize -= ZSTD_blockHeaderSize ;
938982 RETURN_ERROR_IF (cBlockSize > remainingSrcSize , srcSize_wrong , "" );
939983
984+ if (ip >= op && ip < oBlockEnd ) {
985+ /* We are decompressing in-place. Limit the output pointer so that we
986+ * don't overwrite the block that we are currently reading. This will
987+ * fail decompression if the input & output pointers aren't spaced
988+ * far enough apart.
989+ *
990+ * This is important to set, even when the pointers are far enough
991+ * apart, because ZSTD_decompressBlock_internal() can decide to store
992+ * literals in the output buffer, after the block it is decompressing.
993+ * Since we don't want anything to overwrite our input, we have to tell
994+ * ZSTD_decompressBlock_internal to never write past ip.
995+ *
996+ * See ZSTD_allocateLiteralsBuffer() for reference.
997+ */
998+ oBlockEnd = op + (ip - op );
999+ }
1000+
9401001 switch (blockProperties .blockType )
9411002 {
9421003 case bt_compressed :
943- decodedSize = ZSTD_decompressBlock_internal (dctx , op , (size_t )(oend - op ), ip , cBlockSize , /* frame */ 1 , not_streaming );
1004+ decodedSize = ZSTD_decompressBlock_internal (dctx , op , (size_t )(oBlockEnd - op ), ip , cBlockSize , /* frame */ 1 , not_streaming );
9441005 break ;
9451006 case bt_raw :
1007+ /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */
9461008 decodedSize = ZSTD_copyRawBlock (op , (size_t )(oend - op ), ip , cBlockSize );
9471009 break ;
9481010 case bt_rle :
949- decodedSize = ZSTD_setRleBlock (op , (size_t )(oend - op ), * ip , blockProperties .origSize );
1011+ decodedSize = ZSTD_setRleBlock (op , (size_t )(oBlockEnd - op ), * ip , blockProperties .origSize );
9501012 break ;
9511013 case bt_reserved :
9521014 default :
0 commit comments