In [1]:
%run ./global.ipynb

In [2]:
import json

from pymilvus import (
    MilvusClient,
    CollectionSchema, 
    FieldSchema, 
    DataType
)

In [3]:
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="vector_title", dtype=DataType.FLOAT_VECTOR, dim=STO_EMB_DIM),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=STO_EMB_DIM),
    FieldSchema(name="question_title", dtype=DataType.VARCHAR, max_length=10000),
    FieldSchema(name="question", dtype=DataType.VARCHAR, max_length=10000),
    FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=1000),
    FieldSchema(name="answers", dtype=DataType.VARCHAR, max_length=10000),
    FieldSchema(name="num_vote", dtype=DataType.VARCHAR, max_length=8),
    FieldSchema(name="num_answer", dtype=DataType.VARCHAR, max_length=8)
]
schema = CollectionSchema(
    fields=fields,
    auto_id=True,
    description="Schema for stackoverflow"
)

In [4]:
client = MilvusClient(DBFILEPATH)
get_new_collection(STO_CLNAME, dimension=STO_EMB_DIM, schema=schema, auto_id=True)

Created new connection using: b0442bf37fcf4a41ab80b668b19c6778
Successfully created collection: stackoverflow


In [5]:
index_params = client.prepare_index_params()
index_params.add_index(
    field_name="vector",
    metric_type="COSINE",
    index_type="FLAT",
    index_name="vector_index",
    params={"nlist": 128 }
)
index_params.add_index(
    field_name="vector_title",
    metric_type="COSINE",
    index_type="FLAT",
    index_name="vector_title_index",
    params={"nlist": 128 }
)


client.create_index(
    collection_name=STO_CLNAME,
    index_params=index_params
)

Successfully created an index on collection: stackoverflow
Successfully created an index on collection: stackoverflow


In [6]:
with open("stackoverflow.json", "r") as f:
    docs = json.load(f)


num_docs = len(docs)


for i, doc in enumerate(docs):
    print(f"processing {i}/{num_docs}")
    doc["vector"] = embed_sto([truncate_text(doc["question"])])[0]
    doc["vector_title"] = embed_sto([truncate_text(doc["question_title"])])[0]

processing 0/131
262
19


Token indices sequence length is longer than the specified maximum sequence length for this model (651 > 512). Running this sequence through the model will result in indexing errors
text was truncated: 651 -> 300


processing 1/131
651
10
processing 2/131
223
34
processing 3/131
170
26


text was truncated: 604 -> 300


processing 4/131
604
19


text was truncated: 3373 -> 300


processing 5/131
3373
24
processing 6/131
155
16


text was truncated: 1235 -> 300


processing 7/131
1235
12
processing 8/131
71
21


text was truncated: 373 -> 300


processing 9/131
373
12


text was truncated: 1622 -> 300


processing 10/131
1622
35


text was truncated: 411 -> 300


processing 11/131
411
11
processing 12/131
259
29


text was truncated: 525 -> 300


processing 13/131
525
13


text was truncated: 358 -> 300


processing 14/131
358
15


text was truncated: 1250 -> 300


processing 15/131
1250
17


text was truncated: 460 -> 300


processing 16/131
460
14
processing 17/131
298
21
processing 18/131
232
36
processing 19/131
127
20


text was truncated: 2207 -> 300


processing 20/131
2207
8
processing 21/131
221
6


text was truncated: 362 -> 300


processing 22/131
362
12
processing 23/131
289
30
processing 24/131
178
19
processing 25/131
204
16


text was truncated: 351 -> 300


processing 26/131
351
7
processing 27/131
211
19
processing 28/131
123
11


text was truncated: 525 -> 300


processing 29/131
525
8


text was truncated: 307 -> 300


processing 30/131
307
12
processing 31/131
262
11


text was truncated: 572 -> 300


processing 32/131
572
12
processing 33/131
88
17
processing 34/131
145
32
processing 35/131
292
16


text was truncated: 508 -> 300


processing 36/131
508
14


text was truncated: 4378 -> 300


processing 37/131
4378
15
processing 38/131
95
9


text was truncated: 616 -> 300


processing 39/131
616
9


text was truncated: 599 -> 300


processing 40/131
599
10


text was truncated: 875 -> 300


processing 41/131
875
10


text was truncated: 323 -> 300


processing 42/131
323
33


text was truncated: 944 -> 300


processing 43/131
944
14


text was truncated: 1169 -> 300


processing 44/131
1169
20


text was truncated: 826 -> 300


processing 45/131
826
21
processing 46/131
270
14


text was truncated: 433 -> 300


processing 47/131
433
21
processing 48/131
239
10


text was truncated: 1241 -> 300


processing 49/131
1241
14
processing 50/131
122
32


text was truncated: 3128 -> 300


processing 51/131
3128
22


text was truncated: 1012 -> 300


processing 52/131
1012
30
processing 53/131
297
15


text was truncated: 614 -> 300


processing 54/131
614
25


text was truncated: 498 -> 300


processing 55/131
498
17


text was truncated: 538 -> 300


processing 56/131
538
28


text was truncated: 402 -> 300


processing 57/131
402
15


text was truncated: 309 -> 300


processing 58/131
309
45
processing 59/131
110
35
processing 60/131
211
42
processing 61/131
182
19


text was truncated: 463 -> 300


processing 62/131
463
45
processing 63/131
81
14


text was truncated: 483 -> 300


processing 64/131
483
11
processing 65/131
147
14


text was truncated: 375 -> 300


processing 66/131
375
19


text was truncated: 540 -> 300


processing 67/131
540
17
processing 68/131
111
25


text was truncated: 1098 -> 300


processing 69/131
1098
20


text was truncated: 1163 -> 300


processing 70/131
1163
31
processing 71/131
284
11


text was truncated: 351 -> 300


processing 72/131
351
17
processing 73/131
110
14
processing 74/131
73
14
processing 75/131
59
16
processing 76/131
73
8
processing 77/131
130
17
processing 78/131
92
16
processing 79/131
190
47


text was truncated: 509 -> 300


processing 80/131
509
17
processing 81/131
232
12
processing 82/131
189
10
processing 83/131
52
18
processing 84/131
53
13


text was truncated: 800 -> 300


processing 85/131
800
17


text was truncated: 504 -> 300


processing 86/131
504
40


text was truncated: 977 -> 300


processing 87/131
977
33
processing 88/131
84
18
processing 89/131
176
18
processing 90/131
101
17
processing 91/131
150
18
processing 92/131
118
9


text was truncated: 717 -> 300


processing 93/131
717
12
processing 94/131
131
17


text was truncated: 301 -> 300


processing 95/131
301
13


text was truncated: 2958 -> 300


processing 96/131
2958
35
processing 97/131
12
17


text was truncated: 403 -> 300


processing 98/131
403
21
processing 99/131
76
9
processing 100/131
128
8
processing 101/131
126
17
processing 102/131
28
11


text was truncated: 337 -> 300


processing 103/131
337
14
processing 104/131
115
10


text was truncated: 350 -> 300


processing 105/131
350
20
processing 106/131
95
37


text was truncated: 422 -> 300


processing 107/131
422
23


text was truncated: 345 -> 300


processing 108/131
345
24


text was truncated: 2952 -> 300


processing 109/131
2952
45
processing 110/131
50
12


text was truncated: 493 -> 300


processing 111/131
493
9
processing 112/131
122
10
processing 113/131
245
11
processing 114/131
45
19
processing 115/131
71
12
processing 116/131
67
16
processing 117/131
41
13
processing 118/131
48
19
processing 119/131
95
23
processing 120/131
38
29


text was truncated: 776 -> 300


processing 121/131
776
17
processing 122/131
41
16
processing 123/131
88
10
processing 124/131
136
17


text was truncated: 1469 -> 300


processing 125/131
1469
15


text was truncated: 1109 -> 300


processing 126/131
1109
14
processing 127/131
226
12
processing 128/131
84
9
processing 129/131
44
30
processing 130/131
129
12


In [7]:
def schema_check(record, fields):
    excess_columns = set(record.keys()) - set([f.name for f in fields])
    if excess_columns:
        print(f"addigional columns: {excess_columns}")
    for f in fields:
        column_name = f.name
        if column_name == "id":
            continue
        if column_name not in record.keys():
            missing_columns.append(column_name)
            print(f"missing: {column_name}")
            continue
        if f.dtype == DataType.INT64:
            if not isinstance(record[column_name], int):
                print(f"wrong data type: {column_name}")
        if f.dtype == DataType.VARCHAR:
            if not isinstance(record[column_name], str):
                print(f"wrong data type: {column_name}")
            if f.params["max_length"] < len(record[column_name]):
                print(f"too long string: {column_name} ({len(record[column_name])})")
        if f.dtype == DataType.FLOAT_VECTOR:
            if len(record[column_name]) != f.dim:
                print(f"dim mismatch {column_name}: {len(record[column_name])} != {f.dim}")

In [8]:
import copy


insert_data = copy.deepcopy(docs)
for elem in insert_data:
    del elem["num_view"]
    elem["answers"] = str(elem["answers"])
    if len(elem["question"]) > 10000:
        elem["question"] = elem["question"][:10000]
    schema_check(elem, fields)

In [10]:
res = client.insert(collection_name=STO_CLNAME, data=insert_data)

In [11]:
res

{'insert_count': 131,
 'ids': [451482685656268800, 451482685656268801, 451482685656268802, 451482685656268803, 451482685656268804, 451482685656268805, 451482685656268806, 451482685656268807, 451482685656268808, 451482685656268809, 451482685656268810, 451482685656268811, 451482685656268812, 451482685656268813, 451482685656268814, 451482685656268815, 451482685656268816, 451482685656268817, 451482685656268818, 451482685656268819, 451482685656268820, 451482685656268821, 451482685656268822, 451482685656268823, 451482685656268824, 451482685656268825, 451482685656268826, 451482685656268827, 451482685656268828, 451482685656268829, 451482685656268830, 451482685656268831, 451482685656268832, 451482685656268833, 451482685656268834, 451482685656268835, 451482685656268836, 451482685656268837, 451482685656268838, 451482685656268839, 451482685656268840, 451482685656268841, 451482685656268842, 451482685656268843, 451482685656268844, 451482685656268845, 451482685656268846, 451482685656268847, 451482685