In [None]:
import io
import socket
import select
import time
import requests
import torch
from PIL import Image
from io import BytesIO
from transformers import BitsAndBytesConfig

from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image

DEVICE = "cuda:0"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
model = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    quantization_config=quantization_config,
    device_map="auto",
).to(DEVICE)
print("model moved to device")


def model_out(image, text_input):
    global model
    messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": f"{text_input}"},
        ]
    }     
    ]
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[image], return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    generated_ids = model.generate(**inputs, max_new_tokens=100)
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

    return f"{generated_texts}"

def receive_data(port=12345, output_port=12346):
    s = socket.socket()
    s.bind(('0.0.0.0', port))
    s.listen(1)
    s.setblocking(0)
    print("Waiting for connection...")
    flag = 0

    inputs = [s]

    while True:
        readable, _, _ = select.select(inputs, [], [])

        for sock in readable:
            if sock is s:
                conn, addr = s.accept()
                print(f"Connection from {addr}")
                inputs.append(conn)
            else:
                try:
                    if(flag == 0):
                        text_size_data = sock.recv(16)
                        if not text_size_data:
                            print("Connection closed by the client.")
                            inputs.remove(sock)
                            sock.close()
                            continue
                        flag = 1
                    if(flag == 1):
                        print("loop entry")
                        text_size = int(text_size_data.decode().strip())
                        text_input = sock.recv(text_size).decode()
                        if not text_input:
                            print("Connection closed by the client.")
                            inputs.remove(sock)
                            sock.close()
                            continue
                        flag = 2
                        print(f"Text received: {text_input}")
                    if(flag == 2):
                        image_size_data = sock.recv(16)
                        if not image_size_data:
                            print("Connection closed by the client.")
                            inputs.remove(sock)
                            sock.close()
                            continue
                        flag = 3
                        print(image_size_data)
                    if(flag == 3):
                        image_size = int(image_size_data.decode().strip())
                        received_data = b""
                        print(image_size)
                        flag = 4
                    if(flag == 4):
                        while len(received_data) < image_size:
                            packet = sock.recv(4096)
                            if not packet:
                                print("Connection closed unexpectedly.")
                                inputs.remove(sock)
                                sock.close()
                                continue
                            received_data += packet
                        with open("rcx.jpg", "wb") as f:
                            f.write(received_data)
                        image = Image.open("rcx.jpg")
                        print("Image received and processed.")
                        output_data = model_out(image, text_input)
                        send_output(output_data, addr[0], output_port)

                        inputs.remove(sock)
                        sock.close()
                        flag = 0

                except ValueError as e:
                    print(f"ValueError: {e}")
                except BlockingIOError:
                    continue

def send_output(output_data, client_ip, port):
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.connect((client_ip, port))
        sock.sendall(output_data.encode())
        print(f"Sent output data to {client_ip}:{port}")


port = 12345
output_port = 12346
receive_data(port, output_port)