Skip to content

Commit

Permalink
Added custom headers support for WebSocket connection (502)
Browse files Browse the repository at this point in the history
Signed-off-by: Vitalii <vitalii.vlasiuk@temy.co>
  • Loading branch information
vit21ik authored and vitalii-temy committed Jun 8, 2018
1 parent dbb3e9f commit d669d22
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 18 deletions.
17 changes: 17 additions & 0 deletions MQTTv3.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,20 @@ public class MqttPublishSample {
}
```

## Adding custom headers for Websocket connection

The included code below is a extended basic sample that connects to a server with custom headers.

```
MqttClient client = new MqttClient("wss://<BROKER_URI>", "MyClient");
MqttConnectOptions connectOptions = new MqttConnectOptions();
Properties properties = new Properties();
properties.setProperty("X-Amz-CustomAuthorizer-Name", <SOME_VALUE>);
properties.setProperty("X-Amz-CustomAuthorizer-Signature", <SOME_VALUE>);
properties.setProperty(<SOME_VALUE>, <SOME_VALUE>);
connectOptions.setCustomWebSocketHeaders(properties);
client.connect(connectOptions);
```
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ else if ((factory instanceof SSLSocketFactory) == false) {
else if (factory instanceof SSLSocketFactory) {
throw ExceptionHelper.createMqttException(MqttException.REASON_CODE_SOCKET_FACTORY_MISMATCH);
}
netModule = new WebSocketNetworkModule(factory, address, host, port, clientId);
netModule = new WebSocketNetworkModule(factory, address, host, port, clientId, options.getCustomWebSocketHeaders());
((WebSocketNetworkModule)netModule).setConnectTimeout(options.getConnectionTimeout());
break;
case MqttConnectOptions.URI_TYPE_WSS:
Expand All @@ -656,7 +656,7 @@ else if ((factory instanceof SSLSocketFactory) == false) {
}

// Create the network module...
netModule = new WebSocketSecureNetworkModule((SSLSocketFactory) factory, address, host, port, clientId);
netModule = new WebSocketSecureNetworkModule((SSLSocketFactory) factory, address, host, port, clientId, options.getCustomWebSocketHeaders());
((WebSocketSecureNetworkModule)netModule).setSSLhandshakeTimeout(options.getConnectionTimeout());
((WebSocketSecureNetworkModule)netModule).setSSLHostnameVerifier(options.getSSLHostnameVerifier());
((WebSocketSecureNetworkModule)netModule).setHttpsHostnameVerificationEnabled(options.isHttpsHostnameVerificationEnabled());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public class MqttConnectOptions {
private int mqttVersion = MQTT_VERSION_DEFAULT;
private boolean automaticReconnect = false;
private int maxReconnectDelay = 128000;
private Properties customWebSocketHeaders = null;

/**
* Constructs a new <code>MqttConnectOptions</code> object using the
Expand Down Expand Up @@ -650,6 +651,20 @@ public Properties getDebug() {
return p;
}

/**
* Sets the Custom WebSocket Headers for the WebSocket Connection.
*
* @param props The custom websocket headers {@link Properties}
*/

public void setCustomWebSocketHeaders(Properties props) {
this.customWebSocketHeaders = props;
}

public Properties getCustomWebSocketHeaders() {
return customWebSocketHeaders;
}

public String toString() {
return Debug.dumpProperties(getDebug(), "Connection options");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.Properties;
import java.util.Set;
import java.util.Iterator;
/**
* Helper class to execute a WebSocket Handshake.
*/
Expand All @@ -52,14 +55,15 @@ public class WebSocketHandshake {
String uri;
String host;
int port;
Properties customWebSocketHeaders;


public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port){
public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port, Properties customWebSocketHeaders){
this.input = input;
this.output = output;
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
}


Expand Down Expand Up @@ -108,6 +112,16 @@ private void sendHandshakeRequest(String key) throws IOException{
pw.print("Sec-WebSocket-Protocol: mqtt" + LINE_SEPARATOR);
pw.print("Sec-WebSocket-Version: 13" + LINE_SEPARATOR);

if (customWebSocketHeaders != null) {
Set keys = customWebSocketHeaders.keySet();
Iterator i = keys.iterator();
while (i.hasNext()) {
String k = (String) i.next();
String value = customWebSocketHeaders.getProperty(k);
pw.print(k + ": " + value + LINE_SEPARATOR);
}
}

String userInfo = srvUri.getUserInfo();
if(userInfo != null) {
pw.print("Authorization: Basic " + Base64.encode(userInfo) + LINE_SEPARATOR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;
import java.util.Properties;

import javax.net.SocketFactory;

Expand All @@ -37,6 +38,7 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
private String uri;
private String host;
private int port;
private Properties customWebsocketHeaders;
private PipedInputStream pipedInputStream;
private WebSocketReceiver webSocketReceiver;
ByteBuffer recievedPayload;
Expand All @@ -47,20 +49,21 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
* Frame before passing it through to the real socket.
*/
private ByteArrayOutputStream outputStream = new ExtendedByteArrayOutputStream(this);
public WebSocketNetworkModule(SocketFactory factory, String uri, String host, int port, String resourceContext){

public WebSocketNetworkModule(SocketFactory factory, String uri, String host, int port, String resourceContext, Properties customWebsocketHeaders){
super(factory, host, port, resourceContext);
this.uri = uri;
this.host = host;
this.port = port;
this.customWebsocketHeaders = customWebsocketHeaders;
this.pipedInputStream = new PipedInputStream();

log.setResourceName(resourceContext);
}

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port, customWebsocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("webSocketReceiver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;

import java.util.Properties;
import javax.net.ssl.SSLSocketFactory;

import org.eclipse.paho.client.mqttv3.MqttException;
Expand All @@ -39,6 +39,7 @@ public class WebSocketSecureNetworkModule extends SSLNetworkModule{
private String uri;
private String host;
private int port;
private Properties customWebSocketHeaders;
ByteBuffer recievedPayload;

/**
Expand All @@ -48,18 +49,19 @@ public class WebSocketSecureNetworkModule extends SSLNetworkModule{
*/
private ByteArrayOutputStream outputStream = new ExtendedByteArrayOutputStream(this);

public WebSocketSecureNetworkModule(SSLSocketFactory factory, String uri, String host, int port, String clientId) {
public WebSocketSecureNetworkModule(SSLSocketFactory factory, String uri, String host, int port, String clientId, Properties customWebSocketHeaders) {
super(factory, host, port, clientId);
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
this.pipedInputStream = new PipedInputStream();
log.setResourceName(clientId);
}

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port, customWebSocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("WssSocketReceiver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ private NetworkModule createNetworkModule(String address, MqttConnectionOptions
}
netModule = new WebSocketNetworkModule(factory, address, host, port, this.mqttSession.getClientId());
((WebSocketNetworkModule) netModule).setConnectTimeout(options.getConnectionTimeout());
((WebSocketNetworkModule) netModule).setCustomWebSocketHeaders(options.getCustomWebSocketHeaders());
break;
case WSS:
if (port == -1) {
Expand All @@ -794,6 +795,7 @@ private NetworkModule createNetworkModule(String address, MqttConnectionOptions
netModule = new WebSocketSecureNetworkModule((SSLSocketFactory) factory, address, host, port,
this.mqttSession.getClientId());
((WebSocketSecureNetworkModule) netModule).setSSLhandshakeTimeout(options.getConnectionTimeout());
((WebSocketSecureNetworkModule) netModule).setCustomWebSocketHeaders(options.getCustomWebSocketHeaders());
// Ciphers suites need to be set, if they are available
if (wSSFactoryFactory != null) {
String[] enabledCiphers = wSSFactoryFactory.getEnabledCipherSuites(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import javax.net.SocketFactory;
Expand Down Expand Up @@ -117,7 +119,7 @@ public void setWillMessageProperties(MqttProperties willMessageProperties) {
private SocketFactory socketFactory; // SocketFactory to be used to connect
private Properties sslClientProps = null; // SSL Client Properties
private HostnameVerifier sslHostnameVerifier = null; // SSL Hostname Verifier

private Map<String, String> customWebSocketHeaders;
/**
* Returns the MQTT version.
*
Expand Down Expand Up @@ -960,6 +962,19 @@ public Properties getDebug() {
return p;
}

/**
* Sets the Custom WebSocket Headers for the WebSocket Connection.
*
* @param headers The custom websocket headers {@link Properties}
*/
public void setCustomWebSocketHeaders(Map<String, String> headers) {
this.customWebSocketHeaders = Collections.unmodifiableMap(headers);
}

public Map<String, String> getCustomWebSocketHeaders() {
return customWebSocketHeaders;
}

public String toString() {
return Debug.dumpProperties(getDebug(), "Connection options");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ public class WebSocketHandshake {
String uri;
String host;
int port;
Map<String, String> customWebSocketHeaders;

public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port) {
public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port, Map<String, String> customWebSocketHeaders) {
this.input = input;
this.output = output;
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
}

/**
Expand Down Expand Up @@ -108,6 +110,12 @@ private void sendHandshakeRequest(String key) {
pw.print("Sec-WebSocket-Protocol: mqtt" + LINE_SEPARATOR);
pw.print("Sec-WebSocket-Version: 13" + LINE_SEPARATOR);

if (customWebSocketHeaders != null) {
customWebSocketHeaders.entrySet().forEach(entry ->
pw.print(entry.getKey() + ": " + entry.getValue() + LINE_SEPARATOR)
);
}

String userInfo = srvUri.getUserInfo();
if (userInfo != null) {
pw.print("Authorization: Basic " + Base64.encode(userInfo) + LINE_SEPARATOR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;

import java.util.Collections;
import java.util.Map;
import javax.net.SocketFactory;

import org.eclipse.paho.mqttv5.client.internal.TCPNetworkModule;
Expand All @@ -40,7 +41,8 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
private PipedInputStream pipedInputStream;
private WebSocketReceiver webSocketReceiver;
ByteBuffer recievedPayload;

Map<String, String> customWebSocketHeaders;

/**
* Overrides the flush method.
* This allows us to encode the MQTT payload into a WebSocket
Expand All @@ -60,7 +62,7 @@ public WebSocketNetworkModule(SocketFactory factory, String uri, String host, in

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port, customWebSocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("webSocketReceiver");
Expand All @@ -81,7 +83,11 @@ public InputStream getInputStream() throws IOException {
public OutputStream getOutputStream() throws IOException {
return outputStream;
}


public void setCustomWebSocketHeaders(Map<String, String> customWebSocketHeaders) {
this.customWebSocketHeaders = customWebSocketHeaders;
}

/**
* Stops the module, by closing the TCP socket.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;
import java.util.Map;

import javax.net.ssl.SSLSocketFactory;

Expand All @@ -40,7 +41,8 @@ public class WebSocketSecureNetworkModule extends SSLNetworkModule{
private String host;
private int port;
ByteBuffer recievedPayload;

Map<String, String> customWebSocketHeaders;

/**
* Overrides the flush method.
* This allows us to encode the MQTT payload into a WebSocket
Expand All @@ -59,7 +61,7 @@ public WebSocketSecureNetworkModule(SSLSocketFactory factory, String uri, String

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port, customWebSocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("WssSocketReceiver");
Expand All @@ -82,6 +84,10 @@ public OutputStream getOutputStream() throws IOException {
return outputStream;
}

public void setCustomWebSocketHeaders(Map<String, String> customWebSocketHeaders) {
this.customWebSocketHeaders = customWebSocketHeaders;
}

public void stop() throws IOException {
// Creating Close Frame
WebSocketFrame frame = new WebSocketFrame((byte)0x08, true, "1000".getBytes());
Expand Down

0 comments on commit d669d22

Please sign in to comment.