-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation_metric.py
114 lines (86 loc) · 4.14 KB
/
evaluation_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# This file contains evaluation metrics for the dialogue system
# [このファイルには、対話システムの評価指標が含まれています。]
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import re
from Queries import *
# --------------------------------------------------------------------------------
dialogue = """
Mana Yoshida: "I saw Hina Sato near the well this morning. She was acting a bit suspicious. We should investigate her further."
Hina Sato: "I was at the well with Riku Mori, trying to refill water supplies. I didn't see anything unusual."
Riku Mori: "I observed Yumi Okada and Hina Sato at the well and tavern. Let's investigate Yumi Okada's recent actions further."
Taichi Kato: "We need to consider Riku Mori's actions and whereabouts. He might be collaborating with the werewolf."
Yumi Okada: "We should investigate Mana Yoshida's actions too. She might be deflecting suspicion by accusing Hina Sato."
Mana Yoshida: "I agree that Riku Mori's actions should be investigated. But we should also consider Yumi Okada's behavior."
Yumi Okada: "Let's not forget to investigate Yuria Shimizu's actions. She might be the werewolf's accomplice."
Hina Sato: "I think we should investigate Taichi Kato's actions as well. We don't have any information about him yet."
"""
# Turn-taking ratio metric [出番率指標]
def get_turn_taking_ratio(dialogue):
agents = set([line.split(":")[0].strip() for line in dialogue.split("\n")])
agent_turns = {agent: dialogue.count(agent + ":") for agent in agents}
total_turns = sum(agent_turns.values())
if total_turns == 0:
return 0
return sum(list({agent: turns / total_turns for agent, turns in agent_turns.items()}.values()))
# --------------------------------------------------------------------------------
# Response relevance metric [レスポンスの関連性指標]
def calculate_response_relevance(dialogue):
agents = []
utterances = []
lines = dialogue.strip().split("\n")
for line in lines:
parts = line.split(":")
if len(parts) >= 2:
agent = parts[0].strip()
utterance = ":".join(parts[1:]).strip()
agents.append(agent)
utterances.append(utterance)
num_utterances = len(utterances)
if num_utterances < 2:
return [0.0]
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform(utterances)
similarity_matrix = cosine_similarity(tfidf_matrix)
response_relevance_scores = []
total = 0
for i in range(num_utterances):
response_relevance = sum(similarity_matrix[i]) - similarity_matrix[i][i]
response_relevance_scores.append(response_relevance)
total += response_relevance
return response_relevance_scores, total/num_utterances
# Example usage
# response_relevance, avg = calculate_response_relevance(dialogue)
# print("Response Relevance:", response_relevance)
# print("Average Response Relevance:", avg)
# --------------------------------------------------------------------------------
# Conversation agreement metric [会話の一致の指標]
def calculate_agreement_metric(dialogue):
agents = []
utterances = []
lines = dialogue.strip().split("\n")
for line in lines:
parts = line.split(":")
if len(parts) >= 2:
agent = parts[0].strip()
utterance = ":".join(parts[1:]).strip()
agents.append(agent)
utterances.append(utterance)
num_agents = len(agents)
if num_agents < 2:
return 0.0
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform(utterances)
similarity_matrix = cosine_similarity(tfidf_matrix)
agreement = 0.0
count = 0
for i in range(num_agents):
for j in range(i+1, num_agents):
agreement += similarity_matrix[i][j]
count += 1
if count == 0:
return 0.0
agreement_metric = agreement / count
return agreement_metric
# agreement_metric = calculate_agreement_metric(dialogue)
# print("Agreement Metric:", agreement_metric*10)