Skip to content

Commit

Permalink
add some test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
giabao committed Apr 28, 2018
1 parent 7dccaa6 commit c076eb6
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 78 deletions.
231 changes: 153 additions & 78 deletions szbase32/src/com/sandinh/szbase32/ZBase32.scala
@@ -1,95 +1,170 @@
package com.sandinh.szbase32

// TODO: don't break on invalid z-base-32 input.

import java.io.ByteArrayOutputStream
import java.io.OutputStream

/**
* Implements the <a href="http://philzimmermann.com/docs/human-oriented-base-32-encoding.txt">
* z-base-32 encoding</a>.
*
* @author Haakon Nilsen
*/
* Implements the <a href="http://philzimmermann.com/docs/human-oriented-base-32-encoding.txt">
* z-base-32 encoding</a>.
* @author Haakon Nilsen, Bui Viet Thanh
*/
object ZBase32 {
private[this] val encTable = "ybndrfg8ejkmcpqxot1uwisza345h769".toArray
/** compute from the following code:
* {{{
* def toDecTable(e: String): Array[Byte] = {
* val t = Array.fill[Byte](128)(-1)
* def fill(e: String) = e.zipWithIndex.foreach {
* case (c, i) => t(c) = i.toByte
* }
* fill(e.toLowerCase)
* fill(e.toUpperCase)
* t.reverse.dropWhile(_ == -1).reverse
* }
* def fmt(v: Byte): String = {
* var s = s"$v,"
* while(s.length < 4) s = " " + s
* s
* }
* def pretty(t: Array[Byte]) = t.zipWithIndex.foreach {
* case (v, i) =>
* print(fmt(v))
* if (i % 16 == 15) println()
* }
* val encTable = "ybndrfg8ejkmcpqxot1uwisza345h769"
* pretty(toDecTable(encTable))
* }}}
*/
private[this] val decTable = Array[Byte](
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, 18, -1, 25, 26, 27, 30, 29, 7, 31, -1, -1, -1, -1, -1, -1,
-1, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, -1, 11, 2, 16,
13, 14, 4, 22, 17, 19, -1, 20, 15, 0, 23, -1, -1, -1, -1, -1,
-1, 24, 1, 12, 3, 8, 5, 6, 28, 21, 9, 10, -1, 11, 2, 16,
13, 14, 4, 22, 17, 19, -1, 20, 15, 0, 23
)

val encTable = "ybndrfg8ejkmcpqxot1uwisza345h769".toList
val decTable = encTable.map(c => (c, encTable indexOf c)).toMap
/**
* BASE32 characters are 5 bits in length.
* They are formed by taking a block of five octets to form a 40-bit string,
* which is converted into eight BASE32 characters.
*/
private[this] val BitsPerEncodedByte = 5
private[this] val BytesPerEncodedBlock = 8
private[this] val BytesPerUnencodedBlock = 5
/** Mask used to extract 8 bits, used in decoding bytes */
private[this] val Mask8Bit = 0xff
/** Mask used to extract 5 bits, used when encoding Base32 bytes */
private[this] val Mask5Bit = 0x1f

def encode(src: Seq[Byte]) = {
val oddLength = src.size % 5
val evenLength = src.size - oddLength
val buf = new StringBuilder
def d(x : Int) = src(x) & 0xff
def add(b: Int) = buf += encTable(b & 0x1f)
for (i <- 0 until evenLength by 5) {
val b1 = d(i)
val b2 = d(i + 1)
val b3 = d(i + 2)
val b4 = d(i + 3)
val b5 = d(i + 4)
add(b1 >> 3)
add((b1 << 2) | (b2 >> 6))
add(b2 >> 1)
add((b2 << 4) | (b3 >> 4))
add((b3 << 1) | (b4 >> 7))
add(b4 >> 2)
add((b4 << 3) | (b5 >> 5))
add(b5)
}
if (oddLength > 0) {
val b1 = d(evenLength)
lazy val b2 = d(evenLength + 1)
lazy val b3 = d(evenLength + 2)
add(b1 >> 3)
if (oddLength == 1) add(b1 << 2)
if (oddLength > 1) {
add((b1 << 2) | (b2 >> 6))
add(b2 >> 1)
}
if (oddLength == 2) add(b2 << 4)
if (oddLength > 2) add((b2 << 4) | (b3 >> 4))
if (oddLength == 3) add(b3 << 1)
if (oddLength == 4) {
val b4 = d(evenLength + 3)
add((b3 << 1) | (b4 >> 7))
add(b4 >> 2)
add(b4 << 3)
/** Encodes a byte[] containing binary data, into a z-base-32 string */
def encode(in: Array[Byte]): String = {
val out = new StringBuilder
// Writes to the buffer only occur after every 3/5 reads when encoding.
// This variable helps track that.
var modulus = 0
// Place holder for the bytes we're dealing with for our encoding logic.
// Bitwise operations store and extract the encoding from this variable.
var lbitWorkArea = 0L
for (b <- in) {
modulus = (modulus + 1) % BytesPerUnencodedBlock
lbitWorkArea = (lbitWorkArea << 8) + b // BitPerByte
if (b < 0) lbitWorkArea += 256
if (modulus == 0) { // we have enough bytes to create our output
for (i <- 35 to 0 by -5) {
out += encTable((lbitWorkArea >> i).toInt & Mask5Bit)
}
}
}
buf.toString

modulus match {
case 1 => // Only 1 octet; take top 5 bits then remainder
out += encTable((lbitWorkArea >> 3).toInt & Mask5Bit) // 8-1*5 = 3
out += encTable((lbitWorkArea << 2).toInt & Mask5Bit) // 5-3=2
case 2 => // 2 octets = 16 bits to use
out += encTable((lbitWorkArea >> 11).toInt & Mask5Bit) // 16-1*5 = 11
out += encTable((lbitWorkArea >> 6).toInt & Mask5Bit) // 16-2*5 = 6
out += encTable((lbitWorkArea >> 1).toInt & Mask5Bit) // 16-3*5 = 1
out += encTable((lbitWorkArea << 4).toInt & Mask5Bit) // 5-1 = 4
case 3 => // 3 octets = 24 bits to use
out += encTable((lbitWorkArea >> 19).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 14).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 9).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 4).toInt & Mask5Bit)
out += encTable((lbitWorkArea << 1).toInt & Mask5Bit)
case 4 => // 4 octets = 32 bits to use
out += encTable((lbitWorkArea >> 27).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 22).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 17).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 12).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 7).toInt & Mask5Bit)
out += encTable((lbitWorkArea >> 2).toInt & Mask5Bit)
out += encTable((lbitWorkArea << 3).toInt & Mask5Bit)
case _ => // case 0 => no leftovers to process
}

out.result()
}
def decode(src: String) = {
/*We allow whitespace and dashes in the z-base-32 input, but remove
* it before decoding:*/
val d = src.replaceAll("([\\s-])", "")

/** Decodes a String containing characters in the z-base-32 alphabet.
* @note Silence ignore all character `c` in `in` that `!encTable.contains(c)`.
* That condition is same as `0 <= c < decTable.length && decTable(c) >= 0` */
def decode(in: String): Array[Byte] = {
val out = new ByteArrayOutputStream
val oddLength = d.size % 8
val evenLength = d.size - oddLength - 1
for (i <- 0 to evenLength by 8) {
val b = Range(0, 8).map(j => j -> decTable(d(i + j))).toMap
out.write((b(0) << 3) | (b(1) >> 2))
out.write((b(1) << 6) | (b(2) << 1) | (b(3) >> 4))
out.write((b(3) << 4) | (b(4) >> 1))
out.write((b(4) << 7) | (b(5) << 2) | (b(6) >> 3))
out.write((b(6) << 5) | b(7))
}
if (oddLength > 1) { // oddLength ∈ {2,4,5,7}
val e = evenLength + 1
def b(x: Int) = decTable(d(x))
out.write((b(e) << 3) | (b(e + 1) >> 2))
if (oddLength > 3) {
out.write((b(e + 1) << 6) | (b(e + 2) << 1) | (b(e + 3) >> 4))
}
if (oddLength > 4) {
out.write((b(e + 3) << 4) | (b(e + 4) >> 1))
}
if (oddLength > 6) {
out.write((b(e + 4) << 7) | (b(e + 5) << 2) | (b(e + 6) >> 3))
// Writes to the buffer only occur after every 4/8 reads when decoding.
// This variable helps track that.
var modulus = 0
// Place holder for the bytes we're dealing with for our decoding logic.
// Bitwise operations store and extract the decoding from this variable.
var lbitWorkArea = 0L
for (c <- in) {
if (c >= 0 && c < decTable.length) {
val i = decTable(c)
if (i >= 0) {
modulus = (modulus + 1) % BytesPerEncodedBlock
// collect decoded bytes
lbitWorkArea = (lbitWorkArea << BitsPerEncodedByte) + i
if (modulus == 0) { // we can output the 5 bytes
out write ((lbitWorkArea >> 32) & Mask8Bit).toByte
out write ((lbitWorkArea >> 24) & Mask8Bit).toByte
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte
out write (lbitWorkArea & Mask8Bit).toByte
}
}
}
}

// we ignore partial bytes, i.e. only multiples of 8 count
modulus match {
case 2 => // 10 bits, drop 2 and output 1 byte
out write ((lbitWorkArea >> 2) & Mask8Bit).toByte
case 3 => // 15 bits, drop 7 and output 1 byte
out write ((lbitWorkArea >> 7) & Mask8Bit).toByte
case 4 => // 20 bits = 2*8 + 4
lbitWorkArea >>= 4 // drop 4 bits
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte
out write (lbitWorkArea & Mask8Bit).toByte
case 5 => // 25bits = 3*8 + 1
lbitWorkArea >>= 1
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte
out write (lbitWorkArea & Mask8Bit).toByte
case 6 => // 30bits = 3*8 + 6
lbitWorkArea >>= 6
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte
out write (lbitWorkArea & Mask8Bit).toByte
case 7 => // 35 = 4*8 +3
lbitWorkArea >>= 3
out write ((lbitWorkArea >> 24) & Mask8Bit).toByte
out write ((lbitWorkArea >> 16) & Mask8Bit).toByte
out write ((lbitWorkArea >> 8) & Mask8Bit).toByte
out write (lbitWorkArea & Mask8Bit).toByte
case _ => // if modulus < 2, nothing to do
}

out.toByteArray
}

}
75 changes: 75 additions & 0 deletions szbase32/test/src/com/sandinh/szbase32/ZBase32Spec.scala
@@ -0,0 +1,75 @@
package com.sandinh.szbase32

import org.scalatest.{FlatSpec, Matchers}
import szbase32.{ZBase32 => Old}

import scala.util.Random
import ZBase32._

import scala.annotation.tailrec
import scala.collection.immutable.Stream

class ZBase32Spec extends FlatSpec with Matchers {
"ZBase32" should "bytes -> encode -> decode == bytes, include case bytes.length == 0" in {
def test(bytesLen: Int) = {
val bytes = new Array[Byte](bytesLen)
Random.nextBytes(bytes)
decode(encode(bytes)) shouldEqual bytes // should contain theSameElementsInOrderAs
}
test(0)
for (_ <- 0 to 20) test(1 + Random.nextInt(100))
}

it should
"""do NOT NEED to: s -> decode -> encode == s
| because there are >= 1 ways to encode a byte array.""".stripMargin in {
for((s, s2, bytes) <- Seq(
("rj", "re", Array[Byte](34)),
("4ramr45", "4ramr4a", Array[Byte](-47,48,-78,107)))
) {
encode(decode(s)) shouldEqual s2 // don't need == s
decode(s) shouldEqual bytes
decode(s) shouldEqual decode(s2)
}
}

private val encTbl = "ybndrfg8ejkmcpqxot1uwisza345h769"
def rndChars: Stream[Char] = {
def c = encTbl charAt (Random nextInt 32)
Stream continually c
}

it should "en/decode same as in the old implementation" in {
for (_ <- 0 to 20) {
val bytes = new Array[Byte](Random.nextInt(100))
Random.nextBytes(bytes)
encode(bytes) shouldEqual Old.encode(bytes)

val s = rndChars.take(Random.nextInt(100)).mkString
decode(s) shouldEqual Old.decode(s)
}
}

it should "`decode` case-insensitive" in {
for (_ <- 0 to 20) {
val s = rndChars.take(Random.nextInt(100)).mkString
decode(s) shouldEqual decode(s.toUpperCase())
}
}

it should "`decode` don't break on invalid z-base-32 input" in {
def invalidChar(c: Char) = !encTbl.contains(c.toLower)
@tailrec def invalidInput(): String = {
val r = Random.alphanumeric.take(Random.nextInt(100))
if (r exists invalidChar) r.mkString
else invalidInput()
}

for (_ <- 0 to 20) {
val s = invalidInput()
// println(s)
// println(s.map(c => if (invalidChar(c)) " " else c).mkString)
decode(s) shouldEqual decode(s filterNot invalidChar)
}
}
}

0 comments on commit c076eb6

Please sign in to comment.