In [21]:
import tensorflow as tf
from tensorflow import keras
import cv2
import paho.mqtt.client as mqtt
import threading
import base64
import numpy as np
import mediapipe as mp
from PIL import ImageFont, ImageDraw, Image
from keras.models import load_model
import requests
import time

# update_caption 함수 정의
def update_caption(input_text, output_text, is_final=False):
    if is_final:
        # 최종 문장 (파란색으로 처리)
        print(f"OutputText: {output_text}")
    else:
        # 중간 단어들 (검정색으로 처리)
        print(f"InputText: {input_text}")

# MQTT
class ImageMqttPublisher:
    def __init__(self, broker_ip="localhost", broker_port=1883, pub_topic_image="/camerapub", pub_topic_text="/textpub"):
        self.broker_ip = broker_ip
        self.broker_port = broker_port
        self.pub_topic_image = pub_topic_image
        self.pub_topic_text = pub_topic_text
        self.client = None
        self.last_sent_text = ""  # 마지막으로 전송된 텍스트 저장
        self.last_sent_final_output = ""  # 마지막으로 전송된 최종 문장 저장
        self.final_output_sent = False  # 최종 문장이 전송되었는지 여부

    def connect(self):
        thread = threading.Thread(target=self._run, daemon=True)
        thread.start()

    def _run(self):
        self.client = mqtt.Client()
        self.client.on_connect = self._on_connect
        self.client.on_disconnect = self._on_disconnect
        self.client.on_message = self._on_message  # 메시지 수신 콜백 설정
        self.client.connect(self.broker_ip, self.broker_port)
        self.client.loop_forever()

    def _on_connect(self, client, userdata, flags, rc):
        print("MQTT broker connected")
        self.send_initial_empty_list()
        self.client.subscribe(self.pub_topic_text)  # 메시지 구독

    def _on_disconnect(self, client, userdata, rc):
        print("MQTT broker disconnected")

    def _on_message(self, client, userdata, message):
        self.on_message_arrived(message)

    def on_message_arrived(self, message):
        message_text = message.payload.decode().strip()
        if message_text.startswith("FINAL_OUTPUT:"):
            output_text = message_text.replace("FINAL_OUTPUT:", "").strip()
            update_caption("", output_text, is_final=True)
        else:
            input_text = message_text
            update_caption(input_text, "", is_final=False)

    def disconnect(self):
        if self.client:
            self.client.disconnect()

    def send_text(self, text, is_final=False):
        if self.client is None or not self.client.is_connected():
            return

        # 최종 문장이 이미 전송된 이후에는 중간 단어 전송을 막음
        if self.final_output_sent and not is_final:
            return

        # 중복된 텍스트인지 확인
        if is_final:
            if text == self.last_sent_final_output:
                return  # 중복된 최종 문장인 경우 전송하지 않음
            self.last_sent_final_output = text  # 최종 문장 업데이트
            self.final_output_sent = True  # 최종 문장 전송 플래그 설정
        else:
            if text == self.last_sent_text:
                return  # 중복된 중간 텍스트인 경우 전송하지 않음
            self.last_sent_text = text  # 중간 텍스트 업데이트

        self.client.publish(self.pub_topic_text, text, retain=True)

    def send_base64(self, frame):
        if self.client is None or not self.client.is_connected():
            return
        retval, buffer = cv2.imencode(".jpg", frame)
        if not retval:
            print("Image encoding failed")
            return
        b64_bytes = base64.b64encode(buffer)
        self.client.publish(self.pub_topic_image, b64_bytes, retain=True)

    def send_initial_empty_list(self):
        if self.client is None or not self.client.is_connected():
            return
        self.client.publish(self.pub_topic_text, "", retain=True)

# 20초 동안 pose만 잡히거나 아무것도 안 잡히면 멈추는 것
class LandmarkTracker:
    def __init__(self):
        self.last_landmark_time = time.time()
        self.timeout_seconds = 10

    def dispose(self, results):
        current_time = time.time()

        # Check if no landmarks are detected at all
        no_landmarks_detected = (
            results.pose_landmarks is None and 
            results.left_hand_landmarks is None and 
            results.right_hand_landmarks is None
        )

        # Check if only pose landmarks are detected but not both hand landmarks
        pose_only_detected = (
            results.pose_landmarks and 
            not (results.left_hand_landmarks and results.right_hand_landmarks)
        )

        # If no landmarks detected, check the timeout
        if no_landmarks_detected or pose_only_detected:
            if current_time - self.last_landmark_time > self.timeout_seconds:
                return True  # 20 seconds passed without landmarks
        else:
            self.last_landmark_time = current_time  # Reset timer when landmarks are detected

        return False  # Continue processing

# Mediapipe
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils

def mediapipe_detection(image, model):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image.flags.writeable = False
    results = model.process(image)
    image.flags.writeable = True
    return cv2.cvtColor(image, cv2.COLOR_RGB2BGR), results

def draw_styled_landmarks(image, results):
    pose_landmarks = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24]
    pose_connections = [
        (0, 2), (0, 5), (2, 7), (5, 8), (11, 12), (11, 23),
        (12, 24), (23, 24), (11, 13), (13, 15), (12, 14), (14, 16)
    ]

    if results.pose_landmarks:
        for connection in pose_connections:
            start_idx, end_idx = connection
            start_landmark = results.pose_landmarks.landmark[start_idx]
            end_landmark = results.pose_landmarks.landmark[end_idx]
            h, w, _ = image.shape
            x1, y1 = int(start_landmark.x * w), int(start_landmark.y * h)
            x2, y2 = int(end_landmark.x * w), int(end_landmark.y * h)
            cv2.line(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.circle(image, (x1, y1), 3, (255, 0, 255), -1)
            cv2.circle(image, (x2, y2), 3, (255, 0, 255), -1)

    if results.left_hand_landmarks:
        mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)

    if results.right_hand_landmarks:
        mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)

    return image, results

def extract_keypoints(results):
    if results.pose_landmarks and (results.left_hand_landmarks or results.right_hand_landmarks):
        pose_landmarks = [0, 2, 5, 7, 8, 11, 12, 13, 14, 15, 16, 23, 24]
        pose = np.array([[results.pose_landmarks.landmark[idx].x, results.pose_landmarks.landmark[idx].y] for idx in pose_landmarks]).flatten() if results.pose_landmarks else np.zeros(13 * 2)
        lh = np.array([[res.x, res.y] for res in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21*2)
        rh = np.array([[res.x, res.y] for res in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21*2)
        return np.concatenate([pose, lh, rh])
    else:
        return None
        
def send_llama(input_text, mqtt_publisher=None):
    url = 'https://7d8d-34-142-240-196.ngrok-free.app/generate'
    
    instruction = "You are a patient, sick person because you were sick. You receive medical treatment at the hospital. Explain where you are sick."
    
    data = {
        'instruction': instruction,
        'input': input_text
    }

    try:
        response = requests.post(url, json=data)
        response_data = response.json()
        if response.status_code == 200:
            output_text = response_data['output']
            print(output_text)
            if mqtt_publisher:
                mqtt_publisher.send_text(f"FINAL_OUTPUT: {output_text}")
        else:
            print(f"Error: {response.status_code}")
            print(response_data['error'])
    except requests.exceptions.JSONDecodeError:
        print("Failed to decode JSON response")
        print(response.text)
        
def main():
    video_capture = cv2.VideoCapture(0)
    mqtt_publisher = ImageMqttPublisher()
    mqtt_publisher.connect()

    model = keras.models.load_model('fcn-model-reference-wang2017-0100-0.9591.keras')
    actions = {0: 'CT', 1: '가끔', 2: '가능', 3: '가다', 4: '간(신체)', 5: '간호사', 6: '감사(고마움)', 7: '갑자기', 8: '걱정', 9: '건강', 10: '걷다(걸음)', 11: '검사(검진)', 12: '결과(결말)', 13: '계속', 14: '고생', 15: '곳', 16: '관절', 17: '괜찮다', 18: '그냥', 19: '그러면', 20: '그런데', 21: '근육', 22: '기능', 23: '기억', 24: '꼭', 25: '끝', 26: '나', 27: '나쁘다', 28: '나오다', 29: '날(시간)', 30: '남다', 31: '너무', 32: '노력', 33: '높다', 34: '느낌', 35: '다니다', 36: '다음', 37: '다치다', 38: '다행', 39: '달다(맛)', 40: '당뇨', 41: '더', 42: '동안(시간)', 43: '되다', 44: '듣다', 45: '디스크', 46: '따로', 47: '때(시간)', 48: '때문에', 49: '떨어지다', 50: '또(and)', 51: '마시다', 52: '마음', 53: '막다', 54: '만(어미)', 55: '만나다', 56: '만들다', 57: '만약', 58: '많다', 59: '맞다(옳다)', 60: '먹다', 61: '면(어미)', 62: '모두', 63: '모르다', 64: '몸', 65: '못하다', 66: '무엇', 67: '문제', 68: '바꾸다', 69: '바쁘다', 70: '방문', 71: '변하다', 72: '병(질병)', 73: '병원', 74: '보다(구경)', 75: '보통', 76: '부작용', 77: '부족', 78: '부탁', 79: '부터', 80: '불편', 81: '비교', 82: '사라지다', 83: '사용', 84: '사진', 85: '상관(관여)없다', 86: '상태', 87: '생각', 88: '생기다', 89: '선생님', 90: '설명', 91: '수술', 92: '술(주류)', 93: '쉽다', 94: '시간', 95: '시작', 96: '심장', 97: '심하다', 98: '아니다', 99: '아니면', 100: '아마', 101: '아직', 102: '아침', 103: '아프다', 104: '안과(전문)', 105: '안내', 106: '안녕', 107: '안되다', 108: '알다', 109: '암', 110: '약(물질)', 111: '어떻게', 112: '어렵다', 113: '어지럽다', 114: '없다', 115: '예약', 116: '오다', 117: '오래', 118: '요즘', 119: '우리(나의무리)', 120: '우선', 121: '운동', 122: '움직이다', 123: '원래', 124: '원인', 125: '원하다', 126: '이유', 127: '이해', 128: '일(업무)', 129: '입원', 130: '있다', 131: '잘', 132: '잠깐', 133: '저녁', 134: '전(시간에)', 135: '점수', 136: '정도', 137: '정상(제대로)', 138: '조금', 139: '조심', 140: '조절', 141: '좋다', 142: '주다', 143: '주사(행동)', 144: '줄이다', 145: '중(가운데)', 146: '중요', 147: '증상', 148: '지금', 149: '지내다', 150: '진행', 151: '집', 152: '참다', 153: '충분', 154: '치료', 155: '콩팥', 156: '특별', 157: '편하다', 158: '피', 159: '필요', 160: '필요없다', 161: '하다', 162: '하루', 163: '함께', 164: '항상', 165: '해(하다)보다', 166: '허리', 167: '혈당', 168: '확인', 169: '환자', 170: '회복', 171: '후', 172: '힘들다'}

    sequence = []
    sentence = []
    threshold = 0.8

    tracker = LandmarkTracker()

    with mp.solutions.holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
        while True:
            if video_capture.isOpened():
                ret, frame = video_capture.read()
                if not ret:
                    print("Video capture failed")
                    break
    
                # Draw landmarks
                image, results = mediapipe_detection(frame, holistic)
    
                # Check if dispose condition is met
                if tracker.dispose(results):
                    print("Dispose condition met: No landmarks for 10 seconds or only pose detected")
                    break
    
                # Extract keypoints
                draw_styled_landmarks(image, results)
    
                keypoints = extract_keypoints(results)
                if keypoints is not None:
                    sequence.append(keypoints)
                    sequence = sequence[-30:]
    
                    if len(sequence) == 30:
                        res = model.predict(np.expand_dims(sequence, axis=0))[0]
                        predicted_action = actions[np.argmax(res)]
    
                        if res[np.argmax(res)] > threshold:
                            if len(sentence) > 0 and predicted_action != sentence[-1]:
                                sentence.append(predicted_action)
                            elif not sentence:
                                sentence.append(predicted_action)
    
                            if len(sentence) > 10:
                                sentence = sentence[-10:]
    
                            # 중간 단어 전송 (검정색으로 표시될 단어들)
                            if "FINAL_OUTPUT:" not in ' '.join(sentence):
                                mqtt_publisher.send_text(' '.join(sentence))  # 중간 단어들을 전송합니다.
                                print(sentence)
                                sequence = []
                            
                            # 최종 문장 전송 (Llama 모델을 통해 생성된 문장만 파란색으로 표시)
                            if "FINAL_OUTPUT:" in ' '.join(sentence):
                                final_sentence = ' '.join(sentence).replace("FINAL_OUTPUT:", "").strip()
                                mqtt_publisher.send_text(f"FINAL_OUTPUT: {final_sentence}", is_final=True)  # Llama에서 생성된 최종 문장만 전송합니다.
                                break
    
                mqtt_publisher.send_base64(image)
    
                if cv2.waitKey(10) == 27:
                    break

    send_llama(' '.join(sentence), mqtt_publisher)
    mqtt_publisher.disconnect()

    video_capture.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()

  self.client = mqtt.Client()


MQTT broker connected
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 348ms/step
['오다']
InputText: 오다
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
['오다', '곳']
InputText: 오다 곳
Dispose condition met: No landmarks for 10 seconds or only pose detected
여기 왔어요
OutputText: 여기 왔어요
MQTT broker disconnected
