Skip to content

Commit

Permalink
Rework port forwarding unittest and example.
Browse files Browse the repository at this point in the history
  • Loading branch information
iciclespider committed Sep 6, 2020
1 parent fada718 commit 57381a2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 49 deletions.
47 changes: 24 additions & 23 deletions examples/pod_portforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Shows the functionality of portforward streaming using an nginx container.
"""

import select
import socket
import time
import urllib.request
Expand Down Expand Up @@ -53,8 +54,8 @@ def portforward_commands(api_instance):
}]
}
}
resp = api_instance.create_namespaced_pod(body=pod_manifest,
namespace='default')
api_instance.create_namespaced_pod(body=pod_manifest,
namespace='default')
while True:
resp = api_instance.read_namespaced_pod(name=name,
namespace='default')
Expand All @@ -65,40 +66,40 @@ def portforward_commands(api_instance):

pf = portforward(api_instance.connect_get_namespaced_pod_portforward,
name, 'default',
ports='80,8080:80')
for port in (80, 8080):
http = pf.socket(port)
http.settimeout(1)
http.sendall(b'GET / HTTP/1.1\r\n')
http.sendall(b'Host: 127.0.0.1\r\n')
http.sendall(b'Accept: */*\r\n')
http.sendall(b'\r\n')
response = b''
while True:
try:
response += http.recv(1024)
except socket.timeout:
break
print(response.decode('utf-8'))
http.close()
ports='80')
http = pf.socket(80)
http.setblocking(True)
http.sendall(b'GET / HTTP/1.1\r\n')
http.sendall(b'Host: 127.0.0.1\r\n')
http.sendall(b'Accept: */*\r\n')
http.sendall(b'Connection: close\r\n')
http.sendall(b'\r\n')
response = b''
while True:
select.select([http], [], [])
data = http.recv(1024)
if not data:
break
response += data
http.close()
print(response.decode('utf-8'))

# Monkey patch socket.create_connection which is used by http.client and
# urllib.request. The same can be done with urllib3.util.connection.create_connection
# if the "requests" package is used.
socket_create_connection = socket.create_connection
def kubernetes_create_connection(address, *args, **kwargs):
dns_name = address[0]
if isinstance(dns_name, bytes):
dns_name = dns_name.decode()
# Look for "<pod-name>.<namspace>.kubernetes" dns names and if found
# provide a socket that is port forwarded to the kuberntest pod.
# provide a socket that is port forwarded to the kubernetes pod.
dns_name = dns_name.split(".")
if len(dns_name) != 3 or dns_name[2] != "kubernetes":
return socket_create_connection(address, *args, **kwargs)
pf = portforward(api_instance.connect_get_namespaced_pod_portforward,
dns_name[0], dns_name[1], ports=str(address[1]))
return pf.socket(address[1])

socket_create_connection = socket.create_connection
socket.create_connection = kubernetes_create_connection

# Access the nginx http server using the "<pod-name>.<namespace>.kubernetes" dns name.
Expand All @@ -111,9 +112,9 @@ def kubernetes_create_connection(address, *args, **kwargs):

def main():
config.load_kube_config()
c = Configuration()
c = Configuration.get_default_copy()
c.assert_hostname = False
#Configuration.set_default(c)
Configuration.set_default(c)
core_v1 = core_v1_api.CoreV1Api()

portforward_commands(core_v1)
Expand Down
65 changes: 39 additions & 26 deletions kubernetes/e2e_test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# under the License.

import json
import select
import socket
import time
import unittest
Expand Down Expand Up @@ -167,7 +168,10 @@ def test_portforward_raw(self):
api = core_v1_api.CoreV1Api(client)

name = 'portforward-raw-' + short_uuid()
pod_manifest = manifest_with_command(name, "while true;do nc -l -p 1234 -e /bin/cat; done")
pod_manifest = manifest_with_command(
name,
'for port in 1234 1235;do ((while true;do nc -l -p $port -e /bin/cat; done)&);done;sleep 60',
)
resp = api.create_namespaced_pod(body=pod_manifest,
namespace='default')
self.assertEqual(name, resp.metadata.name)
Expand All @@ -182,39 +186,48 @@ def test_portforward_raw(self):
break
time.sleep(1)

pf1234 = portforward(api.connect_get_namespaced_pod_portforward,
pf = portforward(api.connect_get_namespaced_pod_portforward,
name, 'default',
ports='1234')
sock1234 = pf1234.socket(1234)
sock1234.settimeout(1)
ports='1234,1235')
sock1234 = pf.socket(1234)
sock1235 = pf.socket(1235)
sock1234.setblocking(True)
sock1235.setblocking(True)
sent1234 = b'Test port 1234 forwarding...'
sent1235 = b'Test port 1235 forwarding...'
sock1234.sendall(sent1234)
sock1235.sendall(sent1235)
reply1234 = b''
reply1235 = b''
while True:
try:
reply1234 += sock1234.recv(1024)
except socket.timeout:
rlist = []
if sock1234.fileno() != -1:
rlist.append(sock1234)
if sock1235.fileno() != -1:
rlist.append(sock1235)
if not rlist:
break
r, _w, _x = select.select(rlist, [], [], 1)
if not r:
break
if sock1234 in r:
data = sock1234.recv(1024)
if data:
reply1234 += data
else:
sock1234.close()
if sock1235 in r:
data = sock1235.recv(1024)
if data:
reply1235 += data
else:
sock1235.close()
sock1234.close()
sock1235.close()
self.assertEqual(reply1234, sent1234)
self.assertIsNone(pf1234.error(1234))

pf9999 = portforward(api.connect_get_namespaced_pod_portforward,
name, 'default',
ports='9999:1234')
sock9999 = pf9999.socket(9999)
sock9999.settimeout(1)
sent9999 = b'Test port 9999 forwarding...'
sock9999.sendall(sent9999)
reply9999 = b''
while True:
try:
reply9999 += sock9999.recv(1024)
except socket.timeout:
break
self.assertEqual(reply9999, sent9999)
sock9999.close()
self.assertIsNone(pf9999.error(9999))
self.assertEqual(reply1235, sent1235)
self.assertIsNone(pf.error(1234))
self.assertIsNone(pf.error(1235))

resp = api.delete_namespaced_pod(name=name, body={},
namespace='default')
Expand Down

0 comments on commit 57381a2

Please sign in to comment.