forked from haakonn/szbase32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
323 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,170 @@ | ||
package com.sandinh.szbase32 | ||
|
||
// TODO: don't break on invalid z-base-32 input. | ||
|
||
import java.io.ByteArrayOutputStream | ||
import java.io.OutputStream | ||
|
||
/** | ||
* Implements the <a href="http://philzimmermann.com/docs/human-oriented-base-32-encoding.txt"> | ||
* z-base-32 encoding</a>. | ||
* | ||
* @author Haakon Nilsen | ||
*/ | ||
* Implements the <a href="http://philzimmermann.com/docs/human-oriented-base-32-encoding.txt"> | ||
* z-base-32 encoding</a>. | ||
* @author Haakon Nilsen, Bui Viet Thanh | ||
*/ | ||
object ZBase32 { | ||
private[this] val encTable = "ybndrfg8ejkmcpqxot1uwisza345h769".toArray | ||
/** compute from the following code: | ||
* {{{ | ||
* def toDecTable(e: String): Array[Byte] = { | ||
* val t = Array.fill[Byte](128)(-1) | ||
* def fill(e: String) = e.zipWithIndex.foreach { | ||
* case (c, i) => t(c) = i.toByte | ||
* } | ||
* fill(e.toLowerCase) | ||
* fill(e.toUpperCase) | ||
* t.reverse.dropWhile(_ == -1).reverse | ||
* } | ||
* def fmt(v: Byte): String = { | ||
* var s = s"$v," | ||
* while(s.length < 4) s = " " + s | ||
* s | ||
* } | ||
* def pretty(t: Array[Byte]) = t.zipWithIndex.foreach { | ||
* case (v, i) => | ||
* print(fmt(v)) | ||
* if (i % 16 == 15) println() | ||
* } | ||
* val encTable = "ybndrfg8ejkmcpqxot1uwisza345h769" | ||
* pretty(toDecTable(encTable)) | ||
* }}} | ||
*/ | ||
private[this] val decTable = Array[Byte]( | ||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, | ||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, | ||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, | ||
-1, 18, -1, 25, 26, 27, 30, 29, 7, 31, -1, -1, -1, -1, -1, -1, | ||
-1, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, -1, 11, 2, 16, | ||
13, 14, 4, 22, 17, 19, -1, 20, 15, 0, 23, -1, -1, -1, -1, -1, | ||
-1, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, -1, 11, 2, 16, | ||
13, 14, 4, 22, 17, 19, -1, 20, 15, 0, 23 | ||
) | ||
|
||
val encTable = "ybndrfg8ejkmcpqxot1uwisza345h769".toList | ||
val decTable = encTable.map(c => (c, encTable indexOf c)).toMap | ||
/** | ||
* BASE32 characters are 5 bits in length. | ||
* They are formed by taking a block of five octets to form a 40-bit string, | ||
* which is converted into eight BASE32 characters. | ||
*/ | ||
private[this] val BitsPerEncodedByte = 5 | ||
private[this] val BytesPerEncodedBlock = 8 | ||
private[this] val BytesPerUnencodedBlock = 5 | ||
/** Mask used to extract 8 bits, used in decoding bytes */ | ||
private[this] val Mask8Bit = 0xff | ||
/** Mask used to extract 5 bits, used when encoding Base32 bytes */ | ||
private[this] val Mask5Bit = 0x1f | ||
|
||
def encode(src: Seq[Byte]) = { | ||
val oddLength = src.size % 5 | ||
val evenLength = src.size - oddLength | ||
val buf = new StringBuilder | ||
def d(x : Int) = src(x) & 0xff | ||
def add(b: Int) = buf += encTable(b & 0x1f) | ||
for (i <- 0 until evenLength by 5) { | ||
val b1 = d(i) | ||
val b2 = d(i + 1) | ||
val b3 = d(i + 2) | ||
val b4 = d(i + 3) | ||
val b5 = d(i + 4) | ||
add(b1 >> 3) | ||
add((b1 << 2) | (b2 >> 6)) | ||
add(b2 >> 1) | ||
add((b2 << 4) | (b3 >> 4)) | ||
add((b3 << 1) | (b4 >> 7)) | ||
add(b4 >> 2) | ||
add((b4 << 3) | (b5 >> 5)) | ||
add(b5) | ||
} | ||
if (oddLength > 0) { | ||
val b1 = d(evenLength) | ||
lazy val b2 = d(evenLength + 1) | ||
lazy val b3 = d(evenLength + 2) | ||
add(b1 >> 3) | ||
if (oddLength == 1) add(b1 << 2) | ||
if (oddLength > 1) { | ||
add((b1 << 2) | (b2 >> 6)) | ||
add(b2 >> 1) | ||
} | ||
if (oddLength == 2) add(b2 << 4) | ||
if (oddLength > 2) add((b2 << 4) | (b3 >> 4)) | ||
if (oddLength == 3) add(b3 << 1) | ||
if (oddLength == 4) { | ||
val b4 = d(evenLength + 3) | ||
add((b3 << 1) | (b4 >> 7)) | ||
add(b4 >> 2) | ||
add(b4 << 3) | ||
/** Encodes a byte[] containing binary data, into a z-base-32 string */ | ||
def encode(in: Array[Byte]): String = { | ||
val out = new StringBuilder | ||
// Writes to the buffer only occur after every 3/5 reads when encoding. | ||
// This variable helps track that. | ||
var modulus = 0 | ||
// Place holder for the bytes we're dealing with for our encoding logic. | ||
// Bitwise operations store and extract the encoding from this variable. | ||
var lbitWorkArea = 0L | ||
for (b <- in) { | ||
modulus = (modulus + 1) % BytesPerUnencodedBlock | ||
lbitWorkArea = (lbitWorkArea << 8) + b // BitPerByte | ||
if (b < 0) lbitWorkArea += 256 | ||
if (modulus == 0) { // we have enough bytes to create our output | ||
for (i <- 35 to 0 by -5) { | ||
out += encTable((lbitWorkArea >> i).toInt & Mask5Bit) | ||
} | ||
} | ||
} | ||
buf.toString | ||
|
||
modulus match { | ||
case 1 => // Only 1 octet; take top 5 bits then remainder | ||
out += encTable((lbitWorkArea >> 3).toInt & Mask5Bit) // 8-1*5 = 3 | ||
out += encTable((lbitWorkArea << 2).toInt & Mask5Bit) // 5-3=2 | ||
case 2 => // 2 octets = 16 bits to use | ||
out += encTable((lbitWorkArea >> 11).toInt & Mask5Bit) // 16-1*5 = 11 | ||
out += encTable((lbitWorkArea >> 6).toInt & Mask5Bit) // 16-2*5 = 6 | ||
out += encTable((lbitWorkArea >> 1).toInt & Mask5Bit) // 16-3*5 = 1 | ||
out += encTable((lbitWorkArea << 4).toInt & Mask5Bit) // 5-1 = 4 | ||
case 3 => // 3 octets = 24 bits to use | ||
out += encTable((lbitWorkArea >> 19).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 14).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 9).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 4).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea << 1).toInt & Mask5Bit) | ||
case 4 => // 4 octets = 32 bits to use | ||
out += encTable((lbitWorkArea >> 27).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 22).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 17).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 12).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 7).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea >> 2).toInt & Mask5Bit) | ||
out += encTable((lbitWorkArea << 3).toInt & Mask5Bit) | ||
case _ => // case 0 => no leftovers to process | ||
} | ||
|
||
out.result() | ||
} | ||
def decode(src: String) = { | ||
/*We allow whitespace and dashes in the z-base-32 input, but remove | ||
* it before decoding:*/ | ||
val d = src.replaceAll("([\\s-])", "") | ||
|
||
/** Decodes a String containing characters in the z-base-32 alphabet. | ||
* @note Silence ignore all character `c` in `in` that `!encTable.contains(c)`. | ||
* That condition is same as `0 <= c < decTable.length && decTable(c) >= 0` */ | ||
def decode(in: String): Array[Byte] = { | ||
val out = new ByteArrayOutputStream | ||
val oddLength = d.size % 8 | ||
val evenLength = d.size - oddLength - 1 | ||
for (i <- 0 to evenLength by 8) { | ||
val b = Range(0, 8).map(j => j -> decTable(d(i + j))).toMap | ||
out.write((b(0) << 3) | (b(1) >> 2)) | ||
out.write((b(1) << 6) | (b(2) << 1) | (b(3) >> 4)) | ||
out.write((b(3) << 4) | (b(4) >> 1)) | ||
out.write((b(4) << 7) | (b(5) << 2) | (b(6) >> 3)) | ||
out.write((b(6) << 5) | b(7)) | ||
} | ||
if (oddLength > 1) { // oddLength ∈ {2,4,5,7} | ||
val e = evenLength + 1 | ||
def b(x: Int) = decTable(d(x)) | ||
out.write((b(e) << 3) | (b(e + 1) >> 2)) | ||
if (oddLength > 3) { | ||
out.write((b(e + 1) << 6) | (b(e + 2) << 1) | (b(e + 3) >> 4)) | ||
} | ||
if (oddLength > 4) { | ||
out.write((b(e + 3) << 4) | (b(e + 4) >> 1)) | ||
} | ||
if (oddLength > 6) { | ||
out.write((b(e + 4) << 7) | (b(e + 5) << 2) | (b(e + 6) >> 3)) | ||
// Writes to the buffer only occur after every 4/8 reads when decoding. | ||
// This variable helps track that. | ||
var modulus = 0 | ||
// Place holder for the bytes we're dealing with for our decoding logic. | ||
// Bitwise operations store and extract the decoding from this variable. | ||
var lbitWorkArea = 0L | ||
for (c <- in) { | ||
if (c >= 0 && c < decTable.length) { | ||
val i = decTable(c) | ||
if (i >= 0) { | ||
modulus = (modulus + 1) % BytesPerEncodedBlock | ||
// collect decoded bytes | ||
lbitWorkArea = (lbitWorkArea << BitsPerEncodedByte) + i | ||
if (modulus == 0) { // we can output the 5 bytes | ||
out write ((lbitWorkArea >> 32) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 24) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte | ||
out write (lbitWorkArea & Mask8Bit).toByte | ||
} | ||
} | ||
} | ||
} | ||
|
||
// we ignore partial bytes, i.e. only multiples of 8 count | ||
modulus match { | ||
case 2 => // 10 bits, drop 2 and output 1 byte | ||
out write ((lbitWorkArea >> 2) & Mask8Bit).toByte | ||
case 3 => // 15 bits, drop 7 and output 1 byte | ||
out write ((lbitWorkArea >> 7) & Mask8Bit).toByte | ||
case 4 => // 20 bits = 2*8 + 4 | ||
lbitWorkArea >>= 4 // drop 4 bits | ||
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte | ||
out write (lbitWorkArea & Mask8Bit).toByte | ||
case 5 => // 25bits = 3*8 + 1 | ||
lbitWorkArea >>= 1 | ||
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte | ||
out write (lbitWorkArea & Mask8Bit).toByte | ||
case 6 => // 30bits = 3*8 + 6 | ||
lbitWorkArea >>= 6 | ||
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte | ||
out write (lbitWorkArea & Mask8Bit).toByte | ||
case 7 => // 35 = 4*8 +3 | ||
lbitWorkArea >>= 3 | ||
out write ((lbitWorkArea >> 24) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte | ||
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte | ||
out write (lbitWorkArea & Mask8Bit).toByte | ||
case _ => // if modulus < 2, nothing to do | ||
} | ||
|
||
out.toByteArray | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package com.sandinh.szbase32 | ||
|
||
import org.scalatest.{FlatSpec, Matchers} | ||
import szbase32.{ZBase32 => Old} | ||
|
||
import scala.util.Random | ||
import ZBase32._ | ||
|
||
import scala.annotation.tailrec | ||
import scala.collection.immutable.Stream | ||
|
||
class ZBase32Spec extends FlatSpec with Matchers { | ||
"ZBase32" should "bytes -> encode -> decode == bytes, include case bytes.length == 0" in { | ||
def test(bytesLen: Int) = { | ||
val bytes = new Array[Byte](bytesLen) | ||
Random.nextBytes(bytes) | ||
decode(encode(bytes)) shouldEqual bytes // should contain theSameElementsInOrderAs | ||
} | ||
test(0) | ||
for (_ <- 0 to 20) test(1 + Random.nextInt(100)) | ||
} | ||
|
||
it should | ||
"""do NOT NEED to: s -> decode -> encode == s | ||
| because there are >= 1 ways to encode a byte array.""".stripMargin in { | ||
for((s, s2, bytes) <- Seq( | ||
("rj", "re", Array[Byte](34)), | ||
("4ramr45", "4ramr4a", Array[Byte](-47,48,-78,107))) | ||
) { | ||
encode(decode(s)) shouldEqual s2 // don't need == s | ||
decode(s) shouldEqual bytes | ||
decode(s) shouldEqual decode(s2) | ||
} | ||
} | ||
|
||
private val encTbl = "ybndrfg8ejkmcpqxot1uwisza345h769" | ||
def rndChars: Stream[Char] = { | ||
def c = encTbl charAt (Random nextInt 32) | ||
Stream continually c | ||
} | ||
|
||
it should "en/decode same as in the old implementation" in { | ||
for (_ <- 0 to 20) { | ||
val bytes = new Array[Byte](Random.nextInt(100)) | ||
Random.nextBytes(bytes) | ||
encode(bytes) shouldEqual Old.encode(bytes) | ||
|
||
val s = rndChars.take(Random.nextInt(100)).mkString | ||
decode(s) shouldEqual Old.decode(s) | ||
} | ||
} | ||
|
||
it should "`decode` case-insensitive" in { | ||
for (_ <- 0 to 20) { | ||
val s = rndChars.take(Random.nextInt(100)).mkString | ||
decode(s) shouldEqual decode(s.toUpperCase()) | ||
} | ||
} | ||
|
||
it should "`decode` don't break on invalid z-base-32 input" in { | ||
def invalidChar(c: Char) = !encTbl.contains(c.toLower) | ||
@tailrec def invalidInput(): String = { | ||
val r = Random.alphanumeric.take(Random.nextInt(100)) | ||
if (r exists invalidChar) r.mkString | ||
else invalidInput() | ||
} | ||
|
||
for (_ <- 0 to 20) { | ||
val s = invalidInput() | ||
// println(s) | ||
// println(s.map(c => if (invalidChar(c)) " " else c).mkString) | ||
decode(s) shouldEqual decode(s filterNot invalidChar) | ||
} | ||
} | ||
} |
Oops, something went wrong.