Skip to content

Commit

Permalink
Start implementing LZMA2 decoder.
Browse files Browse the repository at this point in the history
  • Loading branch information
gendx committed Nov 16, 2017
1 parent 1d7ba4e commit a4c1b71
Show file tree
Hide file tree
Showing 8 changed files with 601 additions and 98 deletions.
7 changes: 4 additions & 3 deletions fuzz/fuzz_targets/roundtrip.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
#![no_main]
#[macro_use] extern crate libfuzzer_sys;
#[macro_use]
extern crate libfuzzer_sys;
extern crate lzma;

use lzma::error::Result;

fn round_trip(x: &[u8]) -> Result<Vec<u8>> {
let mut compressed: Vec<u8> = Vec::new();
lzma::compress(&mut std::io::BufReader::new(x), &mut compressed)?;
lzma::lzma_compress(&mut std::io::BufReader::new(x), &mut compressed)?;
let mut bf = std::io::BufReader::new(compressed.as_slice());

let mut decomp: Vec<u8> = Vec::new();
lzma::decompress(&mut bf, &mut decomp).expect("Can't decompress what we just compressed");
lzma::lzma_decompress(&mut bf, &mut decomp).expect("Can't decompress what we just compressed");
Ok(decomp)
}

Expand Down
243 changes: 168 additions & 75 deletions src/decode/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,21 @@ use decode::lzbuffer;
use decode::rangecoder;
use byteorder::{LittleEndian, ReadBytesExt};

pub struct Decoder<'a, R, W>
where
R: 'a + io::BufRead,
W: 'a + io::Write,
{
pub struct LZMAParams {
// most lc significant bits of previous byte are part of the literal context
lc: u32, // 0..8
lp: u32, // 0..4
// context for literal/match is plaintext offset modulo 2^pb
pb: u32, // 0..4
dict_size: u32,
unpacked_size: Option<u64>,
rangecoder: rangecoder::RangeDecoder<'a, R>,
literal_probs: Vec<Vec<u16>>,
pos_slot_decoder: Vec<rangecoder::BitTree>,
align_decoder: rangecoder::BitTree,
pos_decoders: [u16; 115],
output: lzbuffer::LZBuffer<'a, W>,
is_match: [u16; 192], // true = LZ, false = literal
is_rep: [u16; 12],
is_rep_g0: [u16; 12],
is_rep_g1: [u16; 12],
is_rep_g2: [u16; 12],
is_rep_0long: [u16; 192],
state: usize,
rep: [usize; 4],
len_decoder: rangecoder::LenDecoder,
rep_len_decoder: rangecoder::LenDecoder,
}

impl<'a, R, W> Decoder<'a, R, W>
where
R: io::BufRead,
W: io::Write,
{
// Read LZMA header and initialize decoder
pub fn from_stream(stream: &'a mut R, output: &'a mut W) -> error::Result<Self> {
impl LZMAParams {
pub fn read_header<R>(stream: &mut R) -> error::Result<LZMAParams>
where
R: io::BufRead,
{
// Properties
let props = try!(stream.read_u8().or_else(|e| {
Err(error::Error::LZMAError(
Expand All @@ -50,7 +29,7 @@ where
let mut pb = props as u32;
if pb >= 225 {
return Err(error::Error::LZMAError(format!(
"LZMA header invalid properties: {} should be < 225",
"LZMA header invalid properties: {} must be < 225",
pb
)));
}
Expand Down Expand Up @@ -91,55 +70,162 @@ where

info!("Unpacked size: {:?}", unpacked_size);

// Decoder
let decoder = Decoder {
let params = LZMAParams {
lc: lc,
lp: lp,
pb: pb,
dict_size: dict_size,
unpacked_size: unpacked_size,
rangecoder: try!(rangecoder::RangeDecoder::new(stream).or_else(|e| {
Err(error::Error::LZMAError(
format!("LZMA stream too short: {}", e),
))
})),
literal_probs: vec![vec![0x400; 0x300]; 1 << (lc + lp)],
pos_slot_decoder: vec![rangecoder::BitTree::new(6); 4],
align_decoder: rangecoder::BitTree::new(4),
pos_decoders: [0x400; 115],
output: lzbuffer::LZBuffer::from_stream(output, dict_size as usize),
is_match: [0x400; 192],
is_rep: [0x400; 12],
is_rep_g0: [0x400; 12],
is_rep_g1: [0x400; 12],
is_rep_g2: [0x400; 12],
is_rep_0long: [0x400; 192],
state: 0,
rep: [0; 4],
len_decoder: rangecoder::LenDecoder::new(),
rep_len_decoder: rangecoder::LenDecoder::new(),
};

Ok(decoder)
Ok(params)
}
}


pub struct DecoderState<LZB>
where
LZB: lzbuffer::LZBuffer,
{
pub output: LZB,
// most lc significant bits of previous byte are part of the literal context
pub lc: u32, // 0..8
pub lp: u32, // 0..4
// context for literal/match is plaintext offset modulo 2^pb
pub pb: u32, // 0..4
unpacked_size: Option<u64>,
literal_probs: Vec<Vec<u16>>,
pos_slot_decoder: Vec<rangecoder::BitTree>,
align_decoder: rangecoder::BitTree,
pos_decoders: [u16; 115],
is_match: [u16; 192], // true = LZ, false = literal
is_rep: [u16; 12],
is_rep_g0: [u16; 12],
is_rep_g1: [u16; 12],
is_rep_g2: [u16; 12],
is_rep_0long: [u16; 192],
state: usize,
rep: [usize; 4],
len_decoder: rangecoder::LenDecoder,
rep_len_decoder: rangecoder::LenDecoder,
}

// Initialize decoder with accumulating buffer
pub fn new_accum<'a, W>(
output: lzbuffer::LZAccumBuffer<'a, W>,
lc: u32,
lp: u32,
pb: u32,
unpacked_size: Option<u64>,
) -> DecoderState<lzbuffer::LZAccumBuffer<'a, W>>
where
W: io::Write,
{
DecoderState {
output: output,
lc: lc,
lp: lp,
pb: pb,
unpacked_size: unpacked_size,
literal_probs: vec![vec![0x400; 0x300]; 1 << (lc + lp)],
pos_slot_decoder: vec![rangecoder::BitTree::new(6); 4],
align_decoder: rangecoder::BitTree::new(4),
pos_decoders: [0x400; 115],
is_match: [0x400; 192],
is_rep: [0x400; 12],
is_rep_g0: [0x400; 12],
is_rep_g1: [0x400; 12],
is_rep_g2: [0x400; 12],
is_rep_0long: [0x400; 192],
state: 0,
rep: [0; 4],
len_decoder: rangecoder::LenDecoder::new(),
rep_len_decoder: rangecoder::LenDecoder::new(),
}
}

// Initialize decoder with circular buffer
pub fn new_circular<'a, W>(
output: &'a mut W,
params: LZMAParams,
) -> error::Result<DecoderState<lzbuffer::LZCircularBuffer<'a, W>>>
where
W: io::Write,
{
// Decoder
let decoder = DecoderState {
output: lzbuffer::LZCircularBuffer::from_stream(output, params.dict_size as usize),
lc: params.lc,
lp: params.lp,
pb: params.pb,
unpacked_size: params.unpacked_size,
literal_probs: vec![vec![0x400; 0x300]; 1 << (params.lc + params.lp)],
pos_slot_decoder: vec![rangecoder::BitTree::new(6); 4],
align_decoder: rangecoder::BitTree::new(4),
pos_decoders: [0x400; 115],
is_match: [0x400; 192],
is_rep: [0x400; 12],
is_rep_g0: [0x400; 12],
is_rep_g1: [0x400; 12],
is_rep_g2: [0x400; 12],
is_rep_0long: [0x400; 192],
state: 0,
rep: [0; 4],
len_decoder: rangecoder::LenDecoder::new(),
rep_len_decoder: rangecoder::LenDecoder::new(),
};

Ok(decoder)
}

impl<LZB> DecoderState<LZB>
where
LZB: lzbuffer::LZBuffer,
{
pub fn reset_state(&mut self, lc: u32, lp: u32, pb: u32) {
self.lc = lc;
self.lp = lp;
self.pb = pb;
self.literal_probs = vec![vec![0x400; 0x300]; 1 << (lc + lp)];
self.pos_slot_decoder = vec![rangecoder::BitTree::new(6); 4];
self.align_decoder = rangecoder::BitTree::new(4);
self.pos_decoders = [0x400; 115];
self.is_match = [0x400; 192];
self.is_rep = [0x400; 12];
self.is_rep_g0 = [0x400; 12];
self.is_rep_g1 = [0x400; 12];
self.is_rep_g2 = [0x400; 12];
self.is_rep_0long = [0x400; 192];
self.state = 0;
self.rep = [0; 4];
self.len_decoder = rangecoder::LenDecoder::new();
self.rep_len_decoder = rangecoder::LenDecoder::new();
}

pub fn set_unpacked_size(&mut self, unpacked_size: Option<u64>) {
self.unpacked_size = unpacked_size;
}

pub fn process(mut self) -> error::Result<()> {
pub fn process<'a, R: io::BufRead>(
&mut self,
rangecoder: &mut rangecoder::RangeDecoder<'a, R>,
) -> error::Result<()> {
loop {
if let Some(_) = self.unpacked_size {
if self.rangecoder.is_finished_ok()? {
if rangecoder.is_finished_ok()? {
break;
}
}

let pos_state = self.output.len() & ((1 << self.pb) - 1);

// Literal
if !self.rangecoder.decode_bit(
if !rangecoder.decode_bit(
// TODO: assumes pb = 2 ??
&mut self.is_match[(self.state << 4) +
pos_state],
&mut self.is_match[(self.state << 4) + pos_state],
)?
{
let byte: u8 = self.decode_literal()?;
let byte: u8 = self.decode_literal(rangecoder)?;
debug!("Literal: {}", byte);
self.output.append_literal(byte)?;

Expand All @@ -158,11 +244,11 @@ where
// LZ
let mut len: usize;
// Distance is repeated from LRU
if self.rangecoder.decode_bit(&mut self.is_rep[self.state])? {
if rangecoder.decode_bit(&mut self.is_rep[self.state])? {
// dist = rep[0]
if !self.rangecoder.decode_bit(&mut self.is_rep_g0[self.state])? {
if !rangecoder.decode_bit(&mut self.is_rep_g0[self.state])? {
// len = 1
if !self.rangecoder.decode_bit(
if !rangecoder.decode_bit(
&mut self.is_rep_0long[(self.state << 4) +
pos_state],
)?
Expand All @@ -176,10 +262,10 @@ where
// dist = rep[i]
} else {
let idx: usize;
if !self.rangecoder.decode_bit(&mut self.is_rep_g1[self.state])? {
if !rangecoder.decode_bit(&mut self.is_rep_g1[self.state])? {
idx = 1;
} else {
if !self.rangecoder.decode_bit(&mut self.is_rep_g2[self.state])? {
if !rangecoder.decode_bit(&mut self.is_rep_g2[self.state])? {
idx = 2;
} else {
idx = 3;
Expand All @@ -193,7 +279,7 @@ where
self.rep[0] = dist
}

len = self.rep_len_decoder.decode(&mut self.rangecoder, pos_state)?;
len = self.rep_len_decoder.decode(rangecoder, pos_state)?;
// update state (rep)
self.state = if self.state < 7 { 8 } else { 11 };
// New distance
Expand All @@ -202,14 +288,14 @@ where
self.rep[3] = self.rep[2];
self.rep[2] = self.rep[1];
self.rep[1] = self.rep[0];
len = self.len_decoder.decode(&mut self.rangecoder, pos_state)?;
len = self.len_decoder.decode(rangecoder, pos_state)?;

// update state (match)
self.state = if self.state < 7 { 7 } else { 10 };
self.rep[0] = self.decode_distance(len)?;
self.rep[0] = self.decode_distance(rangecoder, len)?;

if self.rep[0] == 0xFFFF_FFFF {
if self.rangecoder.is_finished_ok()? {
if rangecoder.is_finished_ok()? {
break;
}
return Err(error::Error::LZMAError(String::from(
Expand Down Expand Up @@ -238,7 +324,10 @@ where
Ok(())
}

fn decode_literal(&mut self) -> error::Result<u8> {
fn decode_literal<'a, R: io::BufRead>(
&mut self,
rangecoder: &mut rangecoder::RangeDecoder<'a, R>,
) -> error::Result<u8> {
let def_prev_byte = 0u8;
let prev_byte = self.output.last_or(def_prev_byte) as usize;

Expand All @@ -253,7 +342,7 @@ where
while result < 0x100 {
let match_bit = (match_byte >> 7) & 1;
match_byte <<= 1;
let bit = self.rangecoder.decode_bit(
let bit = rangecoder.decode_bit(
&mut probs[((1 + match_bit) << 8) + result],
)? as usize;
result = (result << 1) ^ bit;
Expand All @@ -264,16 +353,20 @@ where
}

while result < 0x100 {
result = (result << 1) ^ (self.rangecoder.decode_bit(&mut probs[result])? as usize);
result = (result << 1) ^ (rangecoder.decode_bit(&mut probs[result])? as usize);
}

Ok((result - 0x100) as u8)
}

fn decode_distance(&mut self, length: usize) -> error::Result<usize> {
fn decode_distance<'a, R: io::BufRead>(
&mut self,
rangecoder: &mut rangecoder::RangeDecoder<'a, R>,
length: usize,
) -> error::Result<usize> {
let len_state = if length > 3 { 3 } else { length };

let pos_slot = self.pos_slot_decoder[len_state].parse(&mut self.rangecoder)? as usize;
let pos_slot = self.pos_slot_decoder[len_state].parse(rangecoder)? as usize;
if pos_slot < 4 {
return Ok(pos_slot);
}
Expand All @@ -282,14 +375,14 @@ where
let mut result = (2 ^ (pos_slot & 1)) << num_direct_bits;

if pos_slot < 14 {
result += self.rangecoder.parse_reverse_bit_tree(
result += rangecoder.parse_reverse_bit_tree(
num_direct_bits,
&mut self.pos_decoders,
result - pos_slot,
)? as usize;
} else {
result += (self.rangecoder.get(num_direct_bits - 4)? as usize) << 4;
result += self.align_decoder.parse_reverse(&mut self.rangecoder)? as usize;
result += (rangecoder.get(num_direct_bits - 4)? as usize) << 4;
result += self.align_decoder.parse_reverse(rangecoder)? as usize;
}

Ok(result)
Expand Down
Loading

0 comments on commit a4c1b71

Please sign in to comment.