# Inference

```
docker-compose up 
```

## connection

In [1]:
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams

client = QdrantClient("localhost", port=6333)
client.recreate_collection(
    collection_name="business",
    vectors_config=VectorParams(size=128, distance=Distance.DOT),
)

True

In [2]:
client.get_collection(collection_name="business")

CollectionInfo(status=<CollectionStatus.GREEN: 'green'>, optimizer_status=<OptimizersStatusOneOf.OK: 'ok'>, vectors_count=0, indexed_vectors_count=0, points_count=0, segments_count=5, config=CollectionConfig(params=CollectionParams(vectors=VectorParams(size=128, distance=<Distance.DOT: 'Dot'>), shard_number=1, replication_factor=1, write_consistency_factor=1, on_disk_payload=True), hnsw_config=HnswConfig(m=16, ef_construct=100, full_scan_threshold=10000, max_indexing_threads=0, on_disk=False, payload_m=None), optimizer_config=OptimizersConfig(deleted_threshold=0.2, vacuum_min_vector_number=1000, default_segment_number=0, max_segment_size=None, memmap_threshold=None, indexing_threshold=20000, flush_interval_sec=5, max_optimization_threads=1), wal_config=WalConfig(wal_capacity_mb=32, wal_segments_ahead=0), quantization_config=None), payload_schema={})

## Populate collection

In [1]:
import pandas as pd
import dgl
import torch

In [2]:
# we can do it also as call on Neo4j
business_ids = pd.read_csv('preprocessed/business_ids.csv')
category_ids = pd.read_csv('preprocessed/category_ids.csv')
category_df = pd.read_csv('neo4j_csvs/categories.csv')
category_business_rels = pd.read_csv('neo4j_csvs/category_business_rels.csv')
review_business_rels = pd.read_csv('neo4j_csvs/review_business_rels.csv')
user_review_rels = pd.read_csv('neo4j_csvs/user_review_rels.csv')

In [5]:
graph = dgl.load_graphs('training_data/graph.dgl')
graph = graph[0][0]
edges = {}
for canonical_etype in graph.canonical_etypes:
    edges[canonical_etype] = graph.edges(etype=canonical_etype)

# we add the reversed relations
edges[('category', 'category_to_business', 'business')] = (graph.edges(etype='business_has_category')[1], graph.edges(etype='business_has_category')[0])
edges[('business', 'business_to_review', 'review')] = (graph.edges(etype='review_to_business')[1], graph.edges(etype='review_to_business')[0])
edges[('business', 'business_to_tip', 'tip')] = (graph.edges(etype='tip_to_business')[1], graph.edges(etype='tip_to_business')[0])
edges[('review', 'review_to_user', 'user')] = (graph.edges(etype='user_to_review')[1], graph.edges(etype='user_to_review')[0])
edges[('tip', 'tip_to_user', 'user')] = (graph.edges(etype='user_to_tip')[1], graph.edges(etype='user_to_tip')[0])

num_nodes_dict = {} 
for ntype in graph.ntypes:
    num_nodes_dict[ntype] = graph.nodes(ntype).shape[0]

g = dgl.heterograph(edges, num_nodes_dict = num_nodes_dict)
g.ndata['feat'] = {k: torch.tensor(v, dtype=torch.float32) for k, v in graph.ndata['feat'].items() }
del graph

  g.ndata['feat'] = {k: torch.tensor(v, dtype=torch.float32) for k, v in graph.ndata['feat'].items() }


In [6]:
business_ids # ordered as business in graph dgl

Unnamed: 0,business_id:ID
0,cc3c6f99cdb7d899625e4e7b8d171a06
1,e1a902b3497f013225f4691a03c73a5a
2,63cee85ba2c884484d2104a806628c9d
3,03e04ca2a13470c101717367fdb707a7
4,ceba4d84e8014b380329fabe76308016
...,...
150238,1f1304e276347c5faee365f9abb00799
150239,ca111c03164a5425a9c5167b3a638895
150240,ff03cc577392e0ac794cdb4ebac46c3b
150241,abb853eceedb9401aba67a2897136eff


In [7]:
category_business_rels # business -> category

Unnamed: 0,:START_ID,:END_ID,:TYPE
0,cc3c6f99cdb7d899625e4e7b8d171a06,f77ccbdb203c19d3d52b12a85f33faf5,HAS_CATEGORY
1,cc3c6f99cdb7d899625e4e7b8d171a06,5dc2a02d462a6822a49f8419cdfcf29f,HAS_CATEGORY
2,cc3c6f99cdb7d899625e4e7b8d171a06,957bd1f3ec5f6e976c4e82257f55d1fb,HAS_CATEGORY
3,cc3c6f99cdb7d899625e4e7b8d171a06,5df806cd4d51dddbb591ad8df0fd4c42,HAS_CATEGORY
4,cc3c6f99cdb7d899625e4e7b8d171a06,6e3a69b2ebe159d183c7a7b83d0bf564,HAS_CATEGORY
...,...,...,...
668587,abb853eceedb9401aba67a2897136eff,af4b3609b5a35bb1d3796fac29e1c7b6,HAS_CATEGORY
668588,d367ac4f22db972bce70ab8fe5a50e2b,b11bd467d57c86f09435a96a212e89be,HAS_CATEGORY
668589,d367ac4f22db972bce70ab8fe5a50e2b,4075aad3e4473e2827bc6988e03a6e90,HAS_CATEGORY
668590,d367ac4f22db972bce70ab8fe5a50e2b,71b82d588f8ca5360dd8f79f4452f61a,HAS_CATEGORY


In [8]:
from utils import HGT
node_dict = { ntype: g.ntypes.index(ntype) for ntype in g.ntypes }
edge_dict = { canonical_etype: g.canonical_etypes.index(canonical_etype) for canonical_etype in g.canonical_etypes }
feature_dim_dict = { ntype: g.ndata['feat'][ntype].shape[1] for ntype in g.ntypes }
model = HGT(node_dict, edge_dict, feature_dim_dict, n_hid=256, n_out=128, n_layers=4, n_heads=8, use_norm=False)
sampler = dgl.dataloading.NeighborSampler([24, 24, 24, 24])

In [None]:
model.load_state_dict(torch.load('models/best_model.pt')['model_state_dict'])

In [10]:
from qdrant_client.http.models import PointStruct
from qdrant_client.http.models import UpdateStatus

batch_size = 64
model.eval()
def split(df, chunk_size):
    for i in range(df.index.start, df.index.stop, chunk_size):
        yield list(range(i, i + chunk_size if df.index.stop > i + chunk_size else df.index.stop)), df.loc[i:i + chunk_size - 1]

for batch in split(business_ids, batch_size):
    with torch.no_grad():
        pos_batch_business_ids, batch_business_ids = batch

        pos_block_business = [blocks for _, _, blocks in dgl.dataloading.DataLoader(
            g, { 'business': pos_batch_business_ids }, sampler,
            batch_size=len(pos_batch_business_ids), shuffle=False, drop_last=False, num_workers=1)][0]
        pos_business_logits = model(pos_block_business, 'business')
        batch_categories = batch_business_ids.merge(category_business_rels, left_on='business_id:ID', right_on=':START_ID')
        businesses_its_categories = batch_categories.groupby(by=["business_id:ID"], sort=False)[':END_ID'].apply(list).reset_index()
        
        batch_business_ids = batch_business_ids['business_id:ID'].to_list()
        businesses_its_categories = businesses_its_categories[':END_ID'].to_list()
        pos_business_logits = pos_business_logits.tolist()
        points = []
        for i in range(len(pos_batch_business_ids)):
            points.append(
                PointStruct(id=pos_batch_business_ids[i], vector=pos_business_logits[i], payload={'id': batch_business_ids[i], 'categories': businesses_its_categories[i]})
            )
        operation_info = client.upsert(
            collection_name="business",
            wait=True,
            points=points
        )
        assert operation_info.status == UpdateStatus.COMPLETED


  assert input.numel() == input.storage().size(), (


In [14]:
from qdrant_client import QdrantClient

client = QdrantClient("localhost", port=6333, timeout=100)

client.create_snapshot(collection_name='business')

SnapshotDescription(name='business-4626092216545756212-2023-04-03-15-12-56.snapshot', creation_time=None, size=67364864)

In [11]:
client.get_collection(collection_name="business")

CollectionInfo(status=<CollectionStatus.GREEN: 'green'>, optimizer_status=<OptimizersStatusOneOf.OK: 'ok'>, vectors_count=64, indexed_vectors_count=0, points_count=64, segments_count=5, config=CollectionConfig(params=CollectionParams(vectors=VectorParams(size=128, distance=<Distance.DOT: 'Dot'>), shard_number=1, replication_factor=1, write_consistency_factor=1, on_disk_payload=True), hnsw_config=HnswConfig(m=16, ef_construct=100, full_scan_threshold=10000, max_indexing_threads=0, on_disk=False, payload_m=None), optimizer_config=OptimizersConfig(deleted_threshold=0.2, vacuum_min_vector_number=1000, default_segment_number=0, max_segment_size=None, memmap_threshold=None, indexing_threshold=20000, flush_interval_sec=5, max_optimization_threads=1), wal_config=WalConfig(wal_capacity_mb=32, wal_segments_ahead=0), quantization_config=None), payload_schema={})

## Recommendations

In [12]:
user_ids = pd.read_csv('preprocessed/user_ids.csv')

In [13]:
input_user_id = 'MGPQVLsODMm9ZtYQW-g_OA'
input_category_id = '386c1f850fbd5f478fb4ef8a134c1740'

In [14]:
pos_user_id = user_ids[user_ids['user_id:ID'] == input_user_id].index[0] # getting DGL id to create user embedding

In [15]:
user_ids

Unnamed: 0,user_id:ID
0,qVc8ODYU5SZjKXVBgXdI7w
1,j14WgRoU_-2ZE1aw1dXrJg
2,2WnXYQFK0hXEoTxPtV2zvg
3,SZDeASXq7o05mMNLshsdIA
4,hA5lMy-EnncsH4JoR-hFGQ
...,...
1987892,fB3jbHi3m0L2KgGOxBv6uw
1987893,68czcr4BxJyMQ9cJBm6C7Q
1987894,1x3KMskYxOuJCjRz70xOqQ
1987895,ulfGl4tdbrH05xKzh5lnog


In [16]:
already_reviewed_businesses = user_ids[user_ids['user_id:ID'] == input_user_id] \
    .merge(user_review_rels, left_on='user_id:ID', right_on=':START_ID') \
    .merge(review_business_rels, left_on=':END_ID', right_on=':START_ID', suffixes=('rev', 'business'))[':END_IDbusiness'].to_list()

In [17]:
len(already_reviewed_businesses)

127

In [20]:
pos_block_user = [blocks for _, _, blocks in dgl.dataloading.DataLoader(
    g, {'user': [pos_user_id] }, sampler,
    batch_size=batch_size, shuffle=False, drop_last=False, num_workers=1)][0]
with torch.no_grad():
    pos_user_logits = model(pos_block_user, 'user')

In [27]:
from qdrant_client.http.models import Filter, FieldCondition, MatchValue, MatchAny


search_result = client.search(
    collection_name="business",
    query_vector=pos_user_logits.tolist()[0], 
    query_filter=Filter(
        must=[
            FieldCondition(
                key="categories",
                match=MatchValue(value=input_category_id)
            )
        ],
        must_not=[
            FieldCondition(
                key="id", 
                match=MatchAny(any=already_reviewed_businesses)
            ),
        ],
    ),
    limit=100
)

In [None]:
search_result