Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REPL: fix the process of reading data from the server. #421

Merged
merged 12 commits into from
Apr 26, 2023
231 changes: 194 additions & 37 deletions src/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,131 @@ use erg_compiler::Compiler;
pub type EvalError = CompileError;
pub type EvalErrors = CompileErrors;

/// The instructions for communication between the client and the server.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
enum Inst {
/// Send from server to client. Informs the client to print data.
Print = 0x01,
/// Send from client to server. Informs the REPL server that the executable .pyc file has been written out and is ready for evaluation.
Load = 0x02,
/// Send from server to client. Represents an exception.
Exception = 0x03,
/// Send from server to client. Tells the code generator to initialize due to an error.
Initialize = 0x04,
/// Informs that the connection is to be / should be terminated.
Exit = 0x05,
/// Informs that it is not a supported instruction.
Unknown = 0x00,
}

impl From<u8> for Inst {
fn from(v: u8) -> Inst {
match v {
0x01 => Inst::Print,
0x02 => Inst::Load,
0x03 => Inst::Exception,
0x04 => Inst::Initialize,
0x05 => Inst::Exit,
_ => Inst::Unknown,
}
}
}

/// -------------------------------
/// | ins | size | data
/// -------------------------------
/// | 1 byte | 2 bytes | n bytes
/// -------------------------------
#[derive(Debug, Clone)]
struct Message {
inst: Inst,
size: usize,
mtshiba marked this conversation as resolved.
Show resolved Hide resolved
data: Option<Vec<u8>>,
}

impl Message {
fn new(inst: Inst, data: Option<Vec<u8>>) -> Self {
let size = if let Some(d) = &data { d.len() } else { 0 };
Self { inst, size, data }
}
}

#[derive(Debug)]
struct MessageStream<T: Read + Write> {
stream: T,
}

impl<T: Read + Write> MessageStream<T> {
fn new(stream: T) -> Self {
Self { stream }
}

fn send_msg(&mut self, msg: &Message) -> Result<(), std::io::Error> {
let mut write_buf = Vec::with_capacity(1024);
write_buf.extend((msg.inst as u8).to_be_bytes());
write_buf.extend((msg.size).to_be_bytes());
write_buf.extend_from_slice(&msg.data.clone().unwrap_or_default());

self.stream.write_all(&write_buf)?;

Ok(())
}

fn recv_msg(&mut self) -> Result<Message, std::io::Error> {
// read instruction, 1 byte
let mut inst_buf = [0; 1];
self.stream.read_exact(&mut inst_buf)?;

let inst: Inst = u8::from_be_bytes(inst_buf).into();

// read size, 2 bytes
let mut size_buf = [0; 2];
self.stream.read_exact(&mut size_buf)?;

let data_size = u16::from_be_bytes(size_buf) as usize;

if data_size == 0 {
return Ok(Message::new(inst, None));
}

// read data
let mut data_buf = vec![0; data_size];
self.stream.read_exact(&mut data_buf)?;

Ok(Message::new(inst, Some(data_buf)))
}
}

#[test]
fn test_message() {
use std::collections::VecDeque;

let inner = Box::<VecDeque<u8>>::default();
let mut stream = MessageStream::new(inner);

// test send_msg
stream.send_msg(&Message::new(Inst::Load, None)).unwrap();
assert_eq!(
stream.stream.as_slices(),
(&[2, 0, 0, 0, 0, 0, 0, 0, 0][..], &[][..])
);

// test recv_msg
// data field, 'A' => hex is 0x41
stream.stream.push_front(0x41);
// size field
stream.stream.push_front(0x01);
stream.stream.push_front(0x00);
// inst field
stream.stream.push_front(0x01);

let msg = stream.recv_msg().unwrap();
assert_eq!(msg.inst, Inst::Print);
assert_eq!(msg.size, 1);
assert_eq!(std::str::from_utf8(&msg.data.unwrap()).unwrap(), "A");
}

fn find_available_port() -> u16 {
let socket = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
TcpListener::bind(socket)
Expand All @@ -33,7 +158,7 @@ fn find_available_port() -> u16 {
#[derive(Debug)]
pub struct DummyVM {
compiler: Compiler,
stream: Option<TcpStream>,
stream: Option<MessageStream<TcpStream>>,
}

impl Default for DummyVM {
Expand Down Expand Up @@ -82,7 +207,7 @@ impl Runnable for DummyVM {
stream
.set_read_timeout(Some(Duration::from_secs(cfg.py_server_timeout)))
.unwrap();
break Some(stream);
break Some(MessageStream::new(stream));
}
Err(_) => {
if !cfg.quiet_repl {
Expand All @@ -104,15 +229,16 @@ impl Runnable for DummyVM {

fn finish(&mut self) {
if let Some(stream) = &mut self.stream {
if let Err(err) = stream.write_all("exit".as_bytes()) {
// send exit to server
if let Err(err) = stream.send_msg(&Message::new(Inst::Exit, None)) {
eprintln!("Write error: {err}");
process::exit(1);
}
let mut buf = [0; 1024];
match stream.read(&mut buf) {
Result::Ok(n) => {
let s = std::str::from_utf8(&buf[..n]).unwrap();
if s.contains("closed") && !self.cfg().quiet_repl {

// wait server exit
match stream.recv_msg() {
Result::Ok(msg) => {
if msg.inst == Inst::Exit && !self.cfg().quiet_repl {
println!("The REPL server is closed.");
}
}
Expand All @@ -121,6 +247,7 @@ impl Runnable for DummyVM {
process::exit(1);
}
}

remove_file(self.cfg().dump_pyc_filename()).unwrap_or(());
}
}
Expand Down Expand Up @@ -158,18 +285,71 @@ impl Runnable for DummyVM {
.map_err(|eart| eart.errors)?;
let (last, warns) = (arti.object, arti.warns);
let mut res = warns.to_string();

macro_rules! err_handle {
() => {
{
self.finish();
process::exit(1);

}
};
($hint:expr, $($args:expr),*) => {
{
self.finish();
eprintln!($hint, $($args)*);
process::exit(1);
}
};
}

// Tell the REPL server to execute the code
res += &match self.stream.as_mut().unwrap().write("load".as_bytes()) {
Result::Ok(_) => self.read()?,
Result::Err(err) => {
self.finish();
eprintln!("Sending error: {err}");
process::exit(1);
if let Err(err) = self
.stream
.as_mut()
.unwrap()
.send_msg(&Message::new(Inst::Load, None))
{
err_handle!("Sending error: {}", err);
};

// receive data from server
let data = match self.stream.as_mut().unwrap().recv_msg() {
Result::Ok(msg) => {
let s = match msg.inst {
Inst::Exception => {
debug_assert!(
std::str::from_utf8(msg.data.as_ref().unwrap()) == Ok("SystemExit")
);
return Err(EvalErrors::from(EvalError::system_exit()));
}
Inst::Initialize => {
self.compiler.initialize_generator();
String::from_utf8(msg.data.unwrap_or_default())
}
Inst::Print => String::from_utf8(msg.data.unwrap_or_default()),
Inst::Exit => err_handle!("Recving inst {:?} from server", msg.inst),
// `load` can only be sent from the client to the server
Inst::Load | Inst::Unknown => {
err_handle!("Recving unexpected inst {:?} from server", msg.inst)
}
};

if let Ok(ss) = s {
ss
} else {
err_handle!("Failed to parse server response data, error: {:?}", s.err());
}
}
Result::Err(err) => err_handle!("Recving error: {}", err),
};

res.push_str(&data);
// If the result of an expression is None, it will not be displayed in the REPL.
if res.ends_with("None") {
res.truncate(res.len() - 5);
}

if self.cfg().show_type {
res.push_str(": ");
res.push_str(
Expand Down Expand Up @@ -197,27 +377,4 @@ impl DummyVM {
pub fn eval(&mut self, src: String) -> Result<String, EvalErrors> {
Runnable::eval(self, src)
}

fn read(&mut self) -> Result<String, EvalErrors> {
let mut buf = [0; 1024];
match self.stream.as_mut().unwrap().read(&mut buf) {
Result::Ok(n) => {
let s = std::str::from_utf8(&buf[..n])
.expect("failed to parse the response, maybe the output is too long");
match s {
"[Exception] SystemExit" => Err(EvalErrors::from(EvalError::system_exit())),
"[Initialize]" => {
self.compiler.initialize_generator();
self.read()
}
_ => Ok(s.to_string()),
}
}
Result::Err(err) => {
self.finish();
eprintln!("Read error: {err}");
process::exit(1);
}
}
}
}
43 changes: 33 additions & 10 deletions src/scripts/repl_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,41 @@
__already_loaded = False
__ctx = {'__importlib': __importlib}

class INST:
# Informs that it is not a supported instruction.
UNKNOWN = 0x00
# Send from server to client. Informs the client to print data.
PRINT = 0x01
# Send from client to server. Informs the REPL server that the executable .pyc file has been written out and is ready for evaluation.
LOAD = 0x02
# Send from server to client. Represents an exception.
EXCEPTION = 0x03
# Send from server to client. Tells the code generator to initialize due to an error.
INITIALIZE = 0x04
# Informs that the connection is to be / should be terminated.
EXIT = 0x05

def __encode(instr, data=''):
data_bytes = data.encode()
data_len = len(data_bytes)
# one byte for inst, two bytes for size, and n bytes for data(Optional)
return instr.to_bytes(1, 'big') + data_len.to_bytes(2, 'big') + data_bytes

while True:
try:
__order = __client_socket.recv(1024).decode()
__data = __client_socket.recv(1024)
mtshiba marked this conversation as resolved.
Show resolved Hide resolved
except ConnectionResetError: # when the server was crashed
break
if __order == 'quit' or __order == 'exit': # when the server was closed successfully
__client_socket.send('closed'.encode())
__inst = int.from_bytes(__data[:1], 'big')
if __inst == INST.EXIT: # when the server was closed successfully
__client_socket.send(__encode(INST.EXIT))
break
elif __order == 'load':
elif __inst == INST.LOAD:
__sys.stdout = __io.StringIO()
__res = ''
__exc = ''
__resp_inst = INST.PRINT
__buf = []
try:
if __already_loaded:
# __MODULE__ will be replaced with module name
Expand All @@ -35,7 +58,7 @@
__res = str(exec('import __MODULE__', __ctx))
__already_loaded = True
except SystemExit:
__client_socket.send('[Exception] SystemExit'.encode())
__client_socket.send(__encode(INST.EXCEPTION, 'SystemExit'))
continue
except Exception as e:
try:
Expand All @@ -44,15 +67,15 @@
excs = __traceback.format_exception_only(e.__class__, e)
__exc = ''.join(excs).rstrip()
__traceback.clear_frames(e.__traceback__)
__client_socket.send('[Initialize]'.encode())
__resp_inst = INST.INITIALIZE
__out = __sys.stdout.getvalue()[:-1]
# assert not(__exc and __res)
if __exc or __res:
if __out and __exc or __res:
__out += '\n'
mtshiba marked this conversation as resolved.
Show resolved Hide resolved
__res = __out + __exc + __res
__client_socket.send(__res.encode())
__buf.append(__res)
__client_socket.send(__encode(__resp_inst, ''.join(__buf)))
else:
__client_socket.send('unknown operation'.encode())
__client_socket.send(__encode(INST.UNKNOWN))

__client_socket.close()
__server_socket.close()