In [1]:
from pathlib import Path
from pymilvus import MilvusClient
import os
import numpy as np
import pandas as pd
import torch
import tqdm
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torch.nn.functional as F
from data_modules.mind_recsys_data import MINDRecSysDataModule
from data_modules.mind_component import load_news_data, load_history_data
from modules.llama_decoder import LlamaDecoderForNextArticle
from modules.res_vqvae import RVQVAE
from modules.lstur import LSTUR
from data_modules.indices_data import SeqVQVAEDataModule
os.environ['CUDA_VISIBLE_DEVICES']  = '2'

In [2]:
from modules.aspect_enc import AspectRepr


seqvqvae = LlamaDecoderForNextArticle.load_from_checkpoint(
        '/home/users1/hardy/hardy/project/vae/src/checkpoints/seqvqvae_std_sts-epoch=09-val_loss=4.4361.ckpt',
        codebook_size=768,
        hidden_size=768,
        intermediate_size=2048,
        num_hidden_layers=10,
        num_attention_heads=12,
        max_position_embeddings=4090)
seqvqvae.eval()

rvqvae = RVQVAE.load_from_checkpoint('checkpoints/rvqvae_std_sts-epoch=12-val_loss=0.72483.ckpt', 
        codebook_dim=512, 
        codebook_sizes=[414],
        num_quantizers=1,
        encoder_hidden_size=128,
        decoder_hidden_size=128,
        input_size=1024)
rvqvae.eval()

RVQVAE(
  (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder): Sequential(
    (0): Linear(in_features=1024, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=512, bias=True)
  )
  (rvq_layer): ResidualVectorQuantizer(
    (quantizers): ModuleList(
      (0): VectorQuantizer(
        (codebook): Embedding(414, 512)
      )
    )
  )
  (decoder): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1024, bias=True)
  )
  (loss_fn): MSELoss()
)

In [3]:
def load_aspect_vectors(path: Path):
    """
    Load aspect vectors from a given path.
    """
    data = {}
    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            nid = int(parts[0])
            vector = [float(x) for x in parts[1:]]
            data[nid] = np.array(vector, dtype=np.float32)
    return data

dev_data_path='/home/users1/hardy/hardy/datasets/mind/MINDlarge_dev'
news = load_news_data(Path(dev_data_path), 'dev')
behavior = load_history_data(Path(dev_data_path), 'dev', news, fix_history=False)
aspect_vector = load_aspect_vectors(Path('/home/users1/hardy/hardy/project/vae/outputs/mind/dev_mind_std_sts_aspect_vectors.txt'))

In [4]:
code_dict_path = '/home/users1/hardy/hardy/project/vae/outputs/mind/dev_mind_std_sts_code_dict.txt'
with open(code_dict_path, 'r') as f:
    code_dict = {}
    for line in f:
        parts = line.strip().split()
        nid = parts[0]
        code = int(parts[1])
        code_dict[nid] = code

In [5]:
for idx, row in tqdm.tqdm(behavior.iterrows(), total=len(behavior)):
    history_indices = []
    for article_id in row['history'].split():
        article_id = article_id[1:]
        index = code_dict[article_id]
        history_indices.append(str(index))
    behavior.loc[idx, 'history_indices'] = ' '.join(history_indices)


  0%|          | 0/359723 [00:00<?, ?it/s]

100%|██████████| 359723/359723 [00:55<00:00, 6505.59it/s]


In [6]:
seqvqvae_data_module = SeqVQVAEDataModule(
    test_df = behavior,
    batch_size=4,
    max_len=10000,
    overlap=0
)
seqvqvae_data_module.setup('test')

In [7]:
dataloader = seqvqvae_data_module.test_dataloader()

In [8]:
results = []
for batch in tqdm.tqdm(dataloader):
    batch = [x.cuda() for x in batch]
    outputs = seqvqvae(*batch)
    last_hidden_state = outputs.last_hidden_state
    logits = seqvqvae.lm_head(last_hidden_state)
    prob_dists = torch.nn.Softmax(dim=-1)(logits[:,-1,:])
    predicted_indices = torch.topk(prob_dists, k=25, dim=-1).indices
    results.append(predicted_indices.cpu().numpy())

100%|██████████| 89931/89931 [16:41<00:00, 89.77it/s]


In [12]:
flat_results = np.concatenate(results, axis=0)
# Save the flat_results numpy array to a file
np.save('flat_results.npy', flat_results)

In [13]:
flat_results

array([[140, 198, 345, ..., 201,  89, 400],
       [140, 141,   5, ..., 184,  79, 237],
       [140, 131, 321, ..., 109, 102, 390],
       ...,
       [177, 140, 198, ..., 345, 237, 293],
       [140, 345, 243, ..., 390, 196,  46],
       [140,  42, 198, ..., 233, 390,  46]], shape=(359723, 25))

In [None]:
# Flatten the results list to match the number of rows in the behavior DataFrame
# Each element in results is a numpy array of shape (batch_size, topk)
# We'll concatenate along axis 0 to get (num_behaviors, topk)


assert flat_results.shape[0] == len(behavior)


In [None]:
client = MilvusClient("candidates.db")

In [None]:
def load_grouped_articles(file_path):
    # Placeholder for loading grouped articles
    columns = ['date', 'articles']
    df = pd.read_csv(file_path, sep='\t', header=None, names=columns)
    return df
# Generate all articles given a date
df = load_grouped_articles('/home/users1/hardy/hardy/datasets/mind/MINDlarge_dev/grouped_behaviors.tsv')


In [None]:


if client.has_collection(collection_name="dev_std"):
    client.drop_collection(collection_name="dev_std")
client.create_collection(
    collection_name="dev_std",
    dimension=1024,  # The vectors we will use in this demo has 768 dimensions
)

In [None]:
df['articles_split'] = df['articles'].apply(lambda x: x.split())
i = 0
for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)):
    date = row['date']
    articles = row['articles_split']
    vectors = []
    for article in articles:
        article_id = int(article[1:])  # Remove the leading 'N'
        if article_id in aspect_vector:
            vector = aspect_vector[article_id]
            client.insert(
                collection_name=f"dev_std",
                data={"vector": vector, "id": i, "date": date, "article_id": article},
            )
            i += 1
        else:
            print(f"Article {article} not found in aspect_vector")

100%|██████████| 1/1 [16:53<00:00, 1013.69s/it]


In [28]:
# Fetch the first 10 data entries from the Milvus collection and show the total length
first_10 = client.query(
    collection_name="dev_std",
    output_fields=["id", "vector", "date", "article_id"],
    limit=10
)
total_count = client.get_collection_stats("dev_std")["row_count"]

print("First 10 entries:", first_10)
print("Total number of entries:", total_count)

First 10 entries: data: ["{'id': 0, 'date': '2019-11-15', 'article_id': 'N77675', 'vector': [np.float32(1.9952806), np.float32(0.73241216), np.float32(0.15720135), np.float32(1.8058813), np.float32(-1.9448398), np.float32(0.35916188), np.float32(-1.5362413), np.float32(0.18632725), np.float32(1.0178922), np.float32(-1.8010275), np.float32(0.63838404), np.float32(-0.42765114), np.float32(1.9597986), np.float32(-0.19001888), np.float32(-1.0058689), np.float32(-1.0423645), np.float32(-2.1629148), np.float32(1.0925298), np.float32(-1.6098374), np.float32(-0.3003936), np.float32(2.2687006), np.float32(-0.22901349), np.float32(0.68278825), np.float32(0.4674465), np.float32(1.0500894), np.float32(1.0680149), np.float32(1.3114581), np.float32(0.2847222), np.float32(-0.7773736), np.float32(1.3578644), np.float32(0.67630523), np.float32(-0.8773436), np.float32(1.9708133), np.float32(1.796619), np.float32(-0.4289525), np.float32(0.112669), np.float32(1.0401715), np.float32(-0.13447803), np.float3

In [29]:
codebooks = rvqvae.rvq_layer.quantizers[0].codebook.weight.data

In [30]:
selected_indices = codebooks[flat_results]

In [31]:
selected_indices[0].shape

torch.Size([25, 512])

In [32]:
all_target_vectors = []
for i in tqdm.tqdm(range(selected_indices.shape[0])):
    target_vectors = rvqvae.decoder(selected_indices[i])
    all_target_vectors.append(target_vectors.detach().cpu().numpy())

100%|██████████| 359723/359723 [01:24<00:00, 4251.78it/s]


In [33]:
all_target_vectors_list = [[vec for vec in vecs] for vecs in all_target_vectors]

In [34]:
all_target_vectors_list[0][0]

array([-0.02059989,  0.08748009,  0.09058606, ...,  0.07643966,
        0.0637296 ,  0.10439934], shape=(1024,), dtype=float32)

In [37]:
all_unique_results = []
for i in tqdm.tqdm(range(len(all_target_vectors_list))):
    res = client.search(collection_name="dev_std",
                        search_params={"metric_type": "COSINE"}, 
                        anns_field="vector", 
                        data=all_target_vectors_list[i],
                        limit=1,output_fields=["id", "vector", "date", "article_id"])
                        # Flatten the results and collect unique article ids
    unique_ids = set()
    unique_results = []
    for hits in res:
        for hit in hits:
            if hit['id'] not in unique_ids:
                unique_ids.add(hit['id'])
                unique_results.append(hit)
    all_unique_results.append(unique_results)

100%|██████████| 359723/359723 [2:39:51<00:00, 37.51it/s]    


In [38]:
import pickle

with open('all_unique_results.pkl', 'wb') as f:
    pickle.dump(all_unique_results, f)

In [36]:
unique_ids = set()
unique_results = []
for hits in res:
    for hit in hits:
        if hit['id'] not in unique_ids:
            unique_ids.add(hit['id'])
            unique_results.append(hit)
unique_results

[{'id': 2170, 'distance': 0.39792799949645996, 'entity': {'id': 2170, 'vector': [-0.6384049654006958, 2.076688289642334, 0.7393279075622559, 2.311272382736206, -1.626744270324707, -0.07993075251579285, -0.9772454500198364, -0.05791698023676872, 1.2966983318328857, 1.3938915729522705, -0.8651633858680725, -0.3869216740131378, 1.145029067993164, -0.20852120220661163, -2.0616304874420166, -1.6757508516311646, -1.3861435651779175, 1.799403429031372, -2.152791976928711, 0.520451545715332, 1.1117682456970215, 0.3833470642566681, 0.5036287903785706, 0.24297428131103516, 2.0764448642730713, 0.91532963514328, 2.449953556060791, 1.133056402206421, -0.44914284348487854, 0.5628463625907898, 0.2401438057422638, -1.7925243377685547, 1.8111801147460938, 2.60597825050354, -0.3138400912284851, 0.8238584399223328, 0.21261107921600342, -0.5870572328567505, -1.2036000490188599, -0.2583828866481781, -0.7368589639663696, -0.3308340907096863, 1.3608065843582153, 1.0512926578521729, -2.106748342514038, 0.1655

In [22]:
flat_results.shape

(359723, 25)

In [184]:
lstur_dataloader = data_module.test_dataloader()

In [185]:
batch = next(iter(lstur_dataloader))

In [186]:
batch['user_history'].shape

torch.Size([1, 31, 103, 1])

In [170]:
data_module.data['test']

Unnamed: 0,impression_id,user_id,timestamp,history,impressions,history_category,history_subcategory,candidates,candidates_category,candidates_subcategory,labels,history_text,candidates_text,user_id_class
0,1,U134050,2019-11-15 08:55:22,N12246 N128820 N119226 N4065 N67770 N33446 N10...,N91737-0 N30206-0 N54368-0 N117802-0 N18190-0 ...,"[7, 3, 2, 9, 4, 1, 2, 1, 1, 1, 1, 5, 5, 5, 2, ...","[67, 6, 21, 23, 54, 85, 35, 59, 138, 1, 188, 9...","[N91737, N30206, N54368, N117802, N18190, N122...","[13, 1, 2, 12, 1, 7, 1, 11, 12, 7, 4, 4, 0, 4,...","[135, 32, 35, 48, 101, 62, 1, 71, 37, 53, 8, 8...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]",[How I Got My Job: Becoming a Chef and Food St...,[13 Reasons Why's Christian Navarro Slams Disn...,88392
1,2,U254959,2019-11-15 11:42:35,N68431 N107520 N129836 N114848 N89408 N23264 N...,N119999-0 N24958-0 N104054-0 N33901-0 N9250-0 ...,"[2, 4, 2, 1, 2, 2, 3, 2, 13, 11, 2, 2, 4, 11, ...","[35, 31, 21, 32, 35, 35, 6, 35, 106, 87, 72, 2...","[N119999, N24958, N104054, N33901, N9250, N333...","[1, 8, 2, 9, 10, 6, 2, 2, 2, 2, 3, 9, 3, 1, 1,...","[6, 284, 72, 77, 65, 145, 82, 35, 82, 5, 16, 2...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[Iowa woman killed by explosion at gender reve...,[I've been writing about tiny homes for a year...,510067
2,3,U499841,2019-11-15 09:08:21,N63858 N26834 N6379 N85484 N15229 N65119 N1047...,N18190-0 N89764-0 N91737-0 N54368-0 N49978-1 N...,"[4, 13, 1, 4, 9, 2, 2, 9, 4, 4, 10, 4, 2, 4, 4...","[8, 106, 59, 8, 19, 35, 5, 77, 8, 8, 134, 31, ...","[N18190, N89764, N91737, N54368, N49978, N2916...","[1, 0, 13, 2, 4, 4, 12, 1]","[101, 0, 135, 35, 8, 8, 48, 101]","[0, 0, 0, 0, 1, 0, 0, 0]","[NFL winners, losers: Cowboys need to rebound,...",[30 Best Black Friday Deals from Costco Costco...,206675
3,4,U107107,2019-11-15 05:50:31,N34519 N33286 N36085 N92079 N114848 N102674 N3...,N122944-1 N18190-0 N55801-0 N59297-0 N128045-0...,"[7, 2, 7, 2, 1, 2, 2, 8, 14, 7, 13, 8, 7, 13, ...","[67, 35, 62, 35, 32, 5, 21, 25, 130, 53, 106, ...","[N122944, N18190, N55801, N59297, N128045, N29...","[7, 1, 5, 8, 10, 4, 2, 4, 1, 2]","[62, 101, 9, 14, 134, 8, 35, 8, 1, 21]","[1, 0, 0, 0, 0, 0, 1, 0, 0, 0]",[Popeyes announces return date of its chicken ...,[The Real Reason McDonald's Keeps the Filet-O-...,439134
4,5,U492344,2019-11-15 05:02:25,N109183 N48453 N85005 N45706 N98923 N46069 N35...,N64785-0 N82503-0 N32993-0 N122944-0 N29160-0 ...,"[2, 9, 2, 2, 13, 12, 2, 1, 1, 4, 4, 1, 1, 2, 1...","[5, 44, 5, 5, 199, 48, 35, 117, 32, 54, 54, 1,...","[N64785, N82503, N32993, N122944, N29160, N628...","[1, 3, 1, 7, 4, 11, 14, 1, 1, 8, 6, 2, 10, 9, ...","[1, 17, 32, 62, 8, 46, 55, 1, 101, 14, 18, 35,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[Egypt uncovers 'huge cache' of ancient sealed...,"[Archie's Photo Album: Prince Harry, Duchess M...",534281
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
363459,376467,U199558,2019-11-15 08:12:35,N90686 N73122 N37206 N12575 N96616 N114714,N122640-0 N29160-0 N54368-1 N18190-0,"[1, 1, 9, 6, 2, 9]","[32, 1, 44, 18, 35, 23]","[N122640, N29160, N54368, N18190]","[4, 4, 2, 1]","[8, 8, 35, 101]","[0, 0, 1, 0]",[Woman Spots Deadly Animal Hiding In Photo Of ...,[This was uglier than a brawl. And Myles Garre...,537769
363460,376468,U356824,2019-11-15 06:10:05,N31305 N7742 N45909 N13422 N116312 N110755 N62...,N122640-0 N18190-0 N55801-0 N69938-0 N12384-0 ...,"[2, 2, 11, 5, 4, 4, 1, 4, 5, 12, 12, 2, 5, 11,...","[5, 2, 152, 9, 22, 8, 1, 22, 9, 48, 37, 21, 9,...","[N122640, N18190, N55801, N69938, N12384, N291...","[4, 1, 5, 1, 1, 4, 7, 7, 1, 2, 2, 4]","[8, 101, 9, 1, 32, 8, 62, 53, 1, 35, 21, 31]","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]",[The Latest: Powerful typhoon makes landfall i...,[This was uglier than a brawl. And Myles Garre...,37405
363461,376469,U484114,2019-11-15 15:05:47,N122359 N104081 N79583 N60911 N20131 N29446 N2...,N46555-0 N28863-0 N129416-0 N112536-0 N64957-0...,"[13, 2, 4, 4, 12, 4, 2, 1, 2, 2, 4, 10, 10, 2,...","[135, 21, 54, 54, 163, 54, 35, 1, 5, 35, 8, 11...","[N46555, N28863, N129416, N112536, N64957, N21...","[0, 2, 14, 5, 2, 0, 11, 2, 1, 8, 4, 1]","[0, 5, 55, 9, 35, 0, 46, 21, 6, 76, 22, 101]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]","[Robert Forster, Oscar-Nommed Star of 'Jackie ...",[Federal Prosecutors Probe Giuliani's Links to...,454323
363462,376470,U719484,2019-11-15 15:30:41,N51171 N115826 N52641 N454 N32534 N127923 N443...,N56784-0 N28863-0 N26553-0 N2110-0 N99846-0 N1...,"[2, 8, 3, 12, 11, 12, 4, 11, 2, 1, 3, 11, 2, 1...","[5, 14, 17, 48, 87, 48, 39, 46, 35, 156, 16, 4...","[N56784, N28863, N26553, N2110, N99846, N19831...","[0, 2, 12, 1, 11, 4, 13, 8, 0, 2, 2, 0, 6, 7, ...","[0, 5, 48, 101, 46, 39, 51, 76, 0, 21, 35, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[Strong quake in Philippines kills one, injure...",[Louisiana governor's race is the latest test ...,0


In [63]:
seqvqvae_data_module = SeqVQVAEDataModule(
        test_file=Path('../outputs/mind/dev_mind_std_sts_histories_indices.csv'),
        batch_size=1,
    )
seqvqvae_data_module.setup('test')

In [146]:
test_file=pd.read_csv('../outputs/mind/dev_mind_std_sts_histories_indices.csv')
test_file

Unnamed: 0,impression_id,user_id,history-1,history_indices
0,359674,U66319,N10721 N128129 N28406 N118998 N38884 N96764 N1...,322 328 288 334 140 131 5 381 140 15 217 337 1...
1,271466,U714534,,
2,348598,U231191,N7742 N90185 N58034 N48215 N58477 N48215 N1167...,56 140 140 201 158 201 116 243 198 1 140 248 2...
3,257817,U167725,N80865 N77603 N14898 N75485,21 26 243 113
4,73530,U723351,N109183 N121551 N70847 N25818 N44644 N13604 N2...,140 198 140 140 140 334 273 140 177 140
...,...,...,...,...
255985,351546,U671917,N14678 N85697 N84528 N36467 N124989 N63723 N10...,131 140 15 281 302 167 224 196 184 118 15 390 ...
255986,141597,U569120,,
255987,167941,U561652,N33917 N27334 N48992 N16416 N12934 N52372 N659...,102 219 140 270 394 58 237 140 219 243 324 140...
255988,77488,U657119,N111348 N41632 N75821 N116144 N41156,140 184 322 359 350


In [157]:
raw_histories = test_file['history-1'].iloc[0].split()

In [161]:
raw_histories

['N10721',
 'N128129',
 'N28406',
 'N118998',
 'N38884',
 'N96764',
 'N123633',
 'N42703',
 'N38313',
 'N33177',
 'N87446',
 'N127659',
 'N122729',
 'N15642',
 'N107017',
 'N87210',
 'N10184',
 'N18850',
 'N101263',
 'N5468',
 'N107732',
 'N50489',
 'N18690',
 'N2653',
 'N64554',
 'N65685',
 'N19959',
 'N108219',
 'N5106',
 'N116725',
 'N51471',
 'N45225',
 'N46056',
 'N17072',
 'N25561',
 'N117102',
 'N36062',
 'N76664',
 'N64642',
 'N21044',
 'N59520',
 'N27352',
 'N75078',
 'N48862',
 'N129836',
 'N56204',
 'N73963',
 'N35758',
 'N60457',
 'N76928',
 'N1596',
 'N13872',
 'N117074',
 'N113936',
 'N94492',
 'N53325',
 'N15448',
 'N8991',
 'N11878',
 'N104737',
 'N14504',
 'N113678',
 'N51649',
 'N126244',
 'N84975',
 'N104277',
 'N110158',
 'N30755',
 'N7764',
 'N79026',
 'N97862',
 'N64035',
 'N94180',
 'N111058',
 'N58522',
 'N82386']

In [154]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
indices = torch.tensor(list(map(int, test_file['history_indices'].iloc[0].split())), device=device, dtype=torch.long)
batch = [
    indices.unsqueeze(0),
    torch.ones(1, indices.shape[0], device=device, dtype=torch.long),
]

In [None]:
outputs = seqvqvae(*batch)
last_hidden_state = outputs.last_hidden_state
logits = seqvqvae.lm_head(last_hidden_state)
prob_dists = torch.nn.Softmax(dim=-1)(logits[:,-1,:])
predicted_indices = torch.topk(prob_dists, k=25, dim=-1).indices

In [103]:
predicted_indices

tensor([[140, 177, 131, 337, 198,  15, 196, 217, 322, 328, 104, 390,   3, 200,
         270,   5, 237, 219, 194, 381, 233, 395, 202, 160, 366]],
       device='cuda:0')

In [None]:
# codebooks = rvqvae.rvq_layer.quantizers[0].codebook.weight.data.shape


torch.Size([25, 512])

In [112]:
target_vectors = rvqvae.decoder(selected_codebook_vectors.unsqueeze(0)).squeeze(0)

In [53]:
aspect_vectors['train']['std'].keys()

dict_keys([88753, 45436, 23144, 86255, 93187, 75236, 99744, 5771, 124534, 51947, 59220, 17957, 40259, 42222, 46520, 40599, 13152, 22273, 107267, 30547, 42639, 54460, 117551, 79856, 72751, 81543, 5149, 10616, 50737, 108378, 75778, 21935, 93333, 9663, 40432, 78823, 81055, 35648, 67829, 16695, 5219, 9580, 70495, 71593, 41917, 35373, 126961, 73751, 25725, 108072, 57904, 1387, 35968, 22632, 16951, 120810, 127985, 3018, 55720, 58905, 128011, 13709, 33970, 38405, 129080, 96395, 117941, 43403, 13822, 126443, 122146, 76450, 9204, 80991, 93054, 58373, 56151, 33576, 103042, 110329, 116735, 81976, 115791, 83188, 12731, 56873, 24595, 35617, 83764, 120909, 1846, 54026, 46094, 107732, 13286, 31493, 26284, 11743, 35480, 106490, 73263, 110516, 116159, 10650, 46741, 89260, 43508, 115846, 54286, 125142, 104819, 12905, 70182, 40895, 74585, 59956, 13998, 76098, 112099, 27552, 116074, 28677, 40542, 52012, 72564, 89205, 13636, 77389, 37416, 108765, 84064, 66386, 101245, 113097, 60629, 10493, 76727, 50596, 72

In [45]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
all_articles = [{"vector": torch.tensor(aspect_vectors['dev']['std'][int(article[1:])], dtype=torch.float32, device=device), "id": i} for i, article in enumerate(df['articles'].iloc[0].split())]

In [46]:
res = client.insert(collection_name="dev_std", data=all_articles)

In [126]:
target_vectors.detach().cpu()
target_vectors_list = target_vectors.detach().cpu().numpy().tolist()

In [128]:
len(target_vectors_list[0])

1024

In [137]:

res = client.search(collection_name="dev_std",
                    search_params={"metric_type": "COSINE"}, 
                    anns_field="vector", 
                    data=[target_vectors_list[1]],
                    limit=25,output_fields=["id", "vector"])
                    # Flatten the results and collect unique article ids
unique_ids = set()
unique_results = []
for hits in res:
    for hit in hits:
        if hit['id'] not in unique_ids:
            unique_ids.add(hit['id'])
            unique_results.append(hit)
unique_results

[{'id': 4518, 'distance': 0.8436839580535889, 'entity': {'id': 4518, 'vector': [0.34095245599746704, 1.2260475158691406, 0.07550328969955444, 1.4739519357681274, -1.8527588844299316, 0.22991850972175598, -2.0177347660064697, -0.5618963837623596, 1.331736445426941, 0.04184151813387871, -0.13242104649543762, -0.31928515434265137, 1.3633745908737183, -1.301343321800232, -1.62325119972229, -2.8027701377868652, -1.8981997966766357, 1.5251870155334473, -1.7935230731964111, -0.013952008448541164, 1.446311354637146, 0.09082905948162079, 1.4313489198684692, -0.6509618759155273, 2.986189842224121, 2.3696346282958984, 1.076655387878418, 1.1636009216308594, -0.390546590089798, 0.7045164704322815, 0.5331809520721436, -0.9112767577171326, 1.5487021207809448, 1.0980111360549927, 0.12718217074871063, 0.7745473384857178, 0.6850557923316956, -0.6460841298103333, -0.6740808486938477, -1.041019082069397, -0.6027694940567017, 0.18014110624790192, -0.10369299352169037, 1.4295170307159424, -1.654136300086975

In [142]:
test_file['history-1'].iloc[0]

'N10721 N128129 N28406 N118998 N38884 N96764 N123633 N42703 N38313 N33177 N87446 N127659 N122729 N15642 N107017 N87210 N10184 N18850 N101263 N5468 N107732 N50489 N18690 N2653 N64554 N65685 N19959 N108219 N5106 N116725 N51471 N45225 N46056 N17072 N25561 N117102 N36062 N76664 N64642 N21044 N59520 N27352 N75078 N48862 N129836 N56204 N73963 N35758 N60457 N76928 N1596 N13872 N117074 N113936 N94492 N53325 N15448 N8991 N11878 N104737 N14504 N113678 N51649 N126244 N84975 N104277 N110158 N30755 N7764 N79026 N97862 N64035 N94180 N111058 N58522 N82386'

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0')