Skip to content

Commit 5092c82

Browse files
committed
Add validations to setSSLSocketFactory
1 parent f549391 commit 5092c82

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

src/main/java/dev/gustavoavila/websocketclient/WebSocketClient.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,10 @@ public abstract class WebSocketClient {
132132
*/
133133
private volatile Thread reconnectionThread;
134134

135-
136-
private SSLSocketFactory socketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault();
135+
/**
136+
* Allows to customize the SSL Socket factory instance
137+
*/
138+
private SSLSocketFactory sslSocketFactory;
137139

138140
/**
139141
* Initialize all the variables
@@ -153,11 +155,6 @@ public WebSocketClient(URI uri) {
153155
webSocketConnection = new WebSocketConnection();
154156
}
155157

156-
157-
public void setSSLSocketFactory(SSLSocketFactory sslSocketFactory) {
158-
socketFactory = sslSocketFactory;
159-
}
160-
161158
/**
162159
* Called when the WebSocket handshake has been accepted and the WebSocket
163160
* is ready to send and receive data
@@ -311,6 +308,21 @@ public void connect() {
311308
}
312309
}
313310

311+
/**
312+
* Sets the SSL Socket factory used to create secure TCP connections
313+
* @param sslSocketFactory SSLSocketFactory
314+
*/
315+
public void setSSLSocketFactory(SSLSocketFactory sslSocketFactory) {
316+
synchronized (globalLock) {
317+
if (isRunning) {
318+
throw new IllegalStateException("Cannot set SSLSocketFactory while WebSocketClient is running");
319+
} else if (sslSocketFactory == null) {
320+
throw new IllegalStateException("SSLSocketFactory cannot be null");
321+
}
322+
this.sslSocketFactory = sslSocketFactory;
323+
}
324+
}
325+
314326
/**
315327
* Creates and starts the thread that will handle the WebSocket connection
316328
*/
@@ -654,7 +666,10 @@ private boolean createAndConnectTCPSocket() throws IOException {
654666
socket.connect(new InetSocketAddress(uri.getHost(), 80), connectTimeout);
655667
}
656668
} else if (scheme.equals("wss")) {
657-
socket = socketFactory.createSocket();
669+
if (sslSocketFactory == null) {
670+
sslSocketFactory = (SSLSocketFactory) SSLSocketFactory.getDefault();
671+
}
672+
socket = sslSocketFactory.createSocket();
658673
socket.setSoTimeout(readTimeout);
659674

660675
if (port != -1) {

0 commit comments

Comments
 (0)