Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom headers support for WebSocket connection (502) #554

Merged
merged 1 commit into from
Jun 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions MQTTv3.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,20 @@ public class MqttPublishSample {
}
```

## Adding custom headers for Websocket connection

The included code below is a extended basic sample that connects to a server with custom headers.

```
MqttClient client = new MqttClient("wss://<BROKER_URI>", "MyClient");
MqttConnectOptions connectOptions = new MqttConnectOptions();
Properties properties = new Properties();
properties.setProperty("X-Amz-CustomAuthorizer-Name", <SOME_VALUE>);
properties.setProperty("X-Amz-CustomAuthorizer-Signature", <SOME_VALUE>);
properties.setProperty(<SOME_VALUE>, <SOME_VALUE>);
connectOptions.setCustomWebSocketHeaders(properties);
client.connect(connectOptions);
```
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ else if ((factory instanceof SSLSocketFactory) == false) {
else if (factory instanceof SSLSocketFactory) {
throw ExceptionHelper.createMqttException(MqttException.REASON_CODE_SOCKET_FACTORY_MISMATCH);
}
netModule = new WebSocketNetworkModule(factory, address, host, port, clientId);
netModule = new WebSocketNetworkModule(factory, address, host, port, clientId, options.getCustomWebSocketHeaders());
((WebSocketNetworkModule)netModule).setConnectTimeout(options.getConnectionTimeout());
break;
case MqttConnectOptions.URI_TYPE_WSS:
Expand All @@ -656,7 +656,7 @@ else if ((factory instanceof SSLSocketFactory) == false) {
}

// Create the network module...
netModule = new WebSocketSecureNetworkModule((SSLSocketFactory) factory, address, host, port, clientId);
netModule = new WebSocketSecureNetworkModule((SSLSocketFactory) factory, address, host, port, clientId, options.getCustomWebSocketHeaders());
((WebSocketSecureNetworkModule)netModule).setSSLhandshakeTimeout(options.getConnectionTimeout());
((WebSocketSecureNetworkModule)netModule).setSSLHostnameVerifier(options.getSSLHostnameVerifier());
((WebSocketSecureNetworkModule)netModule).setHttpsHostnameVerificationEnabled(options.isHttpsHostnameVerificationEnabled());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public class MqttConnectOptions {
private int mqttVersion = MQTT_VERSION_DEFAULT;
private boolean automaticReconnect = false;
private int maxReconnectDelay = 128000;
private Properties customWebSocketHeaders = null;

/**
* Constructs a new <code>MqttConnectOptions</code> object using the
Expand Down Expand Up @@ -650,6 +651,20 @@ public Properties getDebug() {
return p;
}

/**
* Sets the Custom WebSocket Headers for the WebSocket Connection.
*
* @param props The custom websocket headers {@link Properties}
*/

public void setCustomWebSocketHeaders(Properties props) {
this.customWebSocketHeaders = props;
}

public Properties getCustomWebSocketHeaders() {
return customWebSocketHeaders;
}

public String toString() {
return Debug.dumpProperties(getDebug(), "Connection options");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.Properties;
import java.util.Set;
import java.util.Iterator;
/**
* Helper class to execute a WebSocket Handshake.
*/
Expand All @@ -52,14 +55,15 @@ public class WebSocketHandshake {
String uri;
String host;
int port;
Properties customWebSocketHeaders;


public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port){
public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port, Properties customWebSocketHeaders){
this.input = input;
this.output = output;
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
}


Expand Down Expand Up @@ -108,6 +112,16 @@ private void sendHandshakeRequest(String key) throws IOException{
pw.print("Sec-WebSocket-Protocol: mqtt" + LINE_SEPARATOR);
pw.print("Sec-WebSocket-Version: 13" + LINE_SEPARATOR);

if (customWebSocketHeaders != null) {
Set keys = customWebSocketHeaders.keySet();
Iterator i = keys.iterator();
while (i.hasNext()) {
String k = (String) i.next();
String value = customWebSocketHeaders.getProperty(k);
pw.print(k + ": " + value + LINE_SEPARATOR);
}
}

String userInfo = srvUri.getUserInfo();
if(userInfo != null) {
pw.print("Authorization: Basic " + Base64.encode(userInfo) + LINE_SEPARATOR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;
import java.util.Properties;

import javax.net.SocketFactory;

Expand All @@ -37,6 +38,7 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
private String uri;
private String host;
private int port;
private Properties customWebsocketHeaders;
private PipedInputStream pipedInputStream;
private WebSocketReceiver webSocketReceiver;
ByteBuffer recievedPayload;
Expand All @@ -47,20 +49,21 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
* Frame before passing it through to the real socket.
*/
private ByteArrayOutputStream outputStream = new ExtendedByteArrayOutputStream(this);
public WebSocketNetworkModule(SocketFactory factory, String uri, String host, int port, String resourceContext){

public WebSocketNetworkModule(SocketFactory factory, String uri, String host, int port, String resourceContext, Properties customWebsocketHeaders){
super(factory, host, port, resourceContext);
this.uri = uri;
this.host = host;
this.port = port;
this.customWebsocketHeaders = customWebsocketHeaders;
this.pipedInputStream = new PipedInputStream();

log.setResourceName(resourceContext);
}

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port, customWebsocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("webSocketReceiver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;

import java.util.Properties;
import javax.net.ssl.SSLSocketFactory;

import org.eclipse.paho.client.mqttv3.MqttException;
Expand All @@ -39,6 +39,7 @@ public class WebSocketSecureNetworkModule extends SSLNetworkModule{
private String uri;
private String host;
private int port;
private Properties customWebSocketHeaders;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why is it a Properties field now, instead of a Map? Properties is deprecated and the reading/writing abilities of Properties are not used in this case? Can I issue a PR to change this to Map please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Properties is used for sslClientProps.
Sure, Map is better.
New pull request - #555

ByteBuffer recievedPayload;

/**
Expand All @@ -48,18 +49,19 @@ public class WebSocketSecureNetworkModule extends SSLNetworkModule{
*/
private ByteArrayOutputStream outputStream = new ExtendedByteArrayOutputStream(this);

public WebSocketSecureNetworkModule(SSLSocketFactory factory, String uri, String host, int port, String clientId) {
public WebSocketSecureNetworkModule(SSLSocketFactory factory, String uri, String host, int port, String clientId, Properties customWebSocketHeaders) {
super(factory, host, port, clientId);
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
this.pipedInputStream = new PipedInputStream();
log.setResourceName(clientId);
}

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port, customWebSocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("WssSocketReceiver");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,7 @@ private NetworkModule createNetworkModule(String address, MqttConnectionOptions
}
netModule = new WebSocketNetworkModule(factory, address, host, port, this.mqttSession.getClientId());
((WebSocketNetworkModule) netModule).setConnectTimeout(options.getConnectionTimeout());
((WebSocketNetworkModule) netModule).setCustomWebSocketHeaders(options.getCustomWebSocketHeaders());
break;
case WSS:
if (port == -1) {
Expand All @@ -794,6 +795,7 @@ private NetworkModule createNetworkModule(String address, MqttConnectionOptions
netModule = new WebSocketSecureNetworkModule((SSLSocketFactory) factory, address, host, port,
this.mqttSession.getClientId());
((WebSocketSecureNetworkModule) netModule).setSSLhandshakeTimeout(options.getConnectionTimeout());
((WebSocketSecureNetworkModule) netModule).setCustomWebSocketHeaders(options.getCustomWebSocketHeaders());
// Ciphers suites need to be set, if they are available
if (wSSFactoryFactory != null) {
String[] enabledCiphers = wSSFactoryFactory.getEnabledCipherSuites(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import javax.net.SocketFactory;
Expand Down Expand Up @@ -117,7 +119,7 @@ public void setWillMessageProperties(MqttProperties willMessageProperties) {
private SocketFactory socketFactory; // SocketFactory to be used to connect
private Properties sslClientProps = null; // SSL Client Properties
private HostnameVerifier sslHostnameVerifier = null; // SSL Hostname Verifier

private Map<String, String> customWebSocketHeaders;
/**
* Returns the MQTT version.
*
Expand Down Expand Up @@ -960,6 +962,19 @@ public Properties getDebug() {
return p;
}

/**
* Sets the Custom WebSocket Headers for the WebSocket Connection.
*
* @param headers The custom websocket headers {@link Properties}
*/
public void setCustomWebSocketHeaders(Map<String, String> headers) {
this.customWebSocketHeaders = Collections.unmodifiableMap(headers);
}

public Map<String, String> getCustomWebSocketHeaders() {
return customWebSocketHeaders;
}

public String toString() {
return Debug.dumpProperties(getDebug(), "Connection options");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ public class WebSocketHandshake {
String uri;
String host;
int port;
Map<String, String> customWebSocketHeaders;

public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port) {
public WebSocketHandshake(InputStream input, OutputStream output, String uri, String host, int port, Map<String, String> customWebSocketHeaders) {
this.input = input;
this.output = output;
this.uri = uri;
this.host = host;
this.port = port;
this.customWebSocketHeaders = customWebSocketHeaders;
}

/**
Expand Down Expand Up @@ -108,6 +110,12 @@ private void sendHandshakeRequest(String key) {
pw.print("Sec-WebSocket-Protocol: mqtt" + LINE_SEPARATOR);
pw.print("Sec-WebSocket-Version: 13" + LINE_SEPARATOR);

if (customWebSocketHeaders != null) {
customWebSocketHeaders.entrySet().forEach(entry ->
pw.print(entry.getKey() + ": " + entry.getValue() + LINE_SEPARATOR)
);
}

String userInfo = srvUri.getUserInfo();
if (userInfo != null) {
pw.print("Authorization: Basic " + Base64.encode(userInfo) + LINE_SEPARATOR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;

import java.util.Collections;
import java.util.Map;
import javax.net.SocketFactory;

import org.eclipse.paho.mqttv5.client.internal.TCPNetworkModule;
Expand All @@ -40,7 +41,8 @@ public class WebSocketNetworkModule extends TCPNetworkModule {
private PipedInputStream pipedInputStream;
private WebSocketReceiver webSocketReceiver;
ByteBuffer recievedPayload;

Map<String, String> customWebSocketHeaders;

/**
* Overrides the flush method.
* This allows us to encode the MQTT payload into a WebSocket
Expand All @@ -60,7 +62,7 @@ public WebSocketNetworkModule(SocketFactory factory, String uri, String host, in

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(getSocketInputStream(), getSocketOutputStream(), uri, host, port, customWebSocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("webSocketReceiver");
Expand All @@ -81,7 +83,11 @@ public InputStream getInputStream() throws IOException {
public OutputStream getOutputStream() throws IOException {
return outputStream;
}


public void setCustomWebSocketHeaders(Map<String, String> customWebSocketHeaders) {
this.customWebSocketHeaders = customWebSocketHeaders;
}

/**
* Stops the module, by closing the TCP socket.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.nio.ByteBuffer;
import java.util.Map;

import javax.net.ssl.SSLSocketFactory;

Expand All @@ -40,7 +41,8 @@ public class WebSocketSecureNetworkModule extends SSLNetworkModule{
private String host;
private int port;
ByteBuffer recievedPayload;

Map<String, String> customWebSocketHeaders;

/**
* Overrides the flush method.
* This allows us to encode the MQTT payload into a WebSocket
Expand All @@ -59,7 +61,7 @@ public WebSocketSecureNetworkModule(SSLSocketFactory factory, String uri, String

public void start() throws IOException, MqttException {
super.start();
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port);
WebSocketHandshake handshake = new WebSocketHandshake(super.getInputStream(), super.getOutputStream(), uri, host, port, customWebSocketHeaders);
handshake.execute();
this.webSocketReceiver = new WebSocketReceiver(getSocketInputStream(), pipedInputStream);
webSocketReceiver.start("WssSocketReceiver");
Expand All @@ -82,6 +84,10 @@ public OutputStream getOutputStream() throws IOException {
return outputStream;
}

public void setCustomWebSocketHeaders(Map<String, String> customWebSocketHeaders) {
this.customWebSocketHeaders = customWebSocketHeaders;
}

public void stop() throws IOException {
// Creating Close Frame
WebSocketFrame frame = new WebSocketFrame((byte)0x08, true, "1000".getBytes());
Expand Down