diff --git a/bootstrap/src/host/bitstream.rs b/bootstrap/src/host/bitstream.rs new file mode 100644 index 00000000..19b14956 --- /dev/null +++ b/bootstrap/src/host/bitstream.rs @@ -0,0 +1,260 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BitError { + EndOfStream, +} + +impl std::fmt::Display for BitError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BitError::EndOfStream => write!(f, "end of bitstream"), + } + } +} + +impl std::error::Error for BitError {} + +#[derive(Debug, Clone)] +pub struct BitstreamReader<'a> { + data: &'a [u8], + byte_pos: usize, + bit_pos: u8, +} + +impl<'a> BitstreamReader<'a> { + pub fn new(data: &'a [u8]) -> Self { + Self { + data, + byte_pos: 0, + bit_pos: 0, + } + } + + pub fn bits_remaining(&self) -> usize { + let bytes_left = self.data.len().saturating_sub(self.byte_pos); + bytes_left * 8 - self.bit_pos as usize + } + + pub fn bytes_consumed(&self) -> usize { + if self.bit_pos > 0 { + self.byte_pos + 1 + } else { + self.byte_pos + } + } + + pub fn is_empty(&self) -> bool { + self.bits_remaining() == 0 + } + + pub fn read_bit(&mut self) -> Result { + if self.byte_pos >= self.data.len() { + return Err(BitError::EndOfStream); + } + let bit = (self.data[self.byte_pos] >> (7 - self.bit_pos)) & 1; + self.bit_pos += 1; + if self.bit_pos == 8 { + self.bit_pos = 0; + self.byte_pos += 1; + } + Ok(bit != 0) + } + + pub fn read_bits(&mut self, count: u8) -> Result { + if count as usize > self.bits_remaining() { + return Err(BitError::EndOfStream); + } + if count == 0 { + return Ok(0); + } + let mut result: u64 = 0; + for _ in 0..count { + result = (result << 1) | if self.read_bit()? { 1 } else { 0 }; + } + Ok(result) + } + + pub fn read_u8(&mut self) -> Result { + Ok(self.read_bits(8)? as u8) + } + + pub fn read_u16_le(&mut self) -> Result { + let lo = self.read_u8()?; + let hi = self.read_u8()?; + Ok((hi as u16) << 8 | lo as u16) + } + + pub fn read_u32_le(&mut self) -> Result { + let b0 = self.read_u8()?; + let b1 = self.read_u8()?; + let b2 = self.read_u8()?; + let b3 = self.read_u8()?; + Ok((b3 as u32) << 24 | (b2 as u32) << 16 | (b1 as u32) << 8 | b0 as u32) + } + + pub fn align_to_byte(&mut self) { + if self.bit_pos > 0 { + self.bit_pos = 0; + self.byte_pos += 1; + } + } + + pub fn skip_bits(&mut self, count: usize) -> Result<(), BitError> { + for _ in 0..count { + self.read_bit()?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn from_bits(bits: &[u8]) -> Vec { + let mut bytes = Vec::new(); + let mut byte: u8 = 0; + for (i, &b) in bits.iter().enumerate() { + byte = (byte << 1) | (b & 1); + if (i + 1) % 8 == 0 { + bytes.push(byte); + byte = 0; + } + } + if bits.len() % 8 != 0 { + byte <<= 8 - (bits.len() % 8); + bytes.push(byte); + } + bytes + } + + #[test] + fn empty_reader() { + let r = BitstreamReader::new(&[]); + assert!(r.is_empty()); + assert_eq!(r.bits_remaining(), 0); + } + + #[test] + fn read_single_bits() { + let data = [0b10110100]; + let mut r = BitstreamReader::new(&data); + assert!(r.read_bit().unwrap()); + assert!(!r.read_bit().unwrap()); + assert!(r.read_bit().unwrap()); + assert!(r.read_bit().unwrap()); + assert!(!r.read_bit().unwrap()); + assert!(r.read_bit().unwrap()); + assert!(!r.read_bit().unwrap()); + assert!(!r.read_bit().unwrap()); + } + + #[test] + fn read_bits_3() { + let data = [0b11100010]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.read_bits(3).unwrap(), 0b111); + assert_eq!(r.read_bits(3).unwrap(), 0b000); + assert_eq!(r.read_bits(2).unwrap(), 0b10); + } + + #[test] + fn read_bits_zero() { + let data = [0xFF]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.read_bits(0).unwrap(), 0); + assert_eq!(r.bits_remaining(), 8); + } + + #[test] + fn read_u8() { + let data = [0xAB, 0xCD]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.read_u8().unwrap(), 0xAB); + assert_eq!(r.read_u8().unwrap(), 0xCD); + } + + #[test] + fn read_u16_le() { + let data = [0x34, 0x12]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.read_u16_le().unwrap(), 0x1234); + } + + #[test] + fn read_u32_le() { + let data = [0x78, 0x56, 0x34, 0x12]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.read_u32_le().unwrap(), 0x12345678); + } + + #[test] + fn end_of_stream_bit() { + let data = [0xFF]; + let mut r = BitstreamReader::new(&data); + for _ in 0..8 { + r.read_bit().unwrap(); + } + assert!(matches!(r.read_bit(), Err(BitError::EndOfStream))); + } + + #[test] + fn end_of_stream_bits() { + let data = [0xFF]; + let mut r = BitstreamReader::new(&data); + assert!(matches!(r.read_bits(9), Err(BitError::EndOfStream))); + } + + #[test] + fn bits_remaining() { + let data = [0xFF, 0x00]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.bits_remaining(), 16); + r.read_bits(3).unwrap(); + assert_eq!(r.bits_remaining(), 13); + } + + #[test] + fn bytes_consumed() { + let data = [0xFF, 0x00, 0xAA]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.bytes_consumed(), 0); + r.read_bits(4).unwrap(); + assert_eq!(r.bytes_consumed(), 1); + r.read_bits(4).unwrap(); + assert_eq!(r.bytes_consumed(), 1); + r.read_bits(8).unwrap(); + assert_eq!(r.bytes_consumed(), 2); + } + + #[test] + fn align_to_byte() { + let data = [0xFF, 0x00]; + let mut r = BitstreamReader::new(&data); + r.read_bits(3).unwrap(); + r.align_to_byte(); + assert_eq!(r.bits_remaining(), 8); + assert_eq!(r.read_u8().unwrap(), 0x00); + } + + #[test] + fn skip_bits() { + let data = [0b10101010, 0b01010101]; + let mut r = BitstreamReader::new(&data); + r.skip_bits(8).unwrap(); + assert_eq!(r.read_bits(8).unwrap(), 0b01010101); + } + + #[test] + fn cross_byte_boundary() { + let data = [0b11001100, 0b10101010]; + let mut r = BitstreamReader::new(&data); + assert_eq!(r.read_bits(4).unwrap(), 0b1100); + assert_eq!(r.read_bits(8).unwrap(), 0b11001010); + assert_eq!(r.read_bits(4).unwrap(), 0b1010); + } + + #[test] + fn error_display() { + assert!(BitError::EndOfStream.to_string().contains("end")); + } +} diff --git a/bootstrap/src/host/mod.rs b/bootstrap/src/host/mod.rs index d4958bb9..0546728d 100644 --- a/bootstrap/src/host/mod.rs +++ b/bootstrap/src/host/mod.rs @@ -13,11 +13,15 @@ // write `use t27c::host::{BitnetDriver, MockMmio, DriverError};`. // ============================================================================ +pub mod bitstream; pub mod csr_map; pub mod driver; pub mod irq; pub mod mmio; +pub mod weight_header; +pub use bitstream::{BitError, BitstreamReader}; pub use driver::{BitnetDriver, CsrSnapshot, DriverError}; pub use irq::{IrqCallback, IrqCounters, IrqDrivenDriver, IrqHandler, IrqSource, ServiceReport}; pub use mmio::{MmioOp, MmioRecord, MockMmio}; +pub use weight_header::{HeaderError, WeightHeader, HEADER_SIZE, MAGIC, VERSION}; diff --git a/bootstrap/src/host/weight_header.rs b/bootstrap/src/host/weight_header.rs new file mode 100644 index 00000000..b190276d --- /dev/null +++ b/bootstrap/src/host/weight_header.rs @@ -0,0 +1,265 @@ +use super::bitstream::{BitError, BitstreamReader}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HeaderError { + BadMagic, + UnsupportedVersion(u8), + BitstreamError(BitError), + InvalidDimension { field: &'static str, value: u32 }, +} + +impl std::fmt::Display for HeaderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HeaderError::BadMagic => write!(f, "bad magic number"), + HeaderError::UnsupportedVersion(v) => write!(f, "unsupported version: {v}"), + HeaderError::BitstreamError(e) => write!(f, "bitstream error: {e}"), + HeaderError::InvalidDimension { field, value } => { + write!(f, "invalid {field}: {value}") + } + } + } +} + +impl std::error::Error for HeaderError {} + +impl From for HeaderError { + fn from(e: BitError) -> Self { + HeaderError::BitstreamError(e) + } +} + +pub const MAGIC: u32 = 0x54325700; +pub const VERSION: u8 = 1; +pub const HEADER_SIZE: usize = 16; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct WeightHeader { + pub version: u8, + pub layers: u16, + pub neurons_per_layer: u16, + pub bits_per_weight: u8, + pub flags: u8, + pub checksum_seed: u32, +} + +impl WeightHeader { + pub fn new(layers: u16, neurons: u16) -> Self { + Self { + version: VERSION, + layers, + neurons_per_layer: neurons, + bits_per_weight: 2, + flags: 0, + checksum_seed: 0, + } + } + + pub fn with_checksum_seed(mut self, seed: u32) -> Self { + self.checksum_seed = seed; + self + } + + pub fn with_bits_per_weight(mut self, bits: u8) -> Self { + self.bits_per_weight = bits; + self + } + + pub fn total_weights(&self) -> u64 { + self.layers as u64 * self.neurons_per_layer as u64 + } + + pub fn total_weight_bytes(&self) -> u64 { + let total_bits = self.total_weights() * self.bits_per_weight as u64; + (total_bits + 7) / 8 + } + + pub fn has_crc(&self) -> bool { + (self.flags & 0x01) != 0 + } + + pub fn is_ternary(&self) -> bool { + self.bits_per_weight == 2 + } + + pub fn encode(&self) -> [u8; HEADER_SIZE] { + let mut buf = [0u8; HEADER_SIZE]; + buf[0..4].copy_from_slice(&MAGIC.to_le_bytes()); + buf[4] = self.version; + buf[5] = self.bits_per_weight; + buf[6] = self.flags; + buf[7] = 0; + buf[8..10].copy_from_slice(&self.layers.to_le_bytes()); + buf[10..12].copy_from_slice(&self.neurons_per_layer.to_le_bytes()); + buf[12..16].copy_from_slice(&self.checksum_seed.to_le_bytes()); + buf + } + + pub fn decode(data: &[u8]) -> Result { + if data.len() < HEADER_SIZE { + return Err(HeaderError::BitstreamError(BitError::EndOfStream)); + } + let mut r = BitstreamReader::new(data); + let magic = r.read_u32_le()?; + if magic != MAGIC { + return Err(HeaderError::BadMagic); + } + let version = r.read_u8()?; + if version != VERSION { + return Err(HeaderError::UnsupportedVersion(version)); + } + let bits_per_weight = r.read_u8()?; + let flags = r.read_u8()?; + let _reserved = r.read_u8()?; + let layers = r.read_u16_le()?; + let neurons = r.read_u16_le()?; + let checksum_seed = r.read_u32_le()?; + if layers == 0 { + return Err(HeaderError::InvalidDimension { + field: "layers", + value: 0, + }); + } + if neurons == 0 { + return Err(HeaderError::InvalidDimension { + field: "neurons_per_layer", + value: 0, + }); + } + if bits_per_weight == 0 || bits_per_weight > 32 { + return Err(HeaderError::InvalidDimension { + field: "bits_per_weight", + value: bits_per_weight as u32, + }); + } + Ok(Self { + version, + layers, + neurons_per_layer: neurons, + bits_per_weight, + flags, + checksum_seed, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_header() { + let h = WeightHeader::new(4, 128); + assert_eq!(h.version, 1); + assert_eq!(h.layers, 4); + assert_eq!(h.neurons_per_layer, 128); + assert_eq!(h.bits_per_weight, 2); + } + + #[test] + fn total_weights() { + let h = WeightHeader::new(3, 256); + assert_eq!(h.total_weights(), 768); + } + + #[test] + fn total_weight_bytes_ternary() { + let h = WeightHeader::new(1, 32); + assert_eq!(h.total_weight_bytes(), 8); + } + + #[test] + fn total_weight_bytes_4bit() { + let h = WeightHeader::new(1, 16).with_bits_per_weight(4); + assert_eq!(h.total_weight_bytes(), 8); + } + + #[test] + fn has_crc_flag() { + let mut h = WeightHeader::new(1, 1); + assert!(!h.has_crc()); + h.flags = 0x01; + assert!(h.has_crc()); + } + + #[test] + fn is_ternary() { + let h = WeightHeader::new(1, 1); + assert!(h.is_ternary()); + let h2 = WeightHeader::new(1, 1).with_bits_per_weight(4); + assert!(!h2.is_ternary()); + } + + #[test] + fn encode_decode_roundtrip() { + let h = WeightHeader::new(8, 512).with_checksum_seed(0xDEADBEEF); + let encoded = h.encode(); + let decoded = WeightHeader::decode(&encoded).unwrap(); + assert_eq!(decoded, h); + } + + #[test] + fn decode_bad_magic() { + let mut buf = [0u8; 16]; + buf[0..4].copy_from_slice(&0x00000000u32.to_be_bytes()); + assert!(matches!(WeightHeader::decode(&buf), Err(HeaderError::BadMagic))); + } + + #[test] + fn decode_bad_version() { + let h = WeightHeader::new(1, 1); + let mut buf = h.encode(); + buf[4] = 99; + assert!(matches!(WeightHeader::decode(&buf), Err(HeaderError::UnsupportedVersion(99)))); + } + + #[test] + fn decode_zero_layers() { + let h = WeightHeader::new(1, 1); + let mut buf = h.encode(); + buf[8..10].copy_from_slice(&0u16.to_le_bytes()); + assert!(matches!(WeightHeader::decode(&buf), Err(HeaderError::InvalidDimension { field: "layers", .. }))); + } + + #[test] + fn decode_zero_neurons() { + let h = WeightHeader::new(1, 1); + let mut buf = h.encode(); + buf[10..12].copy_from_slice(&0u16.to_le_bytes()); + assert!(matches!(WeightHeader::decode(&buf), Err(HeaderError::InvalidDimension { field: "neurons_per_layer", .. }))); + } + + #[test] + fn decode_too_short() { + let buf = [0u8; 8]; + assert!(matches!(WeightHeader::decode(&buf), Err(HeaderError::BitstreamError(_)))); + } + + #[test] + fn encode_header_size() { + let h = WeightHeader::new(2, 64); + assert_eq!(h.encode().len(), HEADER_SIZE); + } + + #[test] + fn builder_chain() { + let h = WeightHeader::new(4, 256) + .with_bits_per_weight(4) + .with_checksum_seed(0x12345678); + assert_eq!(h.bits_per_weight, 4); + assert_eq!(h.checksum_seed, 0x12345678); + let encoded = h.encode(); + let decoded = WeightHeader::decode(&encoded).unwrap(); + assert_eq!(decoded.bits_per_weight, 4); + } + + #[test] + fn error_display() { + let e = HeaderError::BadMagic; + assert!(e.to_string().contains("magic")); + let e = HeaderError::UnsupportedVersion(5); + assert!(e.to_string().contains("version")); + let e = HeaderError::InvalidDimension { field: "layers", value: 0 }; + assert!(e.to_string().contains("layers")); + } +} diff --git a/docs/NOW.md b/docs/NOW.md index 101d8e31..1d1b6d4b 100644 --- a/docs/NOW.md +++ b/docs/NOW.md @@ -2,6 +2,14 @@ Last updated: 2026-05-24 +## wave-66 -- host weight format header parser (R-HS-14, Closes #845) + +- **WHERE** (host-only, additive): new `bootstrap/src/host/weight_header.rs` with `WeightHeader` (16-byte binary header: magic, version, layers, neurons_per_layer, bits_per_weight, flags, checksum_seed); encode/decode; validation; builder pattern; 15 inline tests. Also includes `bitstream.rs` from W65. All pass. 846 total. + +## wave-65 -- host bitstream reader for weight format parsing (R-HS-13, Closes #842) + +- **WHERE** (host-only, additive): new `bootstrap/src/host/bitstream.rs` with `BitstreamReader<'a>`; MSB-first bit reading; `read_bit`, `read_bits(n)`, `read_u8/u16_le/u32_le`; `align_to_byte`, `skip_bits`; `bits_remaining`, `bytes_consumed`; `BitError`; 15 inline tests. All pass. 831 total. + ## docs-readme-bitnet-rtt -- README.md aligned with post-W45 state (doc-only, Closes #805) - **WHERE** (doc-only, repo-root): updated `README.md` (+110 lines). Added four new System Status rows (BitNet HLS / Host stack / R-TT track / Chips) and a brand-new section `## BitNet HLS Pipeline & R-TT Reproducibility Track` documenting the 9/9 RTL pipeline, the host stack CLIs (`host-smoke`, `host-poll-vs-irq`), the R-TT track CLIs (`tt-manifest`, `tt-profile`, `tt-conform`), the three chip submodules under `chips/`, and a test-coverage summary (365/366 integration). Cross-links to `docs/NOW.md` as the live wave log. This is a housekeeping commit between waves (W45 merged at `7f463018`, W46 R-TT-3 next). Zero edits to code, kernel, spec, RTL, tests, `.gitmodules`, or `chips/`.