Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 105 additions & 115 deletions test/jdk/sun/security/ssl/SSLSessionImpl/ResumeChecksClient.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2018, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -23,7 +23,7 @@

/*
* @test
* @bug 8206929 8212885
* @bug 8206929 8212885 8333857
* @summary ensure that client only resumes a session if certain properties
* of the session are compatible with the new connection
* @library /javax/net/ssl/templates
Expand All @@ -47,6 +47,9 @@
import java.security.*;
import java.net.*;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ResumeChecksClient extends SSLContextTemplate {
enum TestMode {
Expand All @@ -56,49 +59,60 @@ enum TestMode {
CIPHER_SUITE,
SIGNATURE_SCHEME
}
static TestMode testMode;

public static void main(String[] args) throws Exception {
new ResumeChecksClient(TestMode.valueOf(args[0])).run();
testMode = TestMode.valueOf(args[0]);
new ResumeChecksClient().test();
}

private final TestMode testMode;
public ResumeChecksClient(TestMode mode) {
this.testMode = mode;
}

private void run() throws Exception {
Server server = startServer();
server.signal();
private void test() throws Exception {
Server server = new Server();
SSLContext sslContext = createClientSSLContext();
while (!server.started) {
Thread.yield();
}
SSLSession firstSession = connect(sslContext, server.port, testMode, false);
HexFormat hex = HexFormat.of();
long firstStartTime = System.currentTimeMillis();
SSLSession firstSession = connect(sslContext, server.port, true);
System.err.println("firstStartTime = " + firstStartTime);
System.err.println("firstId = " + hex.formatHex(firstSession.getId()));
System.err.println("firstSession.getCreationTime() = " +
firstSession.getCreationTime());

server.signal();
long secondStartTime = System.currentTimeMillis();
Thread.sleep(10);
SSLSession secondSession = connect(sslContext, server.port, testMode, true);

server.go = false;
server.signal();
SSLSession secondSession = connect(sslContext, server.port, false);
System.err.println("secondStartTime = " + secondStartTime);
// Note: Ids will never match with TLS 1.3 due to spec
System.err.println("secondId = " + hex.formatHex(secondSession.getId()));
System.err.println("secondSession.getCreationTime() = " +
secondSession.getCreationTime());

switch (testMode) {
case BASIC:
// fail if session is not resumed
checkResumedSession(firstSession, secondSession);
try {
checkResumedSession(firstSession, secondSession);
} catch (Exception e) {
throw new AssertionError("secondSession did not resume: FAIL",
e);
}
System.out.println("secondSession used resumption: PASS");
break;
case VERSION_2_TO_3:
case VERSION_3_TO_2:
case CIPHER_SUITE:
case SIGNATURE_SCHEME:
// fail if a new session is not created
if (secondSession.getCreationTime() <= secondStartTime) {
throw new RuntimeException("Existing session was used");
try {
checkResumedSession(firstSession, secondSession);
System.err.println("firstSession = " + firstSession);
System.err.println("secondSession = " + secondSession);
throw new AssertionError("Second connection should not " +
"have resumed first session: FAIL");
} catch (Exception e) {
System.out.println("secondSession didn't use resumption: PASS");
}
break;
default:
throw new RuntimeException("unknown mode: " + testMode);
throw new AssertionError("unknown mode: " + testMode);
}
}

Expand Down Expand Up @@ -134,51 +148,29 @@ public boolean permits(Set<CryptoPrimitive> primitives,
}

private static SSLSession connect(SSLContext sslContext, int port,
TestMode mode, boolean second) {
boolean first) {

try {
SSLSocket sock = (SSLSocket)
sslContext.getSocketFactory().createSocket();
SSLParameters params = sock.getSSLParameters();

switch (mode) {
case BASIC:
// do nothing to ensure resumption works
break;
case VERSION_2_TO_3:
if (second) {
params.setProtocols(new String[] {"TLSv1.3"});
} else {
params.setProtocols(new String[] {"TLSv1.2"});
}
break;
case VERSION_3_TO_2:
if (second) {
params.setProtocols(new String[] {"TLSv1.2"});
} else {
params.setProtocols(new String[] {"TLSv1.3"});
}
break;
case CIPHER_SUITE:
if (second) {
params.setCipherSuites(
new String[] {"TLS_AES_256_GCM_SHA384"});
} else {
params.setCipherSuites(
new String[] {"TLS_AES_128_GCM_SHA256"});
}
break;
case SIGNATURE_SCHEME:
AlgorithmConstraints constraints =
params.getAlgorithmConstraints();
if (second) {
params.setAlgorithmConstraints(new NoSig("ecdsa"));
} else {
params.setAlgorithmConstraints(new NoSig("rsa"));
}
break;
default:
throw new RuntimeException("unknown mode: " + mode);
switch (testMode) {
case BASIC -> {} // do nothing
case VERSION_2_TO_3 -> params.setProtocols(new String[]{
first ? "TLSv1.2" : "TLSv1.3"});
case VERSION_3_TO_2 -> params.setProtocols(new String[]{
first ? "TLSv1.3" : "TLSv1.2"});
case CIPHER_SUITE -> params.setCipherSuites(
new String[]{
first ? "TLS_AES_128_GCM_SHA256" :
"TLS_AES_256_GCM_SHA384"});
case SIGNATURE_SCHEME ->
params.setAlgorithmConstraints(new NoSig(
first ? "rsa" : "ecdsa"));
default ->
throw new AssertionError("unknown mode: " +
testMode);
}
sock.setSSLParameters(params);
sock.connect(new InetSocketAddress("localhost", port));
Expand All @@ -195,7 +187,7 @@ private static SSLSession connect(SSLContext sslContext, int port,
return result;
} catch (Exception ex) {
// unexpected exception
throw new RuntimeException(ex);
throw new AssertionError(ex);
}
}

Expand Down Expand Up @@ -274,65 +266,63 @@ private static void checkResumedSession(SSLSession initSession,
}
}

private static Server startServer() {
Server server = new Server();
new Thread(server).start();
return server;
}

private static class Server extends SSLContextTemplate implements Runnable {

public volatile boolean go = true;
private boolean signal = false;
public volatile int port = 0;
public volatile boolean started = false;
private static class Server extends SSLContextTemplate {
public int port;
private final SSLServerSocket ssock;
ExecutorService threadPool = Executors.newFixedThreadPool(1);
CountDownLatch serverLatch = new CountDownLatch(1);

private synchronized void waitForSignal() {
while (!signal) {
try {
wait();
} catch (InterruptedException ex) {
// do nothing
}
}
signal = false;
}
public synchronized void signal() {
signal = true;
notify();
}

@Override
public void run() {
Server() {
try {

SSLContext sc = createServerSSLContext();
ServerSocketFactory fac = sc.getServerSocketFactory();
SSLServerSocket ssock = (SSLServerSocket)
fac.createServerSocket(0);
this.port = ssock.getLocalPort();
ssock = (SSLServerSocket) fac.createServerSocket(0);
port = ssock.getLocalPort();

waitForSignal();
started = true;
while (go) {
// Thread to allow multiple clients to connect
new Thread(() -> {
try {
System.out.println("Waiting for connection");
Socket sock = ssock.accept();
BufferedReader reader = new BufferedReader(
new InputStreamReader(sock.getInputStream()));
String line = reader.readLine();
System.out.println("server read: " + line);
PrintWriter out = new PrintWriter(
new OutputStreamWriter(sock.getOutputStream()));
out.println(line);
out.flush();
waitForSignal();
System.err.println("Server starting to accept");
serverLatch.countDown();
do {
threadPool.submit(
new ServerThread((SSLSocket) ssock.accept()));
} while (true);
} catch (Exception ex) {
ex.printStackTrace();
throw new AssertionError("Server Down", ex);
} finally {
threadPool.close();
}
}).start();

} catch (Exception e) {
throw new AssertionError(e);
}
}

static class ServerThread extends Thread {
SSLSocket sock;

ServerThread(SSLSocket s) {
this.sock = s;
System.err.println("(Server) client connection on port " +
sock.getPort());
}

public void run() {
try {
BufferedReader reader = new BufferedReader(
new InputStreamReader(sock.getInputStream()));
String line = reader.readLine();
System.out.println("server read: " + line);
PrintWriter out = new PrintWriter(
new OutputStreamWriter(sock.getOutputStream()));
out.println(line);
out.flush();
out.close();
} catch (Exception e) {
throw new AssertionError("Server thread error", e);
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
}
Expand Down
Loading