diff --git a/src/com/trilead/ssh2/crypto/cipher/CipherInputStream.java b/src/com/trilead/ssh2/crypto/cipher/CipherInputStream.java index c9055ab5..851cc284 100644 --- a/src/com/trilead/ssh2/crypto/cipher/CipherInputStream.java +++ b/src/com/trilead/ssh2/crypto/cipher/CipherInputStream.java @@ -1,65 +1,33 @@ package com.trilead.ssh2.crypto.cipher; +import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStream; /** * CipherInputStream. - * + * * @author Christian Plattner, plattner@trilead.com * @version $Id: CipherInputStream.java,v 1.1 2007/10/15 12:49:55 cplattne Exp $ */ public class CipherInputStream { - BlockCipher currentCipher; - InputStream bi; - byte[] buffer; - byte[] enc; - int blockSize; - int pos; - - /* - * We cannot use java.io.BufferedInputStream, since that is not available in - * J2ME. Everything could be improved alot here. - */ - - final int BUFF_SIZE = 2048; - byte[] input_buffer = new byte[BUFF_SIZE]; - int input_buffer_pos = 0; - int input_buffer_size = 0; + private BlockCipher currentCipher; + private final BufferedInputStream bi; + private byte[] buffer; + private byte[] enc; + private int blockSize; + private int pos; public CipherInputStream(BlockCipher tc, InputStream bi) { - this.bi = bi; - changeCipher(tc); - } - - private int fill_buffer() throws IOException - { - input_buffer_pos = 0; - input_buffer_size = bi.read(input_buffer, 0, BUFF_SIZE); - return input_buffer_size; - } - - private int internal_read(byte[] b, int off, int len) throws IOException - { - if (input_buffer_size < 0) - return -1; - - if (input_buffer_pos >= input_buffer_size) - { - if (fill_buffer() <= 0) - return -1; + if (bi instanceof BufferedInputStream) { + this.bi = (BufferedInputStream) bi; + } else { + this.bi = new BufferedInputStream(bi); } - - int avail = input_buffer_size - input_buffer_pos; - int thiscopy = (len > avail) ? avail : len; - - System.arraycopy(input_buffer, input_buffer_pos, b, off, thiscopy); - input_buffer_pos += thiscopy; - - return thiscopy; + changeCipher(tc); } public void changeCipher(BlockCipher bc) @@ -76,7 +44,7 @@ private void getBlock() throws IOException int n = 0; while (n < blockSize) { - int len = internal_read(enc, n, blockSize - n); + int len = bi.read(enc, n, blockSize - n); if (len < 0) throw new IOException("Cannot read full block, EOF reached."); n += len; @@ -134,11 +102,32 @@ public int readPlain(byte[] b, int off, int len) throws IOException int n = 0; while (n < len) { - int cnt = internal_read(b, off + n, len - n); + int cnt = bi.read(b, off + n, len - n); if (cnt < 0) throw new IOException("Cannot fill buffer, EOF reached."); n += cnt; } return n; } + + public int peekPlain(byte[] b, int off, int len) throws IOException + { + if (pos != blockSize) + throw new IOException("Cannot read plain since crypto buffer is not aligned."); + int n = 0; + + bi.mark(len); + try { + while (n < len) { + int cnt = bi.read(b, off + n, len - n); + if (cnt < 0) + throw new IOException("Cannot fill buffer, EOF reached."); + n += cnt; + } + } finally { + bi.reset(); + } + + return n; + } } diff --git a/src/com/trilead/ssh2/crypto/cipher/CipherOutputStream.java b/src/com/trilead/ssh2/crypto/cipher/CipherOutputStream.java index cf0db4af..9bb8ba0e 100644 --- a/src/com/trilead/ssh2/crypto/cipher/CipherOutputStream.java +++ b/src/com/trilead/ssh2/crypto/cipher/CipherOutputStream.java @@ -1,6 +1,7 @@ package com.trilead.ssh2.crypto.cipher; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; @@ -18,6 +19,8 @@ public class CipherOutputStream byte[] enc; int blockSize; int pos; + private boolean recordingOutput; + private final ByteArrayOutputStream recordingOutputStream = new ByteArrayOutputStream(); /* * We cannot use java.io.BufferedOutputStream, since that is not available @@ -86,6 +89,17 @@ public void changeCipher(BlockCipher bc) enc = new byte[blockSize]; pos = 0; } + + public void startRecording() { + recordingOutput = true; + } + + public byte[] getRecordedOutput() { + recordingOutput = false; + byte[] recordedOutput = recordingOutputStream.toByteArray(); + recordingOutputStream.reset(); + return recordedOutput; + } private void writeBlock() throws IOException { @@ -100,6 +114,10 @@ private void writeBlock() throws IOException internal_write(enc, 0, blockSize); pos = 0; + + if (recordingOutput) { + recordingOutputStream.write(enc, 0, blockSize); + } } public void write(byte[] src, int off, int len) throws IOException diff --git a/src/com/trilead/ssh2/crypto/digest/MessageMac.java b/src/com/trilead/ssh2/crypto/digest/MessageMac.java index 8d959717..7d77c5a8 100644 --- a/src/com/trilead/ssh2/crypto/digest/MessageMac.java +++ b/src/com/trilead/ssh2/crypto/digest/MessageMac.java @@ -2,6 +2,7 @@ package com.trilead.ssh2.crypto.digest; import javax.crypto.Mac; +import javax.crypto.ShortBufferException; import javax.crypto.spec.SecretKeySpec; import java.security.GeneralSecurityException; import java.util.ArrayList; @@ -10,12 +11,27 @@ public final class MessageMac extends MAC { private final Mac messageMac; + private boolean encryptThenMac = false; + private final byte[] buffer; + private final int outSize; public MessageMac(String type, byte[] key) { super(type, key); try { messageMac = Mac.getInstance(Hmac.getHmac(type).getAlgorithm()); + + int macSize = messageMac.getMacLength(); + + if (type.endsWith("-96")) { + outSize = 12; + buffer = new byte[macSize]; + } else { + outSize = macSize; + buffer = null; + } + + encryptThenMac = Hmac.getHmac(type).isEtm(); messageMac.init(new SecretKeySpec(key, type)); } catch (GeneralSecurityException ex) { throw new IllegalArgumentException("Could not create Mac", ex); @@ -55,32 +71,48 @@ public final void update(byte[] packetdata, int off, int len) } public final void getMac(byte[] out, int off) { - byte[] mac = messageMac.doFinal(); - System.arraycopy(mac, off, out, 0, mac.length - off); + try { + if (buffer != null) { + messageMac.doFinal(buffer, 0); + System.arraycopy(buffer, 0, out, off, out.length - off); + } else { + messageMac.doFinal(out, off); + } + } catch (ShortBufferException e) { + throw new IllegalStateException(e); + } } public final int size() { - return messageMac.getMacLength(); + return outSize; + } + + public final boolean isEncryptThenMac() + { + return encryptThenMac; } - private enum Hmac { - HMAC_MD5_96("hmac-md5-96", "HmacMD5", 16), - HMAC_MD5("hmac-md5", "HmacMD5", 16), - HMAC_SHA1_96("hmac-sha1-96", "HmacSHA1", 20), - HMAC_SHA1("hmac-sha1", "HmacSHA1", 20), - HMAC_SHA2_256("hmac-sha2-256", "HmacSHA256", 32), - HMAC_SHA2_512("hmac-sha2-512", "HmacSHA512", 64); + HMAC_MD5_96("hmac-md5-96", "HmacMD5", 16,false), + HMAC_MD5("hmac-md5", "HmacMD5", 16,false), + HMAC_SHA1_96("hmac-sha1-96", "HmacSHA1", 20,false), + HMAC_SHA1("hmac-sha1", "HmacSHA1", 20,false), + HMAC_SHA2_256("hmac-sha2-256", "HmacSHA256", 32,false), + HMAC_SHA2_512("hmac-sha2-512", "HmacSHA512", 64,false), + HMAC_SHA2_256_ETM("hmac-sha2-256-etm@openssh.com", "HmacSHA256", 32,true), + HMAC_SHA2_512_ETM("hmac-sha2-512-etm@openssh.com", "HmacSHA512", 64,true); private String type; private String algorithm; private int length; + private boolean isEtm; - Hmac(String type, String algorithm, int length) { + Hmac(String type, String algorithm, int length,boolean isEtm) { this.type = type; this.algorithm = algorithm; this.length = length; + this.isEtm = isEtm; } public String getType() { @@ -94,6 +126,10 @@ public String getAlgorithm() { public int getLength() { return length; } + + public boolean isEtm() { + return isEtm; + } private static Hmac getHmac(String type) { for (Hmac hmac : values()) { diff --git a/src/com/trilead/ssh2/transport/TransportConnection.java b/src/com/trilead/ssh2/transport/TransportConnection.java index 77d167d8..e0b87350 100644 --- a/src/com/trilead/ssh2/transport/TransportConnection.java +++ b/src/com/trilead/ssh2/transport/TransportConnection.java @@ -10,7 +10,7 @@ import com.trilead.ssh2.crypto.cipher.CipherInputStream; import com.trilead.ssh2.crypto.cipher.CipherOutputStream; import com.trilead.ssh2.crypto.cipher.NullCipher; -import com.trilead.ssh2.crypto.digest.MAC; +import com.trilead.ssh2.crypto.digest.MessageMac; import com.trilead.ssh2.log.Logger; import com.trilead.ssh2.packets.Packets; @@ -37,13 +37,13 @@ public class TransportConnection /* Depends on current MAC and CIPHER */ - MAC send_mac; + MessageMac send_mac; byte[] send_mac_buffer; int send_padd_blocksize = 8; - MAC recv_mac; + MessageMac recv_mac; byte[] recv_mac_buffer; @@ -74,7 +74,7 @@ public TransportConnection(InputStream is, OutputStream os, SecureRandom rnd) this.rnd = rnd; } - public void changeRecvCipher(BlockCipher bc, MAC mac) + public void changeRecvCipher(BlockCipher bc, MessageMac mac) { cis.changeCipher(bc); recv_mac = mac; @@ -85,7 +85,7 @@ public void changeRecvCipher(BlockCipher bc, MAC mac) recv_padd_blocksize = 8; } - public void changeSendCipher(BlockCipher bc, MAC mac) + public void changeSendCipher(BlockCipher bc, MessageMac mac) { if ((bc instanceof NullCipher) == false) { @@ -125,7 +125,9 @@ public void sendMessage(byte[] message, int off, int len, int padd) throws IOExc else if (padd > 64) padd = 64; - int packet_len = 5 + len + padd; /* Minimum allowed padding is 4 */ + boolean encryptThenMac = send_mac != null && send_mac.isEncryptThenMac(); + + int packet_len = (encryptThenMac ? 1 : 5) + len + padd; /* Minimum allowed padding is 4 */ int slack = packet_len % send_padd_blocksize; @@ -137,7 +139,7 @@ else if (padd > 64) if (packet_len < 16) packet_len = 16; - int padd_len = packet_len - (5 + len); + int padd_len = packet_len - ((encryptThenMac ? 1 : 5) + len); if (useRandomPadding) { @@ -169,22 +171,37 @@ else if (padd > 64) */ } - send_packet_header_buffer[0] = (byte) ((packet_len - 4) >> 24); - send_packet_header_buffer[1] = (byte) ((packet_len - 4) >> 16); - send_packet_header_buffer[2] = (byte) ((packet_len - 4) >> 8); - send_packet_header_buffer[3] = (byte) ((packet_len - 4)); + int payloadLength = encryptThenMac ? packet_len : packet_len - 4; + send_packet_header_buffer[0] = (byte) (packet_len >> 24); + send_packet_header_buffer[1] = (byte) (payloadLength >> 16); + send_packet_header_buffer[2] = (byte) (payloadLength >> 8); + send_packet_header_buffer[3] = (byte) (payloadLength); send_packet_header_buffer[4] = (byte) padd_len; - cos.write(send_packet_header_buffer, 0, 5); + if (send_mac != null && send_mac.isEncryptThenMac()) { + cos.writePlain(send_packet_header_buffer, 0, 4); + cos.startRecording(); + cos.write(send_packet_header_buffer, 4, 1); + } else { + cos.write(send_packet_header_buffer, 0, 5); + } cos.write(message, off, len); cos.write(send_padding_buffer, 0, padd_len); if (send_mac != null) { send_mac.initMac(send_seq_number); - send_mac.update(send_packet_header_buffer, 0, 5); - send_mac.update(message, off, len); - send_mac.update(send_padding_buffer, 0, padd_len); + + + if (send_mac.isEncryptThenMac()) { + send_mac.update(send_packet_header_buffer, 0, 4); + byte[] encryptedMessage = cos.getRecordedOutput(); + send_mac.update(encryptedMessage, 0, encryptedMessage.length); + } else { + send_mac.update(send_packet_header_buffer, 0, 5); + send_mac.update(message, off, len); + send_mac.update(send_padding_buffer, 0, padd_len); + } send_mac.getMac(send_mac_buffer, 0); cos.writePlain(send_mac_buffer, 0, send_mac_buffer.length); @@ -227,47 +244,48 @@ public int peekNextMessageLength() throws IOException public int receiveMessage(byte buffer[], int off, int len) throws IOException { - if (recv_packet_header_present == false) - { - cis.read(recv_packet_header_buffer, 0, 5); - } - else - recv_packet_header_present = false; + final int packetLength; + final int payloadLength; + + if (recv_mac != null && recv_mac.isEncryptThenMac()) { + cis.readPlain(recv_packet_header_buffer, 0, 4); + packetLength = getPacketLength(recv_packet_header_buffer, true); - int packet_length = ((recv_packet_header_buffer[0] & 0xff) << 24) - | ((recv_packet_header_buffer[1] & 0xff) << 16) | ((recv_packet_header_buffer[2] & 0xff) << 8) - | ((recv_packet_header_buffer[3] & 0xff)); + recv_mac.initMac(recv_seq_number); + recv_mac.update(recv_packet_header_buffer, 0, 4); - int padding_length = recv_packet_header_buffer[4] & 0xff; + cis.peekPlain(buffer, off, packetLength + recv_mac_buffer.length); + System.arraycopy(buffer, off + packetLength, recv_mac_buffer, 0, recv_mac_buffer.length); - if (packet_length > TransportManager.MAX_PACKET_SIZE || packet_length < 12) - throw new IOException("Illegal packet size! (" + packet_length + ")"); + recv_mac.update(buffer, off, packetLength); + recv_mac.getMac(recv_mac_buffer_cmp, 0); - int payload_length = packet_length - padding_length - 1; + checkMacMatches(recv_mac_buffer, recv_mac_buffer_cmp); - if (payload_length < 0) - throw new IOException("Illegal padding_length in packet from remote (" + padding_length + ")"); + cis.read(recv_packet_header_buffer, 4, 1); + } else { + cis.read(recv_packet_header_buffer, 0, 5); + packetLength = getPacketLength(recv_packet_header_buffer, false); + } - if (payload_length >= len) - throw new IOException("Receive buffer too small (" + len + ", need " + payload_length + ")"); + int paddingLength = recv_packet_header_buffer[4] & 0xff; - cis.read(buffer, off, payload_length); - cis.read(recv_padding_buffer, 0, padding_length); + payloadLength = calculatePayloadLength(len, packetLength, paddingLength); - if (recv_mac != null) - { + cis.read(buffer, off, payloadLength); + cis.read(recv_padding_buffer, 0, paddingLength); + + if (recv_mac != null) { cis.readPlain(recv_mac_buffer, 0, recv_mac_buffer.length); - recv_mac.initMac(recv_seq_number); - recv_mac.update(recv_packet_header_buffer, 0, 5); - recv_mac.update(buffer, off, payload_length); - recv_mac.update(recv_padding_buffer, 0, padding_length); - recv_mac.getMac(recv_mac_buffer_cmp, 0); + if (!recv_mac.isEncryptThenMac()) { + recv_mac.initMac(recv_seq_number); + recv_mac.update(recv_packet_header_buffer, 0, 5); + recv_mac.update(buffer, off, payloadLength); + recv_mac.update(recv_padding_buffer, 0, paddingLength); + recv_mac.getMac(recv_mac_buffer_cmp, 0); - for (int i = 0; i < recv_mac_buffer.length; i++) - { - if (recv_mac_buffer[i] != recv_mac_buffer_cmp[i]) - throw new IOException("Remote sent corrupt MAC."); + checkMacMatches(recv_mac_buffer, recv_mac_buffer_cmp); } } @@ -275,10 +293,42 @@ public int receiveMessage(byte buffer[], int off, int len) throws IOException if (log.isEnabled()) { - log.log(90, "Received " + Packets.getMessageName(buffer[off] & 0xff) + " " + payload_length + log.log(90, "Received " + Packets.getMessageName(buffer[off] & 0xff) + " " + payloadLength + " bytes payload"); } - return payload_length; + return payloadLength; + } + + private static int calculatePayloadLength(int bufferLength, int packetLength, int paddingLength) throws IOException { + int payloadLength = packetLength - paddingLength - 1; + + if (payloadLength < 0) + throw new IOException("Illegal padding_length in packet from remote (" + paddingLength + ")"); + + if (payloadLength >= bufferLength) + throw new IOException("Receive buffer too small (" + bufferLength + ", need " + payloadLength + ")"); + + return payloadLength; + } + + private static void checkMacMatches(byte[] buf1, byte[] buf2) throws IOException { + int difference = 0; + for (int i = 0; i < buf1.length; i++) { + difference |= buf1[i] ^ buf2[i]; + } + if (difference != 0) + throw new IOException("Remote sent corrupt MAC."); + } + + private static int getPacketLength(byte[] packetHeader, boolean isEtm) throws IOException { + int packetLength = ((packetHeader[0] & 0xff) << 24) + | ((packetHeader[1] & 0xff) << 16) | ((packetHeader[2] & 0xff) << 8) + | ((packetHeader[3] & 0xff)); + + if (packetLength > 35000 || packetLength < (isEtm ? 8 : 12)) + throw new IOException("Illegal packet size! (" + packetLength + ")"); + + return packetLength; } } diff --git a/src/com/trilead/ssh2/transport/TransportManager.java b/src/com/trilead/ssh2/transport/TransportManager.java index 4bf0d561..3f19bae4 100644 --- a/src/com/trilead/ssh2/transport/TransportManager.java +++ b/src/com/trilead/ssh2/transport/TransportManager.java @@ -21,7 +21,7 @@ import com.trilead.ssh2.crypto.Base64; import com.trilead.ssh2.crypto.CryptoWishList; import com.trilead.ssh2.crypto.cipher.BlockCipher; -import com.trilead.ssh2.crypto.digest.MAC; +import com.trilead.ssh2.crypto.digest.MessageMac; import com.trilead.ssh2.log.Logger; import com.trilead.ssh2.packets.PacketDisconnect; import com.trilead.ssh2.packets.Packets; @@ -615,12 +615,12 @@ public void forceKeyExchange(CryptoWishList cwl, DHGexParameters dhgex) throws I km.initiateKEX(cwl, dhgex); } - public void changeRecvCipher(BlockCipher bc, MAC mac) + public void changeRecvCipher(BlockCipher bc, MessageMac mac) { tc.changeRecvCipher(bc, mac); } - public void changeSendCipher(BlockCipher bc, MAC mac) + public void changeSendCipher(BlockCipher bc, MessageMac mac) { tc.changeSendCipher(bc, mac); }