In [1]:
%reload_ext autoreload
%autoreload 2
from urllib.request import urlopen
import torch
from PIL import Image
from yacgo.model import EfficientFormerV2

img = Image.open(
    urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))

device = "cuda" if torch.cuda.is_available() else "mps"
# print(device)

# model = timm.create_model('efficientformerv2_s0.snap_dist_in1k', pretrained=True, img_size=19, in_chans=12)
# model = model.eval().to(device)
board_size = 19
model = EfficientFormerV2(
    depths=[2, 2, 6, 4],
    in_chans=12, # num of game state channels
    img_size=board_size,
    embed_dims=(32, 64, 96, 172),
    downsamples=(False, True, True, True),
    num_vit=2,
    mlp_ratios=(4, 4, (4, 3, 3, 3, 4, 4), (4, 3, 3, 4)),
    num_classes=board_size**2 + 1  
)

inputs = torch.randn(1, 12, board_size, board_size)
print(inputs.shape)
_ = model(inputs)

torch.Size([1, 12, 19, 19])


In [2]:
model = model.to(device)

In [3]:
c = 1_000
chans = 12
bs = [256]
import time

data = torch.randn(bs[0], chans, board_size, board_size).to(device)
for _ in range(100):
    _ = model(data)
    
for b in bs:
    data = torch.randn(b, chans, board_size, board_size).to(device)
    start = time.time()
    for _ in range(c):
        _ = model(data)
    end = time.time()
    print(f"Batch size: {b}")
    print("=====================================")
    print(f"Average inference latency per batch: {(end - start) / c: .4f} seconds")
    print(f"average inference latency per example: {(end - start) / c / b: .7f} seconds")
    print(f"1000 inferences in {end - start:.2f} seconds")
    print()

Batch size: 256
Average inference latency per batch:  0.0151 seconds
average inference latency per example:  0.0000590 seconds
1000 inferences in 15.11 seconds



In [19]:
c = 1_000
bs = 
import time
img = Image.open(
    urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
img = transforms(img)
img = [img for _ in range(bs)]
img = torch.stack(img).to(device)


for _ in range(100):
    _ = model(img)

start = time.time()
for _ in range(c):
    _ = model(img)
end = time.time()

print(f"Average inference latency per batch: {(end - start) / c: .4f} seconds")
print(f"average inference latency per example: {(end - start) / c / bs: .7f} seconds")
print(f"1000 inferences in {end - start:.2f} seconds")

Average inference latency per batch:  0.0296 seconds
average inference latency per example:  0.0018483 seconds
1000 inferences in 29.57 seconds


In [None]:
# https://stackoverflow.com/questions/73764403/how-can-a-zmq-server-process-requests-in-batches

import msgpack
import multiprocessing as mp
import numpy as np
import time
import uuid
import zmq

def computation(inputs):
  time.sleep(1)  # Simulate constant GPU overhead.
  results = np.zeros((len(inputs), 8))
  return results

def server(port, batch=10):
  context = zmq.Context.instance()
  socket = context.socket(zmq.ROUTER)
  socket.bind(f'tcp://*:{port}')
  while True:
    inputs = np.empty((batch, 64))
    addresses = []
    for i in range(batch):
      address, empty, payload = socket.recv_multipart()
      inputs[i] = unpack(payload)
      addresses.append(address)
    print('Collected request batch.', flush=True)
    results = computation(inputs)
    for i, address in enumerate(addresses):
      payload = pack(results[i])
      socket.send_multipart([address, b'', payload])
    print('Send response batch.', flush=True)

def client(ports):
  context = zmq.Context.instance()
  socket = context.socket(zmq.REQ)
  socket.setsockopt(zmq.IDENTITY, uuid.uuid4().bytes)
  for port in ports:
    socket.connect(f'tcp://localhost:{port}')
  while True:
    input_ = np.zeros(64)
    socket.send(pack(input_))
    result = unpack(socket.recv())

def pack(array):
  return msgpack.packb((array.shape, str(array.dtype), array.tobytes()))

def unpack(buffer):
  shape, dtype, value = msgpack.unpackb(buffer)
  return np.frombuffer(value, dtype).reshape(shape)

if __name__ == '__main__':
  num_clients = 100
  num_servers = 3
  ports = list(range(5550, 5550 + num_servers))
  for port in ports:
    mp.Process(target=server, args=(port,)).start()
  for _ in range(num_clients):
    mp.Process(target=client, args=(ports,)).start()

In [None]:
import torch
from yacgo.models import ViTWrapper, InferenceLocal
