diff --git a/espflash/Cargo.toml b/espflash/Cargo.toml index 18c987a6..54e8004b 100644 --- a/espflash/Cargo.toml +++ b/espflash/Cargo.toml @@ -32,6 +32,7 @@ strum = "0.21.0" strum_macros = "0.21.1" csv = "1.1.6" regex = "1.5.4" +flate2 = "1" [dev-dependencies] pretty_assertions = "0.7.1" diff --git a/espflash/src/chip/mod.rs b/espflash/src/chip/mod.rs index ea392556..7f3a2613 100644 --- a/espflash/src/chip/mod.rs +++ b/espflash/src/chip/mod.rs @@ -9,6 +9,8 @@ use crate::{ use std::{io::Write, str::FromStr}; +use crate::flash_target::{Esp32Target, Esp8266Target, FlashTarget, RamTarget}; +use crate::flasher::SpiAttachParams; pub use esp32::Esp32; pub use esp32c3::Esp32c3; pub use esp8266::Esp8266; @@ -139,6 +141,17 @@ impl Chip { Chip::Esp8266 => Esp8266::SPI_REGISTERS, } } + + pub fn ram_target(&self) -> Box { + Box::new(RamTarget::new()) + } + + pub fn flash_target(&self, spi_params: SpiAttachParams) -> Box { + match self { + Chip::Esp8266 => Box::new(Esp8266Target::new()), + _ => Box::new(Esp32Target::new(*self, spi_params)), + } + } } impl FromStr for Chip { diff --git a/espflash/src/flash_target/esp32.rs b/espflash/src/flash_target/esp32.rs new file mode 100644 index 00000000..b93872b8 --- /dev/null +++ b/espflash/src/flash_target/esp32.rs @@ -0,0 +1,109 @@ +use crate::connection::Connection; +use crate::elf::{FirmwareImage, RomSegment}; +use crate::error::Error; +use crate::flash_target::{begin_command, block_command_with_timeout, FlashTarget}; +use crate::flasher::{Command, SpiAttachParams, FLASH_SECTOR_SIZE, FLASH_WRITE_SIZE}; +use crate::Chip; +use flate2::write::{ZlibDecoder, ZlibEncoder}; +use flate2::Compression; +use indicatif::{ProgressBar, ProgressStyle}; +use std::io::Write; + +pub struct Esp32Target { + chip: Chip, + spi_attach_params: SpiAttachParams, +} + +impl Esp32Target { + pub fn new(chip: Chip, spi_attach_params: SpiAttachParams) -> Self { + Esp32Target { + chip, + spi_attach_params, + } + } +} + +impl FlashTarget for Esp32Target { + fn begin(&mut self, connection: &mut Connection, _image: &FirmwareImage) -> Result<(), Error> { + let spi_params = self.spi_attach_params.encode(); + connection.with_timeout(Command::SpiAttach.timeout(), |connection| { + connection.command(Command::SpiAttach as u8, spi_params.as_slice(), 0) + })?; + Ok(()) + } + + fn write_segment( + &mut self, + connection: &mut Connection, + segment: RomSegment, + ) -> Result<(), Error> { + let addr = segment.addr; + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::best()); + encoder.write_all(&segment.data)?; + let compressed = encoder.finish()?; + let block_count = (compressed.len() + FLASH_WRITE_SIZE - 1) / FLASH_WRITE_SIZE; + let erase_count = (segment.data.len() + FLASH_SECTOR_SIZE - 1) / FLASH_SECTOR_SIZE; + + // round up to sector size + let erase_size = (erase_count * FLASH_SECTOR_SIZE) as u32; + + begin_command( + connection, + Command::FlashDeflateBegin, + erase_size, + block_count as u32, + FLASH_WRITE_SIZE as u32, + addr, + self.chip != Chip::Esp32, + )?; + + let chunks = compressed.chunks(FLASH_WRITE_SIZE); + + let (_, chunk_size) = chunks.size_hint(); + let chunk_size = chunk_size.unwrap_or(0) as u64; + let pb_chunk = ProgressBar::new(chunk_size); + pb_chunk.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}") + .progress_chars("#>-"), + ); + + // decode the chunks to see how much data the device will have to save + let mut decoder = ZlibDecoder::new(Vec::new()); + let mut decoded_size = 0; + + for (i, block) in chunks.enumerate() { + decoder.write_all(block)?; + decoder.flush()?; + let size = decoder.get_ref().len() - decoded_size; + decoded_size = decoder.get_ref().len(); + + pb_chunk.set_message(format!("segment 0x{:X} writing chunks", addr)); + block_command_with_timeout( + connection, + Command::FlashDeflateData, + block, + 0, + 0xff, + i as u32, + Command::FlashDeflateData.timeout_for_size(size as u32), + )?; + pb_chunk.inc(1); + } + + pb_chunk.finish_with_message(format!("segment 0x{:X}", addr)); + + Ok(()) + } + + fn finish(&mut self, connection: &mut Connection, reboot: bool) -> Result<(), Error> { + connection.with_timeout(Command::FlashDeflateEnd.timeout(), |connection| { + connection.write_command(Command::FlashDeflateEnd as u8, &[1][..], 0) + })?; + if reboot { + connection.reset() + } else { + Ok(()) + } + } +} diff --git a/espflash/src/flash_target/esp8266.rs b/espflash/src/flash_target/esp8266.rs new file mode 100644 index 00000000..da81c69d --- /dev/null +++ b/espflash/src/flash_target/esp8266.rs @@ -0,0 +1,89 @@ +use crate::connection::Connection; +use crate::elf::{FirmwareImage, RomSegment}; +use crate::error::Error; +use crate::flash_target::{begin_command, block_command, FlashTarget}; +use crate::flasher::{get_erase_size, Command, FLASH_WRITE_SIZE}; +use indicatif::{ProgressBar, ProgressStyle}; + +pub struct Esp8266Target; + +impl Esp8266Target { + pub fn new() -> Self { + Esp8266Target + } +} + +impl FlashTarget for Esp8266Target { + fn begin(&mut self, connection: &mut Connection, _image: &FirmwareImage) -> Result<(), Error> { + begin_command( + connection, + Command::FlashBegin, + 0, + 0, + FLASH_WRITE_SIZE as u32, + 0, + false, + ) + } + + fn write_segment( + &mut self, + connection: &mut Connection, + segment: RomSegment, + ) -> Result<(), Error> { + let addr = segment.addr; + let block_count = (segment.data.len() + FLASH_WRITE_SIZE - 1) / FLASH_WRITE_SIZE; + + let erase_size = get_erase_size(addr as usize, segment.data.len()) as u32; + + begin_command( + connection, + Command::FlashBegin, + erase_size, + block_count as u32, + FLASH_WRITE_SIZE as u32, + addr, + false, + )?; + + let chunks = segment.data.chunks(FLASH_WRITE_SIZE); + + let (_, chunk_size) = chunks.size_hint(); + let chunk_size = chunk_size.unwrap_or(0) as u64; + let pb_chunk = ProgressBar::new(chunk_size); + pb_chunk.set_style( + ProgressStyle::default_bar() + .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}") + .progress_chars("#>-"), + ); + + for (i, block) in chunks.enumerate() { + pb_chunk.set_message(format!("segment 0x{:X} writing chunks", addr)); + let block_padding = FLASH_WRITE_SIZE - block.len(); + block_command( + connection, + Command::FlashData, + block, + block_padding, + 0xff, + i as u32, + )?; + pb_chunk.inc(1); + } + + pb_chunk.finish_with_message(format!("segment 0x{:X}", addr)); + + Ok(()) + } + + fn finish(&mut self, connection: &mut Connection, reboot: bool) -> Result<(), Error> { + connection.with_timeout(Command::FlashEnd.timeout(), |connection| { + connection.write_command(Command::FlashEnd as u8, &[1][..], 0) + })?; + if reboot { + connection.reset() + } else { + Ok(()) + } + } +} diff --git a/espflash/src/flash_target/mod.rs b/espflash/src/flash_target/mod.rs new file mode 100644 index 00000000..6d538d9b --- /dev/null +++ b/espflash/src/flash_target/mod.rs @@ -0,0 +1,135 @@ +mod esp32; +mod esp8266; +mod ram; + +use crate::connection::Connection; +use crate::elf::{FirmwareImage, RomSegment}; +use crate::error::Error; +use crate::flasher::{checksum, Command, Encoder, CHECKSUM_INIT, FLASH_WRITE_SIZE}; +use bytemuck::{bytes_of, Pod, Zeroable}; +pub use esp32::Esp32Target; +pub use esp8266::Esp8266Target; +pub use ram::RamTarget; +use std::mem::size_of; +use std::time::Duration; + +pub trait FlashTarget { + fn begin(&mut self, connection: &mut Connection, image: &FirmwareImage) -> Result<(), Error>; + fn write_segment( + &mut self, + connection: &mut Connection, + segment: RomSegment, + ) -> Result<(), Error>; + fn finish(&mut self, connection: &mut Connection, reboot: bool) -> Result<(), Error>; +} + +#[derive(Zeroable, Pod, Copy, Clone, Debug)] +#[repr(C)] +struct BeginParams { + size: u32, + blocks: u32, + block_size: u32, + offset: u32, + encrypted: u32, +} + +fn begin_command( + connection: &mut Connection, + command: Command, + size: u32, + blocks: u32, + block_size: u32, + offset: u32, + supports_encrypted: bool, +) -> Result<(), Error> { + let params = BeginParams { + size, + blocks, + block_size, + offset, + encrypted: 0, + }; + + let bytes = bytes_of(¶ms); + let data = if !supports_encrypted { + // The ESP32 and ESP8266 do not take the `encrypted` field, so truncate the last + // 4 bytes of the slice where it resides. + let end = bytes.len() - 4; + &bytes[0..end] + } else { + bytes + }; + + connection.with_timeout(command.timeout_for_size(size), |connection| { + connection.command(command as u8, data, 0)?; + Ok(()) + }) +} + +#[derive(Zeroable, Pod, Copy, Clone, Debug)] +#[repr(C)] +struct BlockParams { + size: u32, + sequence: u32, + dummy1: u32, + dummy2: u32, +} + +fn block_command( + connection: &mut Connection, + command: Command, + data: &[u8], + padding: usize, + padding_byte: u8, + sequence: u32, +) -> Result<(), Error> { + block_command_with_timeout( + connection, + command, + data, + padding, + padding_byte, + sequence, + command.timeout_for_size(data.len() as u32), + ) +} + +fn block_command_with_timeout( + connection: &mut Connection, + command: Command, + data: &[u8], + padding: usize, + padding_byte: u8, + sequence: u32, + timout: Duration, +) -> Result<(), Error> { + let params = BlockParams { + size: (data.len() + padding) as u32, + sequence, + dummy1: 0, + dummy2: 0, + }; + + let length = size_of::() + data.len() + padding; + + let mut check = checksum(data, CHECKSUM_INIT); + + for _ in 0..padding { + check = checksum(&[padding_byte], check); + } + + connection.with_timeout(timout, |connection| { + connection.command( + command as u8, + (length as u16, |encoder: &mut Encoder| { + encoder.write(bytes_of(¶ms))?; + encoder.write(data)?; + let padding = &[padding_byte; FLASH_WRITE_SIZE][0..padding]; + encoder.write(padding)?; + Ok(()) + }), + check as u32, + )?; + Ok(()) + }) +} diff --git a/espflash/src/flash_target/ram.rs b/espflash/src/flash_target/ram.rs new file mode 100644 index 00000000..7e46932f --- /dev/null +++ b/espflash/src/flash_target/ram.rs @@ -0,0 +1,82 @@ +use crate::connection::Connection; +use crate::elf::{FirmwareImage, RomSegment}; +use crate::error::Error; +use crate::flash_target::{begin_command, block_command, FlashTarget}; +use crate::flasher::Command; +use bytemuck::{bytes_of, Pod, Zeroable}; + +#[derive(Zeroable, Pod, Copy, Clone)] +#[repr(C)] +struct EntryParams { + no_entry: u32, + entry: u32, +} + +pub struct RamTarget { + entry: Option, +} + +impl RamTarget { + pub fn new() -> Self { + RamTarget { entry: None } + } +} + +impl FlashTarget for RamTarget { + fn begin(&mut self, _connection: &mut Connection, image: &FirmwareImage) -> Result<(), Error> { + self.entry = Some(image.entry()); + Ok(()) + } + + fn write_segment( + &mut self, + connection: &mut Connection, + segment: RomSegment, + ) -> Result<(), Error> { + const MAX_RAM_BLOCK_SIZE: usize = 0x1800; + + let padding = 4 - segment.data.len() % 4; + let block_count = + (segment.data.len() + padding + MAX_RAM_BLOCK_SIZE - 1) / MAX_RAM_BLOCK_SIZE; + + begin_command( + connection, + Command::MemBegin, + segment.data.len() as u32, + block_count as u32, + MAX_RAM_BLOCK_SIZE as u32, + segment.addr, + false, + )?; + + for (i, block) in segment.data.chunks(MAX_RAM_BLOCK_SIZE).enumerate() { + let block_padding = if i == block_count - 1 { padding } else { 0 }; + block_command( + connection, + Command::MemData, + block, + block_padding, + 0, + i as u32, + )?; + } + Ok(()) + } + + fn finish(&mut self, connection: &mut Connection, reboot: bool) -> Result<(), Error> { + if reboot { + let params = match self.entry { + Some(entry) if entry > 0 => EntryParams { no_entry: 0, entry }, + _ => EntryParams { + no_entry: 1, + entry: 0, + }, + }; + connection.with_timeout(Command::MemEnd.timeout(), |connection| { + connection.write_command(Command::MemEnd as u8, bytes_of(¶ms), 0) + }) + } else { + Ok(()) + } + } +} diff --git a/espflash/src/flasher.rs b/espflash/src/flasher.rs index 018b334c..c3612cbf 100644 --- a/espflash/src/flasher.rs +++ b/espflash/src/flasher.rs @@ -1,22 +1,21 @@ use bytemuck::{__core::time::Duration, bytes_of, Pod, Zeroable}; -use indicatif::{ProgressBar, ProgressStyle}; use serial::{BaudRate, SerialPort}; use strum_macros::Display; -use std::{mem::size_of, thread::sleep}; +use std::thread::sleep; +use crate::elf::RomSegment; use crate::{ chip::Chip, connection::Connection, elf::FirmwareImage, encoder::SlipEncoder, error::RomError, Error, PartitionTable, }; -type Encoder<'a> = SlipEncoder<'a, Box>; +pub(crate) type Encoder<'a> = SlipEncoder<'a, Box>; -const MAX_RAM_BLOCK_SIZE: usize = 0x1800; -const FLASH_SECTOR_SIZE: usize = 0x1000; +pub(crate) const FLASH_SECTOR_SIZE: usize = 0x1000; const FLASH_BLOCK_SIZE: usize = 0x100; const FLASH_SECTORS_PER_BLOCK: usize = FLASH_SECTOR_SIZE / FLASH_BLOCK_SIZE; -const FLASH_WRITE_SIZE: usize = 0x400; +pub(crate) const FLASH_WRITE_SIZE: usize = 0x400; // register used for chip detect const CHIP_DETECT_MAGIC_REG_ADDR: u32 = 0x40001000; @@ -30,7 +29,7 @@ const SYNC_TIMEOUT: Duration = Duration::from_millis(100); #[derive(Copy, Clone, Debug)] #[allow(dead_code)] #[repr(u8)] -enum Command { +pub(crate) enum Command { FlashBegin = 0x02, FlashData = 0x03, FlashEnd = 0x04, @@ -43,6 +42,10 @@ enum Command { SpiSetParams = 0x0B, SpiAttach = 0x0D, ChangeBaud = 0x0F, + FlashDeflateBegin = 0x10, + FlashDeflateData = 0x11, + FlashDeflateEnd = 0x12, + FlashMd5 = 0x13, } impl Command { @@ -64,7 +67,9 @@ impl Command { } match self { Command::FlashBegin => calc_timeout(ERASE_REGION_TIMEOUT_PER_MB, size), - Command::FlashData => calc_timeout(ERASE_WRITE_TIMEOUT_PER_MB, size), + Command::FlashData | Command::FlashDeflateData => { + calc_timeout(ERASE_WRITE_TIMEOUT_PER_MB, size) + } _ => self.timeout(), } } @@ -109,7 +114,7 @@ impl FlashSize { #[derive(Copy, Clone)] #[repr(C)] -struct SpiAttachParams { +pub struct SpiAttachParams { clk: u8, q: u8, d: u8, @@ -341,64 +346,6 @@ impl Flasher { }) } - fn block_command( - &mut self, - command: Command, - data: &[u8], - padding: usize, - padding_byte: u8, - sequence: u32, - ) -> Result<(), Error> { - let params = BlockParams { - size: (data.len() + padding) as u32, - sequence, - dummy1: 0, - dummy2: 0, - }; - - let length = size_of::() + data.len() + padding; - - let mut check = checksum(data, CHECKSUM_INIT); - - for _ in 0..padding { - check = checksum(&[padding_byte], check); - } - - self.connection - .with_timeout(command.timeout_for_size(data.len() as u32), |connection| { - connection.command( - command as u8, - (length as u16, |encoder: &mut Encoder| { - encoder.write(bytes_of(¶ms))?; - encoder.write(data)?; - let padding = &[padding_byte; FLASH_WRITE_SIZE][0..padding]; - encoder.write(padding)?; - Ok(()) - }), - check as u32, - )?; - Ok(()) - }) - } - - fn mem_finish(&mut self, entry: u32) -> Result<(), Error> { - let params = EntryParams { - no_entry: (entry == 0) as u32, - entry, - }; - self.connection - .with_timeout(Command::MemEnd.timeout(), |connection| { - connection.write_command(Command::MemEnd as u8, bytes_of(¶ms), 0) - }) - } - - fn flash_finish(&mut self, reboot: bool) -> Result<(), Error> { - self.connection - .with_timeout(Command::FlashEnd.timeout(), |connection| { - connection.write_command(Command::FlashEnd as u8, &[(!reboot) as u8][..], 0) - }) - } - fn enable_flash(&mut self, spi_attach_params: SpiAttachParams) -> Result<(), Error> { match self.chip { Chip::Esp8266 => { @@ -523,31 +470,24 @@ impl Flasher { pub fn load_elf_to_ram(&mut self, elf_data: &[u8]) -> Result<(), Error> { let image = FirmwareImage::from_data(elf_data).map_err(|_| Error::InvalidElf)?; + let mut target = self.chip.ram_target(); + target.begin(&mut self.connection, &image)?; + if image.rom_segments(self.chip).next().is_some() { return Err(Error::ElfNotRamLoadable); } for segment in image.ram_segments(self.chip) { - let padding = 4 - segment.data.len() % 4; - let block_count = - (segment.data.len() + padding + MAX_RAM_BLOCK_SIZE - 1) / MAX_RAM_BLOCK_SIZE; - self.begin_command( - Command::MemBegin, - segment.data.len() as u32, - block_count as u32, - MAX_RAM_BLOCK_SIZE as u32, - segment.addr, + target.write_segment( + &mut self.connection, + RomSegment { + addr: segment.addr, + data: segment.data.into(), + }, )?; - - for (i, block) in segment.data.chunks(MAX_RAM_BLOCK_SIZE).enumerate() { - let block_padding = if i == block_count - 1 { padding } else { 0 }; - self.block_command(Command::MemData, block, block_padding, 0, i as u32)?; - } } - self.mem_finish(image.entry())?; - - Ok(()) + target.finish(&mut self.connection, true) } /// Load an elf image to flash and execute it @@ -557,56 +497,20 @@ impl Flasher { bootloader: Option>, partition_table: Option, ) -> Result<(), Error> { - self.enable_flash(self.spi_params)?; - let mut image = FirmwareImage::from_data(elf_data).map_err(|_| Error::InvalidElf)?; image.flash_size = self.flash_size(); + let mut target = self.chip.flash_target(self.spi_params); + target.begin(&mut self.connection, &image)?; + for segment in self .chip .get_flash_segments(&image, bootloader, partition_table) { - let segment = segment?; - let addr = segment.addr; - let block_count = (segment.data.len() + FLASH_WRITE_SIZE - 1) / FLASH_WRITE_SIZE; - - let erase_size = match self.chip { - Chip::Esp8266 => get_erase_size(addr as usize, segment.data.len()) as u32, - _ => segment.data.len() as u32, - }; - - self.begin_command( - Command::FlashBegin, - erase_size, - block_count as u32, - FLASH_WRITE_SIZE as u32, - addr, - )?; - - let chunks = segment.data.chunks(FLASH_WRITE_SIZE); - - let (_, chunk_size) = chunks.size_hint(); - let chunk_size = chunk_size.unwrap_or(0) as u64; - let pb_chunk = ProgressBar::new(chunk_size); - pb_chunk.set_style( - ProgressStyle::default_bar() - .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}") - .progress_chars("#>-"), - ); - - for (i, block) in chunks.enumerate() { - pb_chunk.set_message(format!("segment 0x{:X} writing chunks", addr)); - let block_padding = FLASH_WRITE_SIZE - block.len(); - self.block_command(Command::FlashData, block, block_padding, 0xff, i as u32)?; - pb_chunk.inc(1); - } - - pb_chunk.finish_with_message(format!("segment 0x{:X}", addr)); + target.write_segment(&mut self.connection, segment?)?; } - self.flash_finish(false)?; - - self.connection.reset()?; + target.finish(&mut self.connection, true)?; Ok(()) } @@ -630,7 +534,7 @@ impl Flasher { } } -fn get_erase_size(offset: usize, size: usize) -> usize { +pub(crate) fn get_erase_size(offset: usize, size: usize) -> usize { let sector_count = (size + FLASH_SECTOR_SIZE - 1) / FLASH_SECTOR_SIZE; let start_sector = offset / FLASH_SECTOR_SIZE; @@ -646,7 +550,7 @@ fn get_erase_size(offset: usize, size: usize) -> usize { } } -const CHECKSUM_INIT: u8 = 0xEF; +pub(crate) const CHECKSUM_INIT: u8 = 0xEF; pub fn checksum(data: &[u8], mut checksum: u8) -> u8 { for byte in data { diff --git a/espflash/src/lib.rs b/espflash/src/lib.rs index 295ff78f..4f6237d0 100644 --- a/espflash/src/lib.rs +++ b/espflash/src/lib.rs @@ -4,6 +4,7 @@ mod connection; mod elf; mod encoder; mod error; +mod flash_target; mod flasher; mod partition_table;