-
Notifications
You must be signed in to change notification settings - Fork 0
/
process_midjourney.py
127 lines (99 loc) · 3.3 KB
/
process_midjourney.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
from langchain.embeddings import OpenAIEmbeddings
from sklearn.metrics.pairwise import cosine_similarity
openai_api_key = "sk-..."
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
f = open('intermediate.json')
with_embeddings = json.load(f)
f.close()
print("Loaded data...")
print("Computing embeddings:")
with_embeddings = []
for element in data:
try:
text = element["prompt"]
print(f"Computing embedding for prompt: {text}")
query_result = embeddings.embed_query(text)
element["embedding"] = query_result
with_embeddings.append(element)
except:
print(f"There was an error on element: {element}")
with open("intermediate.json", "w") as outfile:
outfile.write(json.dumps(with_embeddings, indent=4))
with_embeddings = with_embeddings[:1000]
embedding_matrix = []
for element in with_embeddings:
embedding_matrix.append(element["embedding"])
similarities = cosine_similarity(embedding_matrix)
print(f"Length of cosine similarity matrix: {len(similarities)}")
# Actual input_json
# input_json = {
# "nodes" : [],
# "links" : []
# }
input_json = {
"nodes" : [],
"links" : {}
}
# {
# "nodes": [
# {
# "id": "id1",
# "name": "name1",
# "val": 1
# },
# {
# "id": "id2",
# "name": "name2",
# "val": 10
# },
# ...
# ],
# "links": [
# {
# "source": "id1",
# "target": "id2"
# },
# ...
# ]
# }
# array([[1. , 0.77911004, 0.81937526, 0.80392125, 0.7655019 ],
# [0.77911004, 1. , 0.79608392, 0.81086091, 0.75359267],
# [0.81937526, 0.79608392, 1. , 0.79420443, 0.79456192],
# [0.80392125, 0.81086091, 0.79420443, 1. , 0.7746613 ],
# [0.7655019 , 0.75359267, 0.79456192, 0.7746613 , 1. ]])
# for n, element in enumerate(with_embeddings):
# print(f"Processing element: #{n}")
# input_json["nodes"].append({
# "id" : element["id"],
# "img_url" : element["image_paths"][0],
# "prompt" : element["prompt"]
# })
# similarity_m = similarities[n]
# for i, s in enumerate(similarity_m[n+1:]):
# if s > 0.8 and element["id"] != with_embeddings[i]["id"]:
# input_json["links"].append({
# "source" : element["id"],
# "target" : with_embeddings[i]["id"],
# "strength": s
# })
for n, element in enumerate(with_embeddings):
print(f"Processing element: #{n}")
input_json["nodes"].append({
"id" : element["id"],
"img_url" : element["image_paths"][0],
"prompt" : element["prompt"]
})
input_json["links"][element["id"]] = []
similarity_m = similarities[n]
for i, s in enumerate(similarity_m[n+1:]):
if s > 0.8 and element["id"] != with_embeddings[i]["id"]:
input_json["links"][element["id"]].append([with_embeddings[i]["id"], s])
# input_json["links"].append({
# "source" : element["id"],
# "target" : with_embeddings[i]["id"],
# "strength": s
# })
json_object = json.dumps(input_json, indent=4)
with open("network_new_format_100.json", "w") as outfile:
outfile.write(json_object)