## Notes
- 215서버 가상환경 bkms에서 작업

In [1]:
import pandas as pd
import numpy as np
import nltk
from nltk import word_tokenize, sent_tokenize
import torch
from torch.nn.functional import normalize
from tqdm.notebook import tqdm
import psycopg2
import psycopg2.extras as extras

## Load

In [2]:
%%time
# Abstract Embedding 로드
emb = np.load('abstract_embedding.npy')

CPU times: user 728 µs, sys: 356 ms, total: 357 ms
Wall time: 356 ms


In [4]:
emb.shape

(644800, 768)

In [3]:
%%time
emb_list = emb.tolist()

CPU times: user 14.7 s, sys: 4.42 s, total: 19.1 s
Wall time: 18.9 s


In [6]:
len(emb_list)

644800

In [4]:
%%time
# df 로드

# server table 확인
connection_info = "host=147.47.200.145 dbname=teamdb16 user=team16 password=qwer16 port=34543"
conn = psycopg2.connect(connection_info)

try:
    # 테이블을 Pandas.Dataframe으로 추출
    server_df = pd.read_sql('SELECT * FROM citation_data',conn)

except psycopg2.Error as e:
    # 데이터베이스 에러 처리
    print("DB error: ", e)
    
finally:
    # 데이터베이스 연결 해제 필수!!
    conn.close()



CPU times: user 1.89 s, sys: 1.29 s, total: 3.18 s
Wall time: 11.7 s


## DF + Embedding 합치기

In [5]:
server_df['embedding'] = emb_list

In [6]:
df = server_df[['id', 'embedding']]

In [7]:
df.head()

Unnamed: 0,id,embedding
0,55323df545cec66b6f9e0cd9,"[0.03598175197839737, -0.0326111875474453, 0.0..."
1,55323df545cec66b6f9e0ce2,"[0.04400291666388512, -0.004605427850037813, 0..."
2,55323df845cec66b6f9e0dfd,"[0.004515576176345348, -0.012031725607812405, ..."
3,55323df845cec66b6f9e0e0d,"[0.0014496600488200784, -0.05030417442321777, ..."
4,55323dfa45cec66b6f9e0e97,"[0.03331814706325531, -0.014083134941756725, 0..."


## FOS가 Database인 row만 추출

In [8]:
df_db = df[server_df['fos']=='Database']
df_db

Unnamed: 0,id,embedding
81475,53e99bc0b7602d970246566f,"[0.033884674310684204, -0.013931008987128735, ..."
81476,53e9a0aeb7602d9702998956,"[-0.005794344004243612, -0.01751815341413021, ..."
81477,53e9a32eb7602d9702c3b5f2,"[0.042893629521131516, 0.02281208708882332, 0...."
81478,53e9a5e2b7602d9702f0f666,"[0.04637527838349342, -0.04265362024307251, 0...."
81479,53e9a64ab7602d9702f773e2,"[0.02971530146896839, -0.027375876903533936, 0..."
...,...,...
101282,5e56424393d709897ccf850c,"[-0.01572621427476406, -0.028321608901023865, ..."
101284,5e5e190b93d709897ce49b5e,"[0.031522300094366074, -0.012752724811434746, ..."
101285,5e67652c91e011e07454bac1,"[0.0552104115486145, 0.0031533159781247377, 0...."
101286,5e68b99d91e0115a6fd942b7,"[-0.02662317454814911, -0.024163400754332542, ..."


## FOS가 Machine learning인 row만 추출 (20000개 샘플링)

In [9]:
df_ml = df[server_df['fos']=='Machine learning'].sample(n=20000, random_state=42)
df_ml

Unnamed: 0,id,embedding
309493,53e9a69fb7602d9702fd2ba2,"[0.006964611820876598, 0.04274142161011696, 0...."
389858,5bdc316717c44a1f58a06f35,"[-0.0011072474298998713, -0.020074525848031044..."
384300,5a4aef9017c44a2190f79790,"[0.04698636010289192, 0.043299369513988495, 0...."
393082,5c790c78f56def97985832e2,"[0.06247767433524132, 0.04102949798107147, 0.0..."
396894,5decef723a55ac3b267b6df8,"[-0.016124337911605835, -0.004462655168026686,..."
...,...,...
297057,53e99e7fb7602d9702743649,"[0.008742188103497028, -0.02152179554104805, 0..."
393493,5c8746794895d9cbc6f9bdfd,"[0.029180074110627174, 0.01242499053478241, 0...."
358436,558b1bd1e4b031bae1fb8a03,"[0.018051421269774437, -0.003930770326405764, ..."
293015,53e99bfeb7602d97024a7d0e,"[0.023785250261425972, -0.007056978065520525, ..."


## Milvus

In [17]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

# Global
TOP_K = 3
COLLECTION_NAME = 'team16_project_db'

def create_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='ids', max_length=500, is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='Team 16 Project')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    
    print("\ncollection created:", collection_name)
    return collection

def search(collection, search_vectors, top_k=3):
    search_param = {
        "data": search_vectors,
        "anns_field": 'embedding',
        "param": {"metric_type": 'L2', "params": {"nprobe": 16}},
        "limit": top_k}
    results = collection.search(**search_param)
    result_id = [res.id for res in results[0]]
    result_dis = [res.distance for res in results[0]]

    return result_id, result_dis

In [11]:
def has_collection(name):
    return utility.has_collection(name)

def load_collection(collection):
    collection.load()

def release_collection(collection):
    collection.release()

# Drop a collection in Milvus
def drop_collection(name):
    collection = Collection(name)
    collection.drop()
    print("\nDrop collection: {}".format(name))
    
def drop_index(collection):
    collection.drop_index()
    print("\nDrop index sucessfully")
    
# List all collections in Milvus
def list_collections():
    print("\nlist collections:")
    print(utility.list_collections())
    
def get_entity_num(collection):
    print("\nThe number of entity:")
    print(collection.num_entities)
    
def set_properties(collection):
    collection.set_properties(properties={"collection.ttl.seconds": 1800})

In [12]:
# Milvus 서버랑 연결
connections.connect(host='147.47.200.145', port='39530')

In [18]:
# drop collection if the collection exists
if has_collection(COLLECTION_NAME):
    drop_collection(COLLECTION_NAME)
    
# create collection
collection = create_collection(COLLECTION_NAME, 768)

# alter ttl properties of collection level
set_properties(collection)


collection created: team16_project_db


In [19]:
%%time
# insert data 
# 한 번에 20000 row 이상 하면 RESOURCE_EXHAUSTED 에러 남
collection.insert(df_db)

CPU times: user 2.87 s, sys: 6.19 ms, total: 2.88 s
Wall time: 4.74 s


(insert count: 19791, delete count: 0, upsert count: 0, timestamp: 441962994957549571, success count: 19791, err count: 0)

In [20]:
%%time
# flush collection from memory
collection.flush()
# get the number of entities
get_entity_num(collection)


The number of entity:
19791
CPU times: user 4.85 ms, sys: 726 µs, total: 5.58 ms
Wall time: 713 ms


In [22]:
list_collections()


list collections:
['team16_project', 'team16_project_ml', 'team16_project_db']
