In [1]:
from tqdm import tqdm
import torch
import pandas as pd

In [2]:
from sentence_transformers import SentenceTransformer

# 1. Load a pretrained Sentence Transformer model
model = SentenceTransformer("all-MiniLM-L6-v2")

In [3]:
# Load queries from file
queries = []
with open("data/US/queries.txt", "r") as f:
    for line in f:
        queries.append(line.strip())

In [13]:
queries[578]

'flu jabs'

# First filter: only allow queries with sufficient data

In [26]:
with open('./data/US/queries_filtered.txt') as f:
    raw_us_queries = f.read().splitlines()
query_ids = set([int(r.split('\t')[0]) for r in raw_us_queries])
q_freq_us_raw = pd.read_csv('./data/US/Q_freq.csv')

from datetime import datetime

def days_between(d1, d2):
    d1 = datetime.strptime(d1, "%Y-%m-%d")
    d2 = datetime.strptime(d2, "%Y-%m-%d")
    return (d2 - d1).days

def day_index(d):
    return days_between('2004-01-01', d)

In [27]:
start_day = day_index('2009-09-01')
end_day = day_index('2019-08-31')

In [28]:
q_freq_first_non_missing = q_freq_us_raw.groupby('Query').first().reset_index()
q_freq_last_non_missing = q_freq_us_raw.groupby('Query').last().reset_index()

In [29]:
# Find set of qids with non-missing frequency in the first and last month
qids_first_non_missing = set(q_freq_first_non_missing[q_freq_first_non_missing['Day'] <= start_day]['Query'])
qids_last_non_missing = set(q_freq_last_non_missing[q_freq_last_non_missing['Day'] >= end_day]['Query'])

qids_with_data = qids_first_non_missing.intersection(qids_last_non_missing)

# Second filter: semantic filter

In [30]:
# Encode queries
batch_size = 128
embeddings = torch.tensor([])
for i in tqdm(range(0, len(queries), batch_size)):
    batch = queries[i : i + batch_size]
    embeddings = torch.cat((embeddings, torch.tensor(model.encode(batch))))

# Save embeddings
torch.save(embeddings, "data/US/queries_embeddings.pt")

100%|██████████| 177/177 [00:04<00:00, 39.44it/s]


In [31]:
def cosine_similarity(a, b):
    # Both are torch tensors
    return torch.dot(a, b) / (torch.norm(a) * torch.norm(b))

In [32]:
# Flu concept embeddings
flu_embedding = torch.tensor(model.encode("flu"))
fever_embedding = torch.tensor(model.encode("fever"))

query_similarities = []
for i, query in enumerate(queries):
    query_id = i + 1
    query_embedding = embeddings[i]
    flu_symptom_similarity = cosine_similarity(query_embedding, flu_embedding)
    flu_vaccine_similarity = cosine_similarity(query_embedding, fever_embedding)
    # Average of the three similarities
    average_similarity = (flu_symptom_similarity + flu_vaccine_similarity) / 2
    query_similarities.append((query_id, query, average_similarity.item()))

# Sort queries by similarity
query_similarities.sort(key=lambda x: x[2], reverse=True)

In [33]:
# Filter out queries with qids_with_data
filtered_queries = []
for query in query_similarities:
    filtered_queries.append(query)

# Only use first 1500 queries
filtered_queries = filtered_queries[:1500]

# Save filtered queries
with open("data/US/queries_filtered.txt", "w") as f:
    for query in filtered_queries:
        f.write(f"{query[0]}\t{query[1]}\n")