In [1]:
import socket
import dill
import torch
import concurrent
import io
import queue
import threading

class FileAsyncResult:
    def __init__(self, filepath):
        self.event = threading.Event()
        self.filepath = filepath
        self.exception = None
        
    def set_exception(self, exception):
        self.exception = exception
        self.event.set()

    def set_result(self, sockfile):
        with open(self.filepath) as file:
            size = dill.load(sockfile)
            while size > 0:
                cnt_read = min(size, 10 * 1024 * 1024)
                file.write(sockfile.read(cnt_read))
                size -= cnt_read
        self.event.set()

class BasicAsyncResult:
    def __init__(self):
        self.event = threading.Event()
        self.result = None
        self.exception = None
        
    def set_exception(self, exception):
        self.exception = exception
        self.event.set()
        
    def set_result(self, sockfile):
        self.result = dill.load(sockfile)
        self.event.set()
        
class RemoteRunnerServer:
    def __init__(self):
        self.requests_queue = queue.Queue()

    def connection_handler(self, sock):
        with sock, sock.makefile('rwb') as sockfile:
            while True:
                result, typ, rpc_request = self.requests_queue.get()
                
                # GLOBALS
                # sockfile.write(b'0')
                # dill.dump({
                #     "func4": func4
                # }, sockfile)
            
                # RPC COMMAND
                sockfile.write(typ) # b'1'
                dill.dump(rpc_request, sockfile)                          
                sockfile.flush()
                
                typ = sockfile.read(1)
                if typ == b'0':
                    result.set_result(sockfile)
                elif typ == b'1':
                    result.set_exception(dill.load(sockfile))
                else:
                    raise Exception("unknown type " + str(typ))
    
    def send_globals(self, globls):
        assert isinstance(globls, dict)
        
        result = BasicAsyncResult()
        self.requests_queue.put((result, b'0', globls))
        result.event.wait()
        return result.result     
    
    def rpc_simple(self, func, *args, **kwargs):
        assert callable(func)
                                
        result = BasicAsyncResult()
        self.requests_queue.put((result, b'1', ((func, args, kwargs, 1))))
        result.event.wait()
        if result.exception:
            raise result.exception
        return result.result
        
    def rpc_file(self, file, func, *args, **kwargs):
        assert callable(func)
        assert isinstance(file, str)
        result = FileAsyncResult(file)
        self.requests_queue.put((result, b'1', ((func, args, kwargs, 1))))
        result.event.wait()
        if result.exception:
            raise result.exception
        return True
    
    def host_server(self):
        HOST = '0.0.0.0' 
        PORT = 65231

        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            s.bind((HOST, PORT))
            # listen for incoming connections
            s.listen()
            print(f"Server is listening on {HOST}:{PORT}")
            # wait for a client to connect
            
            with concurrent.futures.ThreadPoolExecutor(4, "ServerConnectionHandler") as executor:
                while True:
                    conn, addr = s.accept()
                    print(f"Connected by {addr}")
                    self.connection_handler(conn)
                    # executor.submit(self.handle_connection, conn)
                        # receive data from the client
    
    def run(self):
        self.thread = threading.Thread(target=self.host_server)
        self.thread.start()

def is_sendable(x):
    if callable(x):
        return True
    if str(type(x)) != "<class 'module'>":
        return False
    return 'built-in' not in str(x) 

def get_sendable_globals(globs):
    res = {}
    for k,v in globs.items():
        if k in ['exit', 'open', 'quit', 'get_ipython']: continue
        if not is_sendable(v): continue
        res[k] = v
    return res
        
server = RemoteRunnerServer()
server.run()
server

Server is listening on 0.0.0.0:65231


<__main__.RemoteRunnerServer at 0x7f9cc4e6d4b0>

In [9]:
import dill

In [11]:
globs = {}
exec("from imports import *", globs)
globs

{'__builtins__': {'__name__': 'builtins',
  '__doc__': "Built-in functions, exceptions, and other objects.\n\nNoteworthy: None is the `nil' object; Ellipsis represents `...' in slices.",
  '__package__': '',
  '__loader__': _frozen_importlib.BuiltinImporter,
  '__spec__': ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>, origin='built-in'),
  '__build_class__': <function __build_class__>,
  '__import__': <function __import__>,
  'abs': <function abs(x, /)>,
  'all': <function all(iterable, /)>,
  'any': <function any(iterable, /)>,
  'ascii': <function ascii(obj, /)>,
  'bin': <function bin(number, /)>,
  'breakpoint': <function breakpoint>,
  'callable': <function callable(obj, /)>,
  'chr': <function chr(i, /)>,
  'compile': <function compile(source, filename, mode, flags=0, dont_inherit=False, optimize=-1, *, _feature_version=-1)>,
  'delattr': <function delattr(obj, name, /)>,
  'dir': <function dir>,
  'divmod': <function divmod(x, y, /)>,
  'eval': <

In [13]:
import distributed
distributed

<module 'distributed' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/distributed/__init__.py'>

In [10]:
import dill
import pymorphy3
dill.dumps(pymorphy3, byref=False)

b'\x80\x04\x951\x00\x00\x00\x00\x00\x00\x00\x8c\ndill._dill\x94\x8c\x0e_import_module\x94\x93\x94\x8c\tpymorphy3\x94\x85\x94R\x94.'

In [8]:
import torch
def func(x):
    return torch.Tensor(10) # x * 2 #torch.Tensor((1000, 1000, 1000))

server.send_globals(globals())# get_sendable_globals(globals()))
server.rpc_simple(func, 100)

tensor([-9.8635e-22,  4.5646e-41, -9.8635e-22,  4.5646e-41,  4.4842e-44,
         0.0000e+00,  1.1210e-43,  0.0000e+00,  1.3183e-27,  3.0882e-41])

In [51]:

   
import dill
dill.loads(dill.dumps(get_sendable_globals(globals())))


{'torch': <module 'torch' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/torch/__init__.py'>,
 'func': <function __main__.func(x)>,
 '_3': <module 'torch' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/torch/__init__.py'>,
 '_5': <module 'torch' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/torch/__init__.py'>,
 '_6': module,
 'is_module': <function __main__.is_module(x)>,
 'get_modules_dict': <function __main__.get_modules_dict(globs)>,
 'dill': <module 'dill' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/dill/__init__.py'>,
 '_35': <module 'dill' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/dill/__init__.py'>,
 'is_sendable': <function __main__.is_sendable(x)>,
 'get_sendable_globals': <function __main__.get_sendable_globals(globs)>}

In [33]:
get_modules_dict(globals())

{'torch': <module 'torch' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/torch/__init__.py'>,
 '_3': <module 'torch' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/torch/__init__.py'>,
 '_5': <module 'torch' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/torch/__init__.py'>,
 'dill': <module 'dill' from '/home/misha-sh/micromamba/envs/pytorch-env/lib/python3.10/site-packages/dill/__init__.py'>}

In [97]:
torch.Tensor([1, 2, 3]).numpy().save()

array([1., 2., 3.], dtype=float32)

In [None]:
from enum import Enum

# class FRAME(Enum):
#     FUNCTION = 1,
#     ARGS = 2,
#     KWARGS = 3,
#     TENSOR = 4

# def decode_frame(f):
#     typ = int.from_bytes(f.read(4), byteorder='little')
#     decoded = dill.load(f)
#     #     decoded = torch.load(io.BytesIO(f.read(size))
#     return FRAME(typ), decoded
    
# def send_func(f, func)
#     f.write(FRAME.FUNCTION)
#     dill.dump(func, f)
# def send_args(f, args):
#     f.write(FRAME.ARGS)
#     dill.dump(func, f)
# def send_kwargs(f, args):
#     f.write(FRAME.KWARGS)
#     dill.dump(func, f)
# def send_tensor(f, tensor):
#     f.write(FRAME.TENSOR)
#     torch.save(tensor, f)

    