In [None]:
import json
from tqdm import tqdm
import numpy as np

data_file_ema = 'data/output/human-rx-drug-ema.json'
data_file_fda = 'data/output/human-rx-drug-openfda.json'
RUN_DIAGNOSTIC = False

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

import pickle
from matplotlib import pyplot as plt

In [None]:
#%% load data (takes a fair amount of memory)

with open(data_file_ema) as f:
    data_ema = json.load(f)

with open(data_file_fda) as f:
    data_fda = json.load(f)
    
keys_ema, drugs_ema = zip(*data_ema.items())
sections_ema = [d['Label Text'].keys() for d in drugs_ema]
sections_ema = sorted(set([s for slist in sections_ema for s in slist]))

keys_fda, drugs_fda = zip(*data_fda.items())
sections_fda = [d['Label Text'].keys() for d in drugs_fda]
sections_fda = sorted(set([s for slist in sections_fda for s in slist]))



#%% how many words are in each fda section?
if RUN_DIAGNOSTIC:
    slens = {s.replace(' ','_').lower():[] for s in sections_fda}
    for d in tqdm(drugs_fda):
        for s,val in d['Label Text'].items():
            for v in val:
                slens[s].append(len(v.split()))
    for s,lens in slens.items():
        print(np.mean(lens), s)


#%% find pairs of ema/fda drugs with matching brand or generic names

bnames_ema = [d['metadata']['Medicine name'] for d in drugs_ema]
gnames_ema = [d['metadata']['International non-proprietary name (INN) / common name'] for d in drugs_ema]
snames_ema = [d['metadata']['Active substance'] for d in drugs_ema]

bnames_fda = ['' if 'brand_name' not in d['metadata'].keys() else d['metadata']['brand_name'][0] for d in drugs_fda]
gnames_fda = ['' if 'generic_name' not in d['metadata'].keys() else d['metadata']['generic_name'][0] for d in drugs_fda]
snames_fda = ['' if 'substance_name' not in d['metadata'].keys() else d['metadata']['substance_name'][0] for d in drugs_fda]
find_in_list = lambda x,xlist: list(np.where([x==z for z in xlist])[0])

def match_ema_to_fda(drug_e):
    bname_ema = drug_e['metadata']['Medicine name']
    gname_ema = drug_e['metadata']['International non-proprietary name (INN) / common name']
    sname_ema = drug_e['metadata']['Active substance']
    
    matches_b = find_in_list(bname_ema.lower(), [n.lower() for n in bnames_fda])
    matches_g = find_in_list(gname_ema.lower(), [n.lower() for n in gnames_fda])
    matches_s = find_in_list(sname_ema.lower(), [n.lower() for n in snames_fda])
    
    matches = list(set(matches_b + matches_g + matches_s))
    return [drugs_fda[m] for m in matches]
    

#%% find matching fda for each ema (capped), save embeddings to disk (long computation)
# warning: do not run this cell unless you need to re-compute embeddings

# max sequence length = 512
model = SentenceTransformer('allenai/scibert_scivocab_uncased')

# This is the core code for comparing semantic similarity between sections:
#   The model will ignore anything beyond 512 tokens (~300-400 words),
#   so I slice section text into segments of <= 256 words, compute 
#   SciBERT embedding vectors for each one (dim=768), and return the average.
#   For a given label (or list of labels), I do that section by section
#   and stack the results into a matrix. For analysis purposes I'd be
#   averaging the resulting matrices across a list of labels that
#   represent the same drug (e.g. with identical generic name)

def compute_section_embedding(text, word_count=256):
    n_segments = 1 + len(text.split()) // word_count
    vecs = np.zeros((n_segments,768))
    for i in range(n_segments):
        segment = text.split()[ (i)*word_count : (i+1)*word_count ]
        vecs[i,:] = model.encode( ' '.join(segment) )
    return np.mean(vecs, axis=0)

def compute_drug_embedding(sections, drugs):
    vecs = np.zeros((len(sections),768, len(drugs)))
    for d,drug in enumerate(drugs):
        for s,section in enumerate(sections):
            if section in drug['Label Text'].keys():
                texts = drug['Label Text'][section]
                texts = list(set(texts))
                v = [compute_section_embedding(t) for t in texts]
                vecs[s,:,d] = np.array(v).mean(axis=0)
    return vecs
    

all_vecs_e = []
all_vecs_f = []
max_matches = 20
# note: most ema drugs match 0-10 fda products, long tail from ~30 to ~150

for drug_e in tqdm(drugs_ema):
    drugs_f = match_ema_to_fda(drug_e)
    drugs_f = drugs_f[:max_matches]
    
    vecs_e = compute_drug_embedding(sections_ema, [drug_e])
    vecs_f = compute_drug_embedding(sections_fda, drugs_f)

    all_vecs_e.append(vecs_e)
    all_vecs_f.append(vecs_f)

# note: can't use np.savez because fda arrays are mismatched sizes, awkward
with open('output/scibert-embeddings-ema.pkl', 'wb') as f:
    pickle.dump(all_vecs_e, f)
with open('output/scibert-embeddings-openfda.pkl', 'wb') as f:
    pickle.dump(all_vecs_f, f)

# simplified output format:
#  csv with ema generic name, average of embedding vectors (all)
#  csv with fda generic name, average of embedding vectors (those with a match)

    
#%% reload embeddings from disk

with open('output/scibert-embeddings-ema.pkl', 'rb') as f:
    all_vecs_e = pickle.load(f)
with open('output/scibert-embeddings-openfda.pkl', 'rb') as f:
    all_vecs_f = pickle.load(f)


# omit empty ema sections, these 3 headers have all their text in sub-sections
dummies = ['Clinical Particulars',
           'Pharmacological Properties',
           'Pharmaceutical Particulars']
keep_rows = [s not in dummies for s in sections_ema]
all_vecs_e = [v[keep_rows,:] for v in all_vecs_e]

# reformat fda section names for readability (and match EMA)
sections_fda = [s.replace('_',' ').title().replace('Spl','SPL') for s in sections_fda]

#%% aggregate similarity matrix via matching drugs pairwise

cosims = []
for i,ve in enumerate(all_vecs_e):
    vf = all_vecs_f[i]
    if vf.shape[-1]>0:
        ve = ve.mean(axis=-1)
        vf = vf.mean(axis=-1)
        cosim = np.array(cos_sim(ve,vf))
        cosims.append(cosim)

cosims = np.ma.masked_values(cosims, 0)
csarr = np.ma.mean(cosims**4, axis=0)

keep_cols = ~csarr.mask.any(axis=0)
csarr = csarr[:,keep_cols]
labels_col = [sections_fda[i] for i in np.where(keep_cols)[0]]
labels_row = [sections_ema[i] for i in np.where(keep_rows)[0]]

plt.figure(figsize=(16,12), dpi=150)
im = plt.imshow(csarr)
plt.gca().set_aspect(1)
daspect = csarr.shape[0]/csarr.shape[1]
plt.colorbar(im, fraction=0.046*daspect, pad=0.04)
plt.title('SciBERT Embedding Cosine Similarity \nEMA-FDA Matched Pairs', fontsize=16)
plt.xlabel('FDA Label Section')
plt.ylabel('EMA Label Section')
plt.xticks(range(csarr.shape[1]), labels_col, rotation=90, size='small')
plt.yticks(range(csarr.shape[0]), labels_row, size='small')
plt.tight_layout()
plt.show()

#%% aggregate similarity matrix via matching drugs all-to-all

cosims, ct = np.zeros((26,81)), 0
for ve in tqdm(all_vecs_e):
    ve = ve.mean(axis=-1)
    for vf in all_vecs_f:
        if vf.shape[-1]>0:
            vf = vf.mean(axis=-1)
            cosims += np.array(cos_sim(ve,vf))**4
            ct += 1

# full array won't fit in memory if we do it the same way
csarr = np.ma.masked_values(cosims/ct, 0)

#keep_cols = ~csarr.mask.any(axis=0)
csarr = csarr[:,keep_cols]
labels_col = [sections_fda[i] for i in np.where(keep_cols)[0]]
labels_row = [sections_ema[i] for i in np.where(keep_rows)[0]]

plt.figure(figsize=(16,12), dpi=150)
im = plt.imshow(csarr)
plt.gca().set_aspect(1)
daspect = csarr.shape[0]/csarr.shape[1]
plt.colorbar(im, fraction=0.046*daspect, pad=0.04)
plt.title('SciBERT Embedding Cosine Similarity \nEMA-FDA Overall', fontsize=16)
plt.xlabel('FDA Label Section')
plt.ylabel('EMA Label Section')
plt.xticks(range(csarr.shape[1]), labels_col, rotation=90, size='small')
plt.yticks(range(csarr.shape[0]), labels_row, size='small')
plt.tight_layout()
plt.show()


#%% section similarity within agency: EMA

cosims = []
for i,v in enumerate(all_vecs_e):
    v = v.mean(axis=-1)
    cosim = np.array(cos_sim(v,v))
    cosims.append(cosim)

cosims = np.ma.masked_values(cosims, 0)
csarr = np.ma.mean(cosims**4, axis=0)

labels_row = [sections_ema[i] for i in np.where(keep_rows)[0]]

plt.figure(figsize=(12,12), dpi=150)
im = plt.imshow(csarr)
plt.gca().set_aspect(1)
daspect = csarr.shape[0]/csarr.shape[1]
plt.colorbar(im, fraction=0.046*daspect, pad=0.04)
plt.title('SciBERT Embedding Cosine Similarity \nEMA Only', fontsize=16)
plt.xlabel('EMA Label Section')
plt.ylabel('EMA Label Section')
plt.xticks(range(csarr.shape[1]), labels_row, rotation=90, size='small')
plt.yticks(range(csarr.shape[0]), labels_row, size='small')
plt.tight_layout()
plt.show()


#%% section similarity within agency: FDA

cosims = []
for i,v in enumerate(all_vecs_f):
    if v.shape[-1]>0:
        v = v.mean(axis=-1)
        cosim = np.array(cos_sim(v,v))
        cosims.append(cosim)

cosims = np.ma.masked_values(cosims, 0)
csarr = np.ma.mean(cosims**4, axis=0)

#keep_cols = np.where(~csarr.mask.any(axis=0))[0] # keep same cols as before
csarr = csarr[:,keep_cols][keep_cols,:]
labels_col = [sections_fda[i].replace('_',' ').title().replace('Spl','SPL') for i in keep_cols]

plt.figure(figsize=(12,12), dpi=150)
im = plt.imshow(csarr)
plt.gca().set_aspect(1)
daspect = csarr.shape[0]/csarr.shape[1]
plt.colorbar(im, fraction=0.046*daspect, pad=0.04)
plt.title('SciBERT Embedding Cosine Similarity \nFDA Only', fontsize=16)
plt.xlabel('FDA Label Section')
plt.ylabel('FDA Label Section')
plt.xticks(range(csarr.shape[1]), labels_col, rotation=90, size='small')
plt.yticks(range(csarr.shape[0]), labels_col, size='small')
plt.tight_layout()
plt.show()


#%% next goal: collapse all identical generic names? probably not needed
if RUN_DIAGNOSTIC:
    
    import pandas as pd
    
    enames = pd.DataFrame({'generic':gnames_ema, 'brand':bnames_ema})
    enames = enames.applymap(lambda x: x.lower())
    enames['matches'] = [match_ema_to_fda(d) for d in data_ema.values()] # drugs
    enames['matches'] = enames['matches'].apply(lambda x: [d['metadata']['generic_name'] for d in x])
    enames['matches'] = enames['matches'].apply(lambda x: sum(x,[])) # flatten
    enames['n_matches'] = enames['matches'].apply(lambda x: len(x))
    enames['c_matches'] = enames['matches'].apply(lambda x: len(set(x)))
    
    fnames = pd.DataFrame({'generic':gnames_fda, 'brand':bnames_fda})
    fnames = fnames.applymap(lambda x: x.lower())

b