In [2]:
import pickle
import math
import numpy as np
import matplotlib.pyplot as plt
from os.path import join
from s2and.data import ANDData
%matplotlib inline



In [6]:
from tqdm import tqdm

In [2]:
# load s2and production model -> Clusterer obj
with open("data/production_model.pickle", "rb") as _pkl_file:
    prod_model = pickle.load(_pkl_file)
    clusterer = prod_model['clusterer']

RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd

In [3]:
def create_model_and_eval(X_train, y_train, X_val, y_val):
    pairwise_model = PairwiseModeler(
        n_iter=25, monotone_constraints=featurization_info.lightgbm_monotone_constraints
    )
    pairwise_model.fit(X_train, y_train, X_val, y_val)
    y_pred = pairwise_model.predict_proba(X_val)
    acc = np.sum((y_pred[:, 1]>0.5).astype(int) == y_val) / len(y_pred)
    return acc

In [3]:
# load dataset in "train" mode

from s2and.model import PairwiseModeler
from s2and.featurizer import FeaturizationInfo, featurize
from s2and.eval import pairwise_eval

RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd

In [8]:
# Strategy 1: Impute to 0

def strategy1(train, val):
    X_train, y_train, _ = train
    X_val, y_val, _ = val

    X_train = np.nan_to_num(X_train)
    X_val = np.nan_to_num(X_val)

    acc = create_model_and_eval(X_train, y_train, X_val, y_val)
    return acc

In [9]:
# Strategy 2: Impute to -1

def strategy2(train, val):
    X_train, y_train, _ = train
    X_val, y_val, _ = val

    X_train = np.nan_to_num(X_train, nan=-1)
    X_val = np.nan_to_num(X_val, nan=-1)

    acc = create_model_and_eval(X_train, y_train, X_val, y_val)
    return acc

In [10]:
# Strategy 3: Do not impute; skip in split decisions

def strategy3(train, val):
    X_train, y_train, _ = train
    X_val, y_val, _ = val

    acc = create_model_and_eval(X_train, y_train, X_val, y_val)
    return acc

In [11]:
# Strategy 4: Set to large negative value

def strategy4(train, val):
    X_train, y_train, _ = train
    X_val, y_val, _ = val

    X_train = np.nan_to_num(X_train, nan=-10000)
    X_val = np.nan_to_num(X_val, nan=-10000)

    acc = create_model_and_eval(X_train, y_train, X_val, y_val)
    return acc

In [12]:
# Strategy 5: Mean imputation

def strategy5(train, val):
    X_train, y_train, _ = train
    X_val, y_val, _ = val

    col_mean = np.nanmean(X_train, axis=0)
    inds = np.where(np.isnan(X_train))
    X_train[inds] = np.take(col_mean, inds[1])
    inds = np.where(np.isnan(X_val))
    X_val[inds] = np.take(col_mean, inds[1])

    acc = create_model_and_eval(X_train, y_train, X_val, y_val)
    return acc

In [15]:
acc_per_strategy = [[]]*5
datasets = ["pubmed", "qian", "zbmath"]
strategy_fns = [strategy1, strategy2, strategy3, strategy4, strategy5]

for dataset_name in datasets:
    print(f"Evaluating over {dataset_name}:\n")
    parent_dir = f"data/{dataset_name}"
    dataset = ANDData(
        signatures=join(parent_dir, f"{dataset_name}_signatures.json"),
        papers=join(parent_dir, f"{dataset_name}_papers.json"),
        mode="train",
        specter_embeddings=join(parent_dir, f"{dataset_name}_specter.pickle"),
        clusters=join(parent_dir, f"{dataset_name}_clusters.json"),
        block_type="s2",
        train_pairs_size=100000,
        val_pairs_size=10000,
        test_pairs_size=10000,
        name=dataset_name,
        n_jobs=8,
    )

    featurization_info = FeaturizationInfo()
    # the cache will make it faster to train multiple times - it stores the features on disk for you
    train, val, test = featurize(dataset, featurization_info, n_jobs=8, use_cache=True)
    
    for i, strategy_fn in tqdm(enumerate(strategy_fns)):
        acc = strategy_fn(train, val)
        acc_per_strategy[i].append(acc)

2022-11-01 15:51:38,969 - s2and - INFO - loading papers


Evaluating over pubmed:



2022-11-01 15:51:41,146 - s2and - INFO - loaded papers
2022-11-01 15:51:41,147 - s2and - INFO - loading signatures
2022-11-01 15:51:42,588 - s2and - INFO - loaded signatures
2022-11-01 15:51:42,589 - s2and - INFO - loading clusters
2022-11-01 15:51:42,593 - s2and - INFO - loaded clusters, loading specter
2022-11-01 15:51:42,617 - s2and - INFO - loaded specter, loading cluster seeds
2022-11-01 15:51:42,618 - s2and - INFO - loaded cluster seeds
2022-11-01 15:51:42,619 - s2and - INFO - making signature to cluster id
2022-11-01 15:51:42,622 - s2and - INFO - made signature to cluster id
2022-11-01 15:51:42,622 - s2and - INFO - loading name counts
2022-11-01 15:51:58,313 - s2and - INFO - loaded name counts
2022-11-01 15:51:58,721 - s2and - INFO - preprocessing papers
Preprocessing papers 1/2: 100%|████████████████████████████| 134623/134623 [00:24<00:00, 5548.88it/s]
Preprocessing papers 2/2: 100%|████████████████████████████| 134623/134623 [00:47<00:00, 2816.45it/s]
2022-11-01 15:53:14,574 


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▎                              | 1/25 [00:45<18:11, 45.49s/trial, best loss: -0.99978438881883][A
  8%|██▍                           | 2/25 [01:03<14:16, 37.22s/trial, best loss: -0.9998516009462936][A
 12%|███▌                          | 3/25 [01:23<11:42, 31.93s/trial, best loss: -0.9998516009462936][A
 16%|████▊                         | 4/25 [01:34<09:02, 25.85s/trial, best loss: -0.9998516009462936][A
 20%|██████                        | 5/25 [02:04<09:01, 27.06s/trial, best loss: -0.9998516009462936][A
 24%|███████▏                      | 6/25 [02:29<08:22, 26.46s/trial, best loss: -0.9998516009462936][A
 28%|████████▍                     | 7/25 [02:46<07:03, 23.53s/trial, best loss: -0.9998516009462936][A
 32%|█████████▌                    | 8/25 [02:56<05:34, 19.66s/trial, best loss: -0.9998516009462936][A
 36%|██████████▊                   | 9/25 [03:33<06:36

1it [09:27, 567.05s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:47<19:05, 47.71s/trial, best loss: -0.9997817269523959][A
  8%|██▍                           | 2/25 [01:12<15:37, 40.75s/trial, best loss: -0.9998772214107227][A
 12%|███▌                          | 3/25 [01:33<12:47, 34.90s/trial, best loss: -0.9998772214107227][A
 16%|████▊                         | 4/25 [01:45<09:48, 28.05s/trial, best loss: -0.9998772214107227][A
 20%|██████                        | 5/25 [02:17<09:46, 29.33s/trial, best loss: -0.9998772214107227][A
 24%|███████▏                      | 6/25 [02:45<09:05, 28.71s/trial, best loss: -0.9998772214107227][A
 28%|████████▍                     | 7/25 [03:02<07:36, 25.35s/trial, best loss: -0.9998772214107227][A
 32%|█████████▌                    | 8/25 [03:13<05:58, 21.11s/trial, best loss: -0.9998772214107227][A
 36%|██████████▊                   | 9/25 [03:52<07:03

2it [19:25, 576.35s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:50<20:14, 50.59s/trial, best loss: -0.9998898652762851][A
  8%|██▍                            | 2/25 [01:15<16:28, 43.00s/trial, best loss: -0.999929793272798][A
 12%|███▋                           | 3/25 [01:37<13:24, 36.59s/trial, best loss: -0.999929793272798][A
 16%|████▉                          | 4/25 [01:50<10:18, 29.46s/trial, best loss: -0.999929793272798][A
 20%|██████▏                        | 5/25 [02:23<10:13, 30.67s/trial, best loss: -0.999929793272798][A
 24%|███████▍                       | 6/25 [02:51<09:23, 29.65s/trial, best loss: -0.999929793272798][A
 28%|████████▋                      | 7/25 [03:09<07:54, 26.36s/trial, best loss: -0.999929793272798][A
 32%|█████████▉                     | 8/25 [03:20<06:10, 21.79s/trial, best loss: -0.999929793272798][A
 36%|███████████▏                   | 9/25 [04:02<07:2

3it [31:52, 627.56s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:47<19:07, 47.80s/trial, best loss: -0.9997817269523959][A
  8%|██▍                           | 2/25 [01:12<15:41, 40.94s/trial, best loss: -0.9998772214107227][A
 12%|███▌                          | 3/25 [01:34<12:51, 35.06s/trial, best loss: -0.9998772214107227][A
 16%|████▊                         | 4/25 [01:46<09:51, 28.17s/trial, best loss: -0.9998772214107227][A
 20%|██████                        | 5/25 [02:18<09:49, 29.47s/trial, best loss: -0.9998772214107227][A
 24%|███████▏                      | 6/25 [02:45<09:07, 28.80s/trial, best loss: -0.9998772214107227][A
 28%|████████▍                     | 7/25 [03:03<07:36, 25.37s/trial, best loss: -0.9998772214107227][A
 32%|█████████▌                    | 8/25 [03:14<05:58, 21.10s/trial, best loss: -0.9998772214107227][A
 36%|██████████▊                   | 9/25 [03:52<06:58

4it [41:48, 618.06s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:47<18:51, 47.14s/trial, best loss: -0.9998971854089792][A
  8%|██▍                           | 2/25 [01:12<15:32, 40.53s/trial, best loss: -0.9998971854089792][A
 12%|███▌                          | 3/25 [01:32<12:37, 34.45s/trial, best loss: -0.9998971854089792][A
 16%|████▊                         | 4/25 [01:44<09:43, 27.77s/trial, best loss: -0.9998971854089792][A
 20%|██████                        | 5/25 [02:15<09:34, 28.74s/trial, best loss: -0.9998971854089792][A
 24%|███████▏                      | 6/25 [02:42<08:54, 28.11s/trial, best loss: -0.9998971854089792][A
 28%|████████▍                     | 7/25 [02:59<07:27, 24.88s/trial, best loss: -0.9998971854089792][A
 32%|█████████▌                    | 8/25 [03:10<05:52, 20.74s/trial, best loss: -0.9998971854089792][A
 36%|██████████▊                   | 9/25 [03:49<06:57

5it [52:46, 633.20s/it]
2022-11-01 16:46:05,936 - s2and - INFO - loading papers


Evaluating over qian:



2022-11-01 16:46:07,541 - s2and - INFO - loaded papers
2022-11-01 16:46:07,544 - s2and - INFO - loading signatures
2022-11-01 16:46:07,793 - s2and - INFO - loaded signatures
2022-11-01 16:46:07,797 - s2and - INFO - loading clusters
2022-11-01 16:46:07,812 - s2and - INFO - loaded clusters, loading specter
2022-11-01 16:46:08,053 - s2and - INFO - loaded specter, loading cluster seeds
2022-11-01 16:46:08,056 - s2and - INFO - loaded cluster seeds
2022-11-01 16:46:08,058 - s2and - INFO - making signature to cluster id
2022-11-01 16:46:08,062 - s2and - INFO - made signature to cluster id
2022-11-01 16:46:08,064 - s2and - INFO - loading name counts
2022-11-01 16:46:30,513 - s2and - INFO - loaded name counts
2022-11-01 16:46:30,906 - s2and - INFO - preprocessing papers
Preprocessing papers 1/2: 100%|██████████████████████████████| 59545/59545 [00:50<00:00, 1175.33it/s]
Preprocessing papers 2/2: 100%|██████████████████████████████| 59545/59545 [00:35<00:00, 1674.23it/s]
2022-11-01 16:48:06,324 


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [01:36<38:41, 96.71s/trial, best loss: -0.9363821261939171][A
  8%|██▍                           | 2/25 [02:11<29:55, 78.05s/trial, best loss: -0.9363821261939171][A
 12%|███▌                          | 3/25 [02:38<23:03, 62.88s/trial, best loss: -0.9363821261939171][A
 16%|████▊                         | 4/25 [03:00<17:40, 50.51s/trial, best loss: -0.9363821261939171][A
 20%|██████                        | 5/25 [03:52<17:00, 51.02s/trial, best loss: -0.9370904254901344][A
 24%|███████▏                      | 6/25 [04:51<16:55, 53.47s/trial, best loss: -0.9370904254901344][A
 28%|████████▍                     | 7/25 [05:21<13:56, 46.45s/trial, best loss: -0.9370904254901344][A
 32%|█████████▌                    | 8/25 [05:39<10:41, 37.75s/trial, best loss: -0.9370904254901344][A
 36%|██████████▊                   | 9/25 [07:05<13:55

1it [23:18, 1398.67s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [01:25<34:10, 85.44s/trial, best loss: -0.9331074447027774][A
  8%|██▍                           | 2/25 [01:43<24:59, 65.21s/trial, best loss: -0.9331074447027774][A
 12%|███▌                          | 3/25 [02:04<18:59, 51.82s/trial, best loss: -0.9331074447027774][A
 16%|████▊                         | 4/25 [02:15<13:54, 39.76s/trial, best loss: -0.9331074447027774][A
 20%|██████                        | 5/25 [02:45<12:17, 36.89s/trial, best loss: -0.9331074447027774][A
 24%|███████▏                      | 6/25 [03:17<11:09, 35.22s/trial, best loss: -0.9331074447027774][A
 28%|████████▍                     | 7/25 [03:33<08:52, 29.56s/trial, best loss: -0.9331074447027774][A
 32%|█████████▌                    | 8/25 [03:43<06:43, 23.73s/trial, best loss: -0.9331074447027774][A
 36%|██████████▊                   | 9/25 [04:31<08:16

2it [35:51, 1205.06s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:50<20:07, 50.33s/trial, best loss: -0.9360726011687822][A
  8%|██▍                           | 2/25 [01:09<15:40, 40.89s/trial, best loss: -0.9360726011687822][A
 12%|███▌                          | 3/25 [01:29<12:46, 34.86s/trial, best loss: -0.9360726011687822][A
 16%|████▊                         | 4/25 [01:42<09:50, 28.10s/trial, best loss: -0.9360726011687822][A
 20%|██████                        | 5/25 [02:13<09:38, 28.93s/trial, best loss: -0.9360726011687822][A
 24%|███████▏                      | 6/25 [02:44<09:25, 29.78s/trial, best loss: -0.9360726011687822][A
 28%|████████▍                     | 7/25 [03:01<07:46, 25.92s/trial, best loss: -0.9360726011687822][A
 32%|█████████▌                    | 8/25 [03:12<06:00, 21.22s/trial, best loss: -0.9360726011687822][A
 36%|███████████▏                   | 9/25 [04:02<07:5

3it [51:57, 1133.27s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [01:18<31:16, 78.19s/trial, best loss: -0.9331074447027774][A
  8%|██▍                           | 2/25 [01:58<25:39, 66.93s/trial, best loss: -0.9331074447027774][A
 12%|███▌                          | 3/25 [02:38<21:29, 58.61s/trial, best loss: -0.9331074447027774][A
 16%|████▊                         | 4/25 [02:58<16:30, 47.15s/trial, best loss: -0.9331074447027774][A
 20%|██████                        | 5/25 [03:54<16:37, 49.89s/trial, best loss: -0.9331074447027774][A
 24%|███████▏                      | 6/25 [04:50<16:22, 51.73s/trial, best loss: -0.9331074447027774][A
 28%|████████▍                     | 7/25 [05:21<13:35, 45.31s/trial, best loss: -0.9331074447027774][A
 32%|█████████▌                    | 8/25 [05:40<10:35, 37.41s/trial, best loss: -0.9331074447027774][A
 36%|██████████▊                   | 9/25 [07:09<14:07

4it [1:13:20, 1178.01s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A

Mean of empty slice



  4%|█▏                            | 1/25 [00:47<18:52, 47.17s/trial, best loss: -0.9323062633530225][A
  8%|██▍                           | 2/25 [01:05<14:44, 38.45s/trial, best loss: -0.9390429613862009][A
 12%|███▌                          | 3/25 [01:24<12:01, 32.80s/trial, best loss: -0.9390429613862009][A
 16%|████▊                         | 4/25 [01:36<09:15, 26.44s/trial, best loss: -0.9390429613862009][A
 20%|██████                        | 5/25 [02:06<09:10, 27.55s/trial, best loss: -0.9390429613862009][A
 24%|███████▏                      | 6/25 [02:36<08:56, 28.25s/trial, best loss: -0.9390429613862009][A
 28%|████████▍                     | 7/25 [02:51<07:18, 24.37s/trial, best loss: -0.9390429613862009][A
 32%|█████████▌                    | 8/25 [03:01<05:41, 20.07s/trial, best loss: -0.9390429613862009][A
 36%|██████████▊                   | 9/25 [03:48<07:26, 27.91s/trial, best loss: -0.9390429613862009][A
 40%|███████████▌                 | 10/25 [04:44<09:07

5it [1:25:55, 1031.10s/it]
2022-11-01 18:16:03,523 - s2and - INFO - loading papers


Evaluating over zbmath:



2022-11-01 18:16:03,796 - s2and - INFO - loaded papers
2022-11-01 18:16:03,797 - s2and - INFO - loading signatures
2022-11-01 18:16:03,964 - s2and - INFO - loaded signatures
2022-11-01 18:16:03,965 - s2and - INFO - loading clusters
2022-11-01 18:16:03,980 - s2and - INFO - loaded clusters, loading specter
2022-11-01 18:16:04,278 - s2and - INFO - loaded specter, loading cluster seeds
2022-11-01 18:16:04,279 - s2and - INFO - loaded cluster seeds
2022-11-01 18:16:04,280 - s2and - INFO - making signature to cluster id
2022-11-01 18:16:04,284 - s2and - INFO - made signature to cluster id
2022-11-01 18:16:04,285 - s2and - INFO - loading name counts
2022-11-01 18:16:20,164 - s2and - INFO - loaded name counts
2022-11-01 18:16:21,023 - s2and - INFO - preprocessing papers
Preprocessing papers 1/2: 100%|██████████████████████████████| 23406/23406 [00:05<00:00, 4296.25it/s]
Preprocessing papers 2/2: 100%|██████████████████████████████| 23406/23406 [00:04<00:00, 4702.58it/s]
2022-11-01 18:16:35,033 


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:39<15:55, 39.83s/trial, best loss: -0.9527854658129595][A
  8%|██▍                           | 2/25 [00:58<12:49, 33.45s/trial, best loss: -0.9565535966590955][A
 12%|███▌                          | 3/25 [01:15<10:26, 28.48s/trial, best loss: -0.9565535966590955][A
 16%|████▊                         | 4/25 [01:24<07:59, 22.83s/trial, best loss: -0.9565535966590955][A
 20%|██████                        | 5/25 [01:52<08:06, 24.32s/trial, best loss: -0.9565535966590955][A
 24%|███████▏                      | 6/25 [02:16<07:40, 24.23s/trial, best loss: -0.9565535966590955][A
 28%|████████▍                     | 7/25 [02:30<06:20, 21.12s/trial, best loss: -0.9565535966590955][A
 32%|█████████▌                    | 8/25 [02:39<04:56, 17.46s/trial, best loss: -0.9565535966590955][A
 36%|██████████▊                   | 9/25 [03:15<06:07

1it [09:28, 568.35s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:44<17:43, 44.32s/trial, best loss: -0.9536629451111548][A
  8%|██▍                           | 2/25 [01:06<14:26, 37.69s/trial, best loss: -0.9536629451111548][A
 12%|███▌                          | 3/25 [01:25<11:43, 31.98s/trial, best loss: -0.9536629451111548][A
 16%|████▊                         | 4/25 [01:36<08:59, 25.67s/trial, best loss: -0.9536629451111548][A
 20%|██████                        | 5/25 [02:06<08:59, 26.97s/trial, best loss: -0.9536629451111548][A
 24%|███████▏                      | 6/25 [02:32<08:26, 26.64s/trial, best loss: -0.9536629451111548][A
 28%|████████▍                     | 7/25 [02:47<06:57, 23.18s/trial, best loss: -0.9536629451111548][A
 32%|█████████▌                    | 8/25 [02:57<05:26, 19.22s/trial, best loss: -0.9536629451111548][A
 36%|██████████▊                   | 9/25 [03:34<06:33

2it [43:16, 1006.22s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█                          | 1/25 [02:42<1:04:52, 162.19s/trial, best loss: -0.9485618281973782][A
  8%|██▎                          | 2/25 [03:05<46:09, 120.42s/trial, best loss: -0.9506062990717722][A
 12%|███▌                          | 3/25 [03:25<33:08, 90.41s/trial, best loss: -0.9506062990717722][A
 16%|████▊                         | 4/25 [03:36<23:20, 66.71s/trial, best loss: -0.9506062990717722][A
 20%|██████                        | 5/25 [04:08<18:45, 56.25s/trial, best loss: -0.9506062990717722][A
 24%|███████▏                      | 6/25 [04:34<14:55, 47.12s/trial, best loss: -0.9506062990717722][A
 28%|████████▍                     | 7/25 [04:50<11:19, 37.74s/trial, best loss: -0.9506062990717722][A
 32%|█████████▌                    | 8/25 [05:00<08:19, 29.41s/trial, best loss: -0.9506062990717722][A
 36%|██████████▊                   | 9/25 [05:39<08:37

3it [54:56, 914.44s/it] 


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A
  4%|█▏                            | 1/25 [00:45<18:03, 45.15s/trial, best loss: -0.9536629451111548][A
  8%|██▍                           | 2/25 [01:07<14:43, 38.40s/trial, best loss: -0.9536629451111548][A
 12%|███▌                          | 3/25 [01:26<11:57, 32.61s/trial, best loss: -0.9536629451111548][A
 16%|████▊                         | 4/25 [01:37<09:07, 26.07s/trial, best loss: -0.9536629451111548][A
 20%|██████                        | 5/25 [02:07<09:04, 27.21s/trial, best loss: -0.9536629451111548][A
 24%|███████▏                      | 6/25 [02:33<08:27, 26.73s/trial, best loss: -0.9536629451111548][A
 28%|████████▍                     | 7/25 [02:48<06:58, 23.24s/trial, best loss: -0.9536629451111548][A
 32%|█████████▌                    | 8/25 [02:58<05:27, 19.25s/trial, best loss: -0.9536629451111548][A
 36%|██████████▊                   | 9/25 [03:35<06:33

4it [1:06:30, 848.37s/it]


  0%|                                                         | 0/25 [00:00<?, ?trial/s, best loss=?][A

Mean of empty slice



  4%|█▏                            | 1/25 [00:43<17:13, 43.07s/trial, best loss: -0.9527055425393021][A
  8%|██▍                           | 2/25 [01:05<14:07, 36.86s/trial, best loss: -0.9527055425393021][A
 12%|███▌                          | 3/25 [01:24<11:34, 31.55s/trial, best loss: -0.9527055425393021][A
 16%|████▊                         | 4/25 [01:35<08:51, 25.33s/trial, best loss: -0.9527055425393021][A
 20%|██████                        | 5/25 [02:04<08:49, 26.50s/trial, best loss: -0.9527055425393021][A
 24%|███████▏                      | 6/25 [02:29<08:14, 26.03s/trial, best loss: -0.9540756025730451][A
 28%|████████▍                     | 7/25 [02:44<06:49, 22.78s/trial, best loss: -0.9540756025730451][A
 32%|█████████▌                    | 8/25 [02:54<05:21, 18.89s/trial, best loss: -0.9540756025730451][A
 36%|██████████▊                   | 9/25 [03:31<06:26, 24.18s/trial, best loss: -0.9540756025730451][A
 40%|███████████▌                 | 10/25 [04:11<07:15

5it [1:21:38, 979.65s/it]


In [16]:
accs = np.array(acc_per_strategy)
accs = np.mean(accs, axis=1)
accs

array([0.94778985, 0.94778985, 0.94778985, 0.94778985, 0.94778985])

In [38]:
# acc_per_strategy[0]
accs = []
for i in range(3):
    for k,j in enumerate(range(i*5, i*5 + 5)):
        if i == 0:
            accs.append([])
        accs[k].append(acc_per_strategy[0][j])

In [39]:
accs

[[0.9973252703802767, 0.8792134831460674, 0.9782],
 [0.9973252703802767, 0.8728152309612984, 0.9728],
 [0.9986044888940574, 0.8748439450686641, 0.9776],
 [0.9973252703802767, 0.8728152309612984, 0.9728],
 [0.9981393185254099, 0.8528401997503121, 0.9742]]

In [40]:
np.mean(accs, axis=1)

array([0.95157958, 0.94764683, 0.95034948, 0.94764683, 0.94172651])

In [41]:
np.argmax(np.mean(accs, axis=1))

# Best validation performance observed for Strategy 1: impute 0 for nan

0

In [5]:
datasets = ["pubmed", "qian", "zbmath"]  
# skipping -- "inspire", "medline"
# done -- "aminer", "arnetminer", "kisti"

missing = [0.257931452991453, 0.191869137094056, 0.17408965811965812]
lengths = [300000, 281168, 300000]

for dataset_name in datasets:
    parent_dir = f"data/{dataset_name}"
    dataset = ANDData(
        signatures=join(parent_dir, f"{dataset_name}_signatures.json"),
        papers=join(parent_dir, f"{dataset_name}_papers.json"),
        mode="train",
        specter_embeddings=join(parent_dir, f"{dataset_name}_specter.pickle"),
        clusters=join(parent_dir, f"{dataset_name}_clusters.json"),
        block_type="s2",
        train_pairs_size=100000,
        val_pairs_size=100000,
        test_pairs_size=100000,
        name=dataset_name,
        n_jobs=8,
    )

    featurization_info = FeaturizationInfo()
    # the cache will make it faster to train multiple times - it stores the features on disk for you
    train, val, test = featurize(dataset, featurization_info, n_jobs=8, use_cache=True)
    
    missing_pct = (np.sum(np.isnan(train[0])) + np.sum(np.isnan(val[0])) + np.sum(np.isnan(test[0]))) / \
    (np.prod(train[0].shape) + np.prod(val[0].shape) + np.prod(test[0].shape))
    missing.append(missing_pct)
    
    total_pairs = train[0].shape[0] + val[0].shape[0] + test[0].shape[0]
    lengths.append(total_pairs)
    
    print(missing_pct, total_pairs)
    del dataset, train, val, test

print("missing_pcts:", missing)
print("missing_pct_avg", np.mean(missing))
print("n_pairs:", lengths)
print("n_pairs_mean:", np.mean(lengths))

2022-11-04 17:42:11,917 - s2and - INFO - loading papers
2022-11-04 17:42:17,794 - s2and - INFO - loaded papers
2022-11-04 17:42:17,795 - s2and - INFO - loading signatures
2022-11-04 17:42:20,395 - s2and - INFO - loaded signatures
2022-11-04 17:42:20,395 - s2and - INFO - loading clusters
2022-11-04 17:42:20,399 - s2and - INFO - loaded clusters, loading specter
2022-11-04 17:42:20,528 - s2and - INFO - loaded specter, loading cluster seeds
2022-11-04 17:42:20,528 - s2and - INFO - loaded cluster seeds
2022-11-04 17:42:20,529 - s2and - INFO - making signature to cluster id
2022-11-04 17:42:20,531 - s2and - INFO - made signature to cluster id
2022-11-04 17:42:20,531 - s2and - INFO - loading name counts
2022-11-04 17:42:36,632 - s2and - INFO - loaded name counts
2022-11-04 17:42:37,034 - s2and - INFO - preprocessing papers
Preprocessing papers 1/2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134623/1

0.2617637716462652 125525


2022-11-04 17:44:02,216 - s2and - INFO - loaded papers
2022-11-04 17:44:02,217 - s2and - INFO - loading signatures
2022-11-04 17:44:03,861 - s2and - INFO - loaded signatures
2022-11-04 17:44:03,861 - s2and - INFO - loading clusters
2022-11-04 17:44:03,868 - s2and - INFO - loaded clusters, loading specter
2022-11-04 17:44:04,133 - s2and - INFO - loaded specter, loading cluster seeds
2022-11-04 17:44:04,134 - s2and - INFO - loaded cluster seeds
2022-11-04 17:44:04,135 - s2and - INFO - making signature to cluster id
2022-11-04 17:44:04,137 - s2and - INFO - made signature to cluster id
2022-11-04 17:44:04,138 - s2and - INFO - loading name counts
2022-11-04 17:44:20,162 - s2and - INFO - loaded name counts
2022-11-04 17:44:20,380 - s2and - INFO - preprocessing papers
Preprocessing papers 1/2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 59545/59545 [00:09<00:00, 6590.52it/s]
Preprocessing papers 2/

0.17877380072502025 113652


2022-11-04 17:45:05,401 - s2and - INFO - loaded papers
2022-11-04 17:45:05,402 - s2and - INFO - loading signatures
2022-11-04 17:45:05,559 - s2and - INFO - loaded signatures
2022-11-04 17:45:05,559 - s2and - INFO - loading clusters
2022-11-04 17:45:05,580 - s2and - INFO - loaded clusters, loading specter
2022-11-04 17:45:05,995 - s2and - INFO - loaded specter, loading cluster seeds
2022-11-04 17:45:05,995 - s2and - INFO - loaded cluster seeds
2022-11-04 17:45:05,996 - s2and - INFO - making signature to cluster id
2022-11-04 17:45:06,001 - s2and - INFO - made signature to cluster id
2022-11-04 17:45:06,002 - s2and - INFO - loading name counts
2022-11-04 17:45:22,060 - s2and - INFO - loaded name counts
2022-11-04 17:45:22,159 - s2and - INFO - preprocessing papers
Preprocessing papers 1/2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23406/23406 [00:06<00:00, 3544.16it/s]
Preprocessing papers 2/

0.4144615842850388 147724
missing_pcts: [0.257931452991453, 0.191869137094056, 0.17408965811965812, 0.2617637716462652, 0.17877380072502025, 0.4144615842850388]
missing_pct_avg 0.24648156747691521
n_pairs: [300000, 281168, 300000, 125525, 113652, 147724]
n_pairs_mean: 211344.83333333334
