In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data.dataloader import default_collate

from torch.utils.data import (
    Dataset, TensorDataset, DataLoader, 
    RandomSampler, SequentialSampler
)

from torch.utils.data.dataset import random_split

import torch.optim as optim
import torchtext

In [86]:
import numpy as np
import pandas as pd

from collections import Counter
from pprint import pprint

from tqdm import trange, tqdm

In [3]:
YEAR = 2019
data_path = "../data/aspect_data"
judge_wv_path = f"judgements_word_vec_{YEAR}"

In [4]:
disease_wv = pd.read_pickle(f"{data_path}/{judge_wv_path}/disease_qe_word_vec_full_{YEAR}.pickle")
print(len(disease_wv.items()))
gene_wv = pd.read_pickle(f"{data_path}/{judge_wv_path}/gene_qe_word_vec_full_{YEAR}.pickle")
print(len(disease_wv.items()))

12996
12996


In [5]:
disease_wv["2019_26_NCT02195336"][:25]

[1, 1, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]

In [6]:
gene_wv["2019_26_NCT02195336"][:25]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [7]:
ct_fields = [
    "trec_doc_id", "pm_rel_desc", 
    "gene1_annotation_desc", "gene_wv", "gene_score", 
    "disease_desc", "disease_wv", "disease_score",
]

ct_scores = pd.read_pickle(f"{data_path}/ct_judgement_scores_{YEAR}.pickle")
ct_scores.head(n=2)

Unnamed: 0,trec_topic_number,trec_doc_id,pm_rel_desc,disease_desc,gene1_annotation_desc,gene1_name,gene2_annotation_desc,gene2_name,gene3_annotation_desc,gene3_name,demographics_desc,gene_score,disease_score
9,1,NCT00119249,Human PM,Exact,Missing Variant,BRAF (E586K),,,,,Matches,1,3
18,1,NCT00304525,Human PM,Exact,Missing Variant,BRAF (E586K),,,,,Not Discussed,1,3


In [8]:
ct_scores[ct_scores.duplicated(subset="trec_doc_id", keep=False)].sort_values("trec_doc_id")[:2]

Unnamed: 0,trec_topic_number,trec_doc_id,pm_rel_desc,disease_desc,gene1_annotation_desc,gene1_name,gene2_annotation_desc,gene2_name,gene3_annotation_desc,gene3_name,demographics_desc,gene_score,disease_score
5086,15,NCT00002667,Human PM,More General,Missing Gene,KRAS (G12V),Missing Variant,high tumor mutational burden,,,Matches,1,2
4405,13,NCT00002667,Human PM,More General,Missing Gene,EZR-ROS1 fusion,,,,,Matches,0,2


In [9]:
ct_scores["pm_rel_desc"].value_counts()

Human PM    5713
Name: pm_rel_desc, dtype: int64

In [10]:
ct_scores["year"] = YEAR
ct_scores["uniq_id"] = ct_scores["year"].astype(str) + "_" + \
    ct_scores["trec_topic_number"].astype(str) + "_" + \
    ct_scores["trec_doc_id"]

In [11]:
ct_scores["disease_wv"] = ct_scores["uniq_id"].map(disease_wv)
ct_scores["gene_wv"] = ct_scores["uniq_id"].map(gene_wv)

In [12]:
len(ct_scores[ct_scores["trec_doc_id"] == "NCT00936221"]["disease_wv"].iloc[0])

222

In [13]:
len(ct_scores[ct_scores["trec_doc_id"] == "NCT00936221"]["gene_wv"].iloc[0])

495

In [14]:
ct_scores[ct_fields].head()

Unnamed: 0,trec_doc_id,pm_rel_desc,gene1_annotation_desc,gene_wv,gene_score,disease_desc,disease_wv,disease_score
9,NCT00119249,Human PM,Missing Variant,"[0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1,Exact,"[0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",3
18,NCT00304525,Human PM,Missing Variant,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1,Exact,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",3
23,NCT00405587,Human PM,Different Variant,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,More General,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",2
33,NCT00811759,Human PM,Missing Variant,"[0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1,More Specific,"[0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",2
37,NCT00936221,Human PM,Missing Variant,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1,Exact,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",3


In [15]:
data_path = "../data/aspect_data"
judge_wv_path = f"judgements_word_vec_{YEAR}"
doc_wv = pd.read_pickle(f"{data_path}/{judge_wv_path}/doc_wv_all_baas_scibert_{YEAR}.pickle")
print(len(doc_wv["NCT02286219"]))
doc_wv["NCT02286219"]

768


array([ 8.61359527e-04, -2.81754971e-01, -3.29208463e-01,  5.31943500e-01,
       -8.72500520e-03, -3.49188805e-01,  2.48892456e-01,  7.87907660e-01,
        1.43921152e-01,  1.55066639e-01,  2.37071052e-01, -3.80850285e-01,
       -2.42039144e-01,  3.59762400e-01,  2.01381400e-01, -2.44752035e-01,
        4.21687871e-01,  9.38035697e-02,  3.28805745e-01, -1.43198758e-01,
       -3.05172920e-01,  4.26285982e-01,  1.48630455e-01, -1.51338905e-01,
        6.13931835e-01, -4.52057160e-02,  1.22143395e-01, -6.45003170e-02,
        8.38697404e-02, -3.25944722e-01,  1.89856514e-01, -2.35023245e-01,
        1.33091887e-03, -5.52369654e-01, -2.22136244e-01,  1.87645003e-01,
        1.67359799e-01,  1.55393884e-01, -3.17962915e-01,  4.62974794e-02,
       -2.56724596e-01, -3.80310148e-01,  1.27255231e-01, -2.26876631e-01,
        6.29170537e-01, -3.92039955e-01, -2.94103086e-01,  1.16876900e-01,
       -1.88982651e-01, -5.53722680e-02,  3.08715373e-01, -7.02668950e-02,
        2.11927772e-01,  

In [16]:
len(doc_wv.keys())

3707

In [17]:
df = pd.read_pickle(f"../data/aspect_data/trials_judgement_combined_full_{YEAR}.pickle")
print(df.shape)
df.head()

(5713, 16)


Unnamed: 0,score,id,brief_summary,brief_title,minimum_age,gender,primary_outcome,detailed_description,keywords,official_title,intervention_type,intervention_name,intervention_browse,condition_browse,inclusion,exclusion
0,5.55757,NCT02286219,"To evaluate the safety, tolerability, dose lim...","Phase 1, Multiple Ascending Dose Study of Anti...",6570,All,Adverse events as assessed by NCI Common Toxic...,"This is a Phase 1, open-label, multiple dose e...",,"Phase 1, Multiple Ascending Dose Study of Anti...",Drug,FS102,Immunoglobulins Immunoglobulin G,,For more information regarding BMS clinical tr...,ary concurrent medical conditions: 1. Hyperten...
1,5.55757,NCT02342587,Although anti-HER2 (human epidermal growth fac...,Safety and Clinical Activity of Lapatinib in P...,7300,All,overall response rate 4 weeks after treatment ...,,Lapatinib,Safety and Clinical Activity of Lapatinib in P...,Drug,Lapatinib,Lapatinib,Neoplasms,- Age > 19 - Written informed consent - ECOG 0...,"- Uncontrolled symptomatic brain metastasis,Un..."
2,5.55757,NCT02386501,"This is a Phase 1b, multicenter, open-label, d...",Dose Escalation Study of ADXS31-164 in Subject...,6570,All,Number of patients with dose-limiting toxiciti...,,,A Phase 1b Dose Escalation Study of ADXS31-164...,Drug,ADXS31-164,,,- HER2 Positive - Have histological or cytolog...,- Is newly diagnosed with a curative treatment...
3,5.55757,NCT02394496,Based on these results it can be envisioned th...,Overcoming Endocrine Resistance in Metastatic ...,6570,Female,Progression Free Survival Defined as the time ...,In presence of ER hypersensitivity even a smal...,Fulvestrant Lapatinib Aromatase Inhibitor meta...,A Randomized Trial With Factorial Design Compa...,Drug Drug Drug Drug,Fulvestrant Lapatinib Aromatase Inhibitors Pla...,Lapatinib Fulvestrant Aromatase Inhibitors,Breast Neoplasms,1. Provision of written informed consent 2. Hi...,1. Previous therapy with Fulvestrant and/or La...
4,5.55757,NCT02396108,The current standard of care for stage I-III H...,Dose-confirmation Study of ASLAN001 Combined W...,7665,All,Pathologic complete response rate Post neoadju...,Breast cancer is the leading cause of cancer d...,,Phase Ib Dose-confirmation Study of ASLAN001 C...,Drug,Paclitaxel + Carboplatin + ASLAN001,Paclitaxel Albumin-Bound Paclitaxel Carboplatin,Breast Neoplasms,- Age ≥ 21 years - Karnofsky performance statu...,- Concurrent administration of any other tumor...


In [18]:
set(doc_wv.keys()) - set(df["id"].unique())

set()

In [19]:
len(set(df["id"].unique()))

3707

In [20]:
len(ct_scores["trec_doc_id"].unique())

3707

In [21]:
len(set(ct_scores["trec_doc_id"].unique()) - set(df["id"].unique()))

0

In [22]:
len(set(df["id"].unique()) - set(ct_scores["trec_doc_id"].unique()))

0

In [23]:
len(doc_wv.items())

3707

### Combine Data

In [24]:
df_qrels = pd.read_pickle("../data/trials_topics_combined_all_years.pickle")
df_qrels["uniq_id"] = df_qrels["year"].astype(str) + "_" + \
    df_qrels["topic"].astype(str) + "_" + \
    df_qrels["id"]
df_qrels = df_qrels[df_qrels["year"] == YEAR].copy()
print(df_qrels.shape)
print(df_qrels.id.nunique())
df_qrels.head()

(12996, 24)
8567


Unnamed: 0,score,id,brief_summary,brief_title,minimum_age,gender,primary_outcome,detailed_description,keywords,official_title,...,inclusion,exclusion,topic,_,label,disease,gene,age,year,uniq_id
27207,1.0,NCT02195336,"To date, there are no methods to reliably sele...",dMR During First Line Treatment of Non Squamou...,6570,female,"dMRT changes during treatment Baseline, Day 8,...","A Non Interventional, open-label, single arm, ...",,Diffusion MR (dMRT) During First Line Treatmen...,...,1. Written informed consent obtained prior to ...,"1. Mixed, non-small cell and small cell tumour...",26,0,0,squamous cell lung cancer,FGFR1 amplification,,2019,2019_26_NCT02195336
27208,1.0,NCT01623115,Alirocumab (SAR236553/REGN727) is a fully huma...,Efficacy and Safety of Alirocumab (SAR236553/R...,6570,female,Percent Change From Baseline in Calculated LDL...,The maximum study duration was planned to be 8...,PCSK9,"A Randomized, Double-Blind, Placebo-Controlled...",...,- Participants with heterozygous familial hype...,"- Age < 18 years or legal age of adulthood, wh...",35,0,0,familial hypercholesterolemia,PCSK9,,2019,2019_35_NCT01623115
27209,1.0,NCT02898857,"Every year in France, 30.000 deaths are due to...",Chemoresistance and Involvement of the NOTCH P...,6570,male,analysed by immunochemistry (IHC) 1 day The bi...,"Every year in France, 30.000 deaths are due to...",,CHEMO RESISTANCE TO PLATINUM COMPOUNDS AND NOT...,...,"patients of both sex, aged 75 years or younger...",patients having received more than one line of...,15,0,2,lung adenocarcinoma,"KRAS (G12V), high tumor mutational burden",,2019,2019_15_NCT02898857
27210,1.0,NCT02898857,"Every year in France, 30.000 deaths are due to...",Chemoresistance and Involvement of the NOTCH P...,6570,male,analysed by immunochemistry (IHC) 1 day The bi...,"Every year in France, 30.000 deaths are due to...",,CHEMO RESISTANCE TO PLATINUM COMPOUNDS AND NOT...,...,"patients of both sex, aged 75 years or younger...",patients having received more than one line of...,27,0,0,non-small cell lung cancer,KRAS (G12C),,2019,2019_27_NCT02898857
27211,1.0,NCT01605617,The investigators are studying two FDA-approve...,Trial of Percutaneous Tibial Nerve Stimulation...,6570,male,Number of Urinary Voids Per 24 Hours After 12 ...,Overactive bladder (OAB) is described as urina...,,Prospective Randomized Trial of Percutaneous T...,...,- Female ages > 18 and < 100 years old without...,- Has had PTNS modulation in the past - Has a ...,8,0,0,bladder cancer,FGFR3 (S249C),,2019,2019_8_NCT01605617


In [25]:
df_qrels["id"].apply(lambda x: len(x)).unique()

array([11])

In [26]:
print(len(ct_scores["uniq_id"].unique()))
print(len(df_qrels["uniq_id"].unique()))

5713
12996


In [27]:
uniq_ids = set(ct_scores["uniq_id"].unique()) - (set(ct_scores["uniq_id"].unique()) - set(df_qrels["uniq_id"]))
print(len(uniq_ids))
uniq_ids

5686


{'2019_16_NCT02915666',
 '2019_37_NCT01337765',
 '2019_15_NCT00837135',
 '2019_16_NCT03839342',
 '2019_6_NCT02588261',
 '2019_18_NCT02038348',
 '2019_19_NCT03166904',
 '2019_1_NCT02416232',
 '2019_34_NCT02296190',
 '2019_20_NCT02797964',
 '2019_2_NCT02360579',
 '2019_7_NCT00533949',
 '2019_2_NCT01910181',
 '2019_13_NCT03856411',
 '2019_27_NCT02743923',
 '2019_1_NCT02399943',
 '2019_24_NCT00972686',
 '2019_20_NCT02953457',
 '2019_7_NCT00039182',
 '2019_15_NCT00940381',
 '2019_18_NCT02601079',
 '2019_15_NCT01986166',
 '2019_5_NCT00709761',
 '2019_16_NCT03829410',
 '2019_6_NCT00601848',
 '2019_25_NCT02587650',
 '2019_7_NCT03399487',
 '2019_5_NCT03550755',
 '2019_24_NCT01151007',
 '2019_20_NCT02000622',
 '2019_3_NCT03040791',
 '2019_30_NCT01550380',
 '2019_21_NCT01396148',
 '2019_27_NCT01384994',
 '2019_25_NCT01941927',
 '2019_22_NCT03663062',
 '2019_34_NCT02170961',
 '2019_18_NCT02155621',
 '2019_39_NCT02632045',
 '2019_13_NCT03322540',
 '2019_6_NCT00027690',
 '2019_36_NCT03654716',
 '201

In [28]:
5713 - 5686 

27

In [29]:
doc_wv.keys()

dict_keys(['NCT02286219', 'NCT02342587', 'NCT02386501', 'NCT02394496', 'NCT02396108', 'NCT02451553', 'NCT02473653', 'NCT02593708', 'NCT02658084', 'NCT02675829', 'NCT02716116', 'NCT02880007', 'NCT02892123', 'NCT02901301', 'NCT02952729', 'NCT03032107', 'NCT03043313', 'NCT03084926', 'NCT03125200', 'NCT03134638', 'NCT03330561', 'NCT03364348', 'NCT03365882', 'NCT03368196', 'NCT03438396', 'NCT03448042', 'NCT03457896', 'NCT03469531', 'NCT03493854', 'NCT03550755', 'NCT03602079', 'NCT03680560', 'NCT03725436', 'NCT03786107', 'NCT03816553', 'NCT03832855', 'NCT03847168', 'NCT03853915', 'NCT03916094', 'NCT00027690', 'NCT00029003', 'NCT00040794', 'NCT00049543', 'NCT00068497', 'NCT00068653', 'NCT00072631', 'NCT00079066', 'NCT00085553', 'NCT00087412', 'NCT00088959', 'NCT00091156', 'NCT00091806', 'NCT00097227', 'NCT00101920', 'NCT00118157', 'NCT00125372', 'NCT00130520', 'NCT00217698', 'NCT00226239', 'NCT00230126', 'NCT00265317', 'NCT00266877', 'NCT00288054', 'NCT00294762', 'NCT00312377', 'NCT00315185',

In [30]:
len(set(df_qrels["uniq_id"]) - set(ct_scores["uniq_id"].unique()))

7310

In [31]:
df_qrels_subset = df_qrels[
    (df_qrels["uniq_id"].isin(ct_scores["uniq_id"].unique())) &
    (df_qrels["year"])
]
print(df_qrels_subset.shape)
df_qrels_subset.head()

(5686, 24)


Unnamed: 0,score,id,brief_summary,brief_title,minimum_age,gender,primary_outcome,detailed_description,keywords,official_title,...,inclusion,exclusion,topic,_,label,disease,gene,age,year,uniq_id
27208,1.0,NCT01623115,Alirocumab (SAR236553/REGN727) is a fully huma...,Efficacy and Safety of Alirocumab (SAR236553/R...,6570,female,Percent Change From Baseline in Calculated LDL...,The maximum study duration was planned to be 8...,PCSK9,"A Randomized, Double-Blind, Placebo-Controlled...",...,- Participants with heterozygous familial hype...,"- Age < 18 years or legal age of adulthood, wh...",35,0,0,familial hypercholesterolemia,PCSK9,,2019,2019_35_NCT01623115
27209,1.0,NCT02898857,"Every year in France, 30.000 deaths are due to...",Chemoresistance and Involvement of the NOTCH P...,6570,male,analysed by immunochemistry (IHC) 1 day The bi...,"Every year in France, 30.000 deaths are due to...",,CHEMO RESISTANCE TO PLATINUM COMPOUNDS AND NOT...,...,"patients of both sex, aged 75 years or younger...",patients having received more than one line of...,15,0,2,lung adenocarcinoma,"KRAS (G12V), high tumor mutational burden",,2019,2019_15_NCT02898857
27213,1.0,NCT00075270,The purpose of this study is to determine the ...,Paclitaxel With / Without GW572016 (Lapatinib)...,6570,male,Time to Progression as Evaluated by the Invest...,,metastatic breast cancer ErbB1 kinase inhibito...,"A Randomized, Multicenter, Double-Blind, Place...",...,- Signed Informed Consent - Able to swallow an...,- Prior treatment regimens for advanced or met...,4,0,0,gastric cancer,ERBB2 amplification,,2019,2019_4_NCT00075270
27214,1.0,NCT00075270,The purpose of this study is to determine the ...,Paclitaxel With / Without GW572016 (Lapatinib)...,6570,female,Time to Progression as Evaluated by the Invest...,,metastatic breast cancer ErbB1 kinase inhibito...,"A Randomized, Multicenter, Double-Blind, Place...",...,- Signed Informed Consent - Able to swallow an...,- Prior treatment regimens for advanced or met...,5,0,0,cervical cancer,ERBB2 (S310Y),,2019,2019_5_NCT00075270
27215,1.0,NCT00075270,The purpose of this study is to determine the ...,Paclitaxel With / Without GW572016 (Lapatinib)...,6570,female,Time to Progression as Evaluated by the Invest...,,metastatic breast cancer ErbB1 kinase inhibito...,"A Randomized, Multicenter, Double-Blind, Place...",...,- Signed Informed Consent - Able to swallow an...,- Prior treatment regimens for advanced or met...,19,0,0,colon cancer,ERBB2 amplification,,2019,2019_19_NCT00075270


In [32]:
df_qrels_subset.label.unique()

array([0, 2, 1])

In [33]:
df_qrels_subset["label"].loc[df_qrels_subset["label"] >= 1] = 1

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_with_indexer(indexer, value)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  """Entry point for launching an IPython kernel.


In [34]:
df_qrels_subset.label.unique()

array([0, 1])

In [35]:
df_qrels_subset.label.value_counts()

0    3502
1    2184
Name: label, dtype: int64

In [36]:
df_qrels_subset.head()

Unnamed: 0,score,id,brief_summary,brief_title,minimum_age,gender,primary_outcome,detailed_description,keywords,official_title,...,inclusion,exclusion,topic,_,label,disease,gene,age,year,uniq_id
27208,1.0,NCT01623115,Alirocumab (SAR236553/REGN727) is a fully huma...,Efficacy and Safety of Alirocumab (SAR236553/R...,6570,female,Percent Change From Baseline in Calculated LDL...,The maximum study duration was planned to be 8...,PCSK9,"A Randomized, Double-Blind, Placebo-Controlled...",...,- Participants with heterozygous familial hype...,"- Age < 18 years or legal age of adulthood, wh...",35,0,0,familial hypercholesterolemia,PCSK9,,2019,2019_35_NCT01623115
27209,1.0,NCT02898857,"Every year in France, 30.000 deaths are due to...",Chemoresistance and Involvement of the NOTCH P...,6570,male,analysed by immunochemistry (IHC) 1 day The bi...,"Every year in France, 30.000 deaths are due to...",,CHEMO RESISTANCE TO PLATINUM COMPOUNDS AND NOT...,...,"patients of both sex, aged 75 years or younger...",patients having received more than one line of...,15,0,1,lung adenocarcinoma,"KRAS (G12V), high tumor mutational burden",,2019,2019_15_NCT02898857
27213,1.0,NCT00075270,The purpose of this study is to determine the ...,Paclitaxel With / Without GW572016 (Lapatinib)...,6570,male,Time to Progression as Evaluated by the Invest...,,metastatic breast cancer ErbB1 kinase inhibito...,"A Randomized, Multicenter, Double-Blind, Place...",...,- Signed Informed Consent - Able to swallow an...,- Prior treatment regimens for advanced or met...,4,0,0,gastric cancer,ERBB2 amplification,,2019,2019_4_NCT00075270
27214,1.0,NCT00075270,The purpose of this study is to determine the ...,Paclitaxel With / Without GW572016 (Lapatinib)...,6570,female,Time to Progression as Evaluated by the Invest...,,metastatic breast cancer ErbB1 kinase inhibito...,"A Randomized, Multicenter, Double-Blind, Place...",...,- Signed Informed Consent - Able to swallow an...,- Prior treatment regimens for advanced or met...,5,0,0,cervical cancer,ERBB2 (S310Y),,2019,2019_5_NCT00075270
27215,1.0,NCT00075270,The purpose of this study is to determine the ...,Paclitaxel With / Without GW572016 (Lapatinib)...,6570,female,Time to Progression as Evaluated by the Invest...,,metastatic breast cancer ErbB1 kinase inhibito...,"A Randomized, Multicenter, Double-Blind, Place...",...,- Signed Informed Consent - Able to swallow an...,- Prior treatment regimens for advanced or met...,19,0,0,colon cancer,ERBB2 amplification,,2019,2019_19_NCT00075270


In [37]:
doc_ids = []
X_gene = []
X_disease = []
X_doc_wv = [] # Take the last 11 chars since trec_doc_id all have 11 chars.
y = []
counter = 0
for key in uniq_ids:
    doc_ids.append(key)
    X_gene.append(ct_scores[ct_scores["uniq_id"] == key]["gene_wv"].item())
    X_disease.append(ct_scores[ct_scores["uniq_id"] == key]["disease_wv"].item())
    X_doc_wv.append(doc_wv[key[-11:]])
#     print(df_qrels_subset[df_qrels_subset["uniq_id"] == key]["label"].item())
    y.append(df_qrels_subset[df_qrels_subset["uniq_id"] == key]["label"].item())
    counter += 1
#     if counter > 2000:
#         break

  if __name__ == '__main__':
  # Remove the CWD from sys.path while we load stuff.
  del sys.path[0]


## Baseline Model

Here we simply concatenate the vectors and pass it into a dense layer.

In [38]:
np.array([0, 1], dtype=np.float)

array([0., 1.])

In [39]:
def num_correctly_classified(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat)

In [100]:
class DatasetWithStringID(Dataset):
    def __init__(
        self, X_doc_wv, y, doc_ids,
        loss_func):

        """
        We utilise a custom DataSet class to enough that we can pass IDs
        in and out of the batch at evaluation time.

        Parameters
        ----------
        doc_ids: ID of docs (clinical trials) that are tagged by year, trec_doc_id and topic
        loss_func: ["entropy", "bce"]; different formats required given type of data
        """
        
        self.X_doc_wv = X_doc_wv
        self.y = y
        self.doc_ids = doc_ids
        
        self.loss_func = loss_func

    def __len__(self):
        return len(self.X_doc_wv)

    def __getitem__(self, index):
        """ Returns one sample of data"""
        # CrossEntropyLoss
        if self.loss_func == "entropy":
            y_tensor = torch.tensor(self.y[index], dtype=torch.long)
        
        # BCELoss
        elif self.loss_func == "bce":
            y = np.array([0, 1], dtype=np.float) if (self.y[index] == 1) else np.array([1, 0],  dtype=np.float)  
            y_tensor = torch.tensor(y, dtype=torch.float)
            
        else:
            raise ValueError("loss_func needs to specified correctly!")
        ret_arr = [
            torch.tensor(self.X_doc_wv[index], dtype=torch.float),
            y_tensor,
            self.doc_ids[index]
        ]
        return ret_arr
    
    @classmethod
    def values_count(cls, x):
        return Counter(x)

class DatasetWithFields(DatasetWithStringID):
    def __init__(
        self, X_doc_wv, y, 
        doc_ids, loss_func,
        X_gene, X_disease):

        """
        We utilise a custom DataSet class to enough that we can pass IDs
        in and out of the batch at evaluation time.

        Parameters
        ----------
        doc_ids: ID of docs (clinical trials) that are tagged by year, trec_doc_id and topic
        loss_func: ["entropy", "bce"]; different formats required given type of data
        """
        super().__init__(X_doc_wv, y, doc_ids, loss_func)

        self.X_gene = X_gene
        self.X_disease = X_disease

    def __getitem__(self, index):
        """ Returns one sample of data"""
#         print(f"index is: {index}")
#         print("printing X_doc_wv[index]")
        X = np.concatenate((self.X_gene[index], self.X_disease[index], self.X_doc_wv[index]), axis=0)
#         print(X_doc_wv[index])
#         print(torch.FloatTensor(X))
#         print("Finished printing")
#         X = torch.tensor(X, dtype=torch.float32)

        # CrossEntropyLoss
        if self.loss_func == "entropy":
            y_tensor = torch.tensor(self.y[index], dtype=torch.long)
        
        # BCELoss
        elif self.loss_func == "bce":
            y = np.array([0, 1], dtype=np.float) if (self.y[index] == 1) else np.array([1, 0],  dtype=np.float)  
            y_tensor = torch.tensor(y, dtype=torch.float)
            
        else:
            raise ValueError("loss_func needs to specified correctly!")
        ret_arr = [
            torch.tensor(X, dtype=torch.float),
            y_tensor,
            self.doc_ids[index]
        ]
        return ret_arr

def id_collate(batch):
    new_batch = []
    doc_ids = []

    for b in batch:
        new_batch.append(b[:-1])
        doc_ids.append(b[-1])
    return default_collate(new_batch), doc_ids

In [101]:
class denseNetSoftmax(nn.Module):
    """
    Simple neural net with 3 dense layers.

    We use this for testing the denseNetDataLoader.py
    """
    def __init__(self, wv_dim, layer1_dim, layer2_dim, num_labels):
        super(denseNet, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = torch.nn.Linear(wv_dim, layer1_dim)
        self.fc2 = torch.nn.Linear(layer1_dim, layer2_dim)
        self.fc3 = torch.nn.Linear(layer2_dim, num_labels)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        return F.softmax(x, dim=0)

class denseNet(nn.Module):
    """
    Simple neural net with 3 dense layers.

    We use this for testing the denseNetDataLoader.py
    """
    def __init__(
        self, wv_dim, 
        layer1_dim, layer2_dim,
        layer3_dim, layer4_dim,
        layer5_dim, num_labels):
        super(denseNet, self).__init__()
        # an affine operation: y = Wx + b
        self.fc1 = torch.nn.Linear(wv_dim, layer1_dim)
        self.fc2 = torch.nn.Linear(layer1_dim, layer2_dim)
        self.fc3 = torch.nn.Linear(layer2_dim, layer3_dim)
        self.fc4 = torch.nn.Linear(layer3_dim, layer4_dim)
        self.fc5 = torch.nn.Linear(layer4_dim, layer5_dim)
        self.fc6 = torch.nn.Linear(layer5_dim, num_labels)
        self.drop1 = torch.nn.Dropout(p=0.5)
        self.drop2 = torch.nn.Dropout(p=0.5)
        self.drop3 = torch.nn.Dropout(p=0.5)
        
        self.fc_layers = [
            self.fc1, self.fc2, self.fc3,
            self.fc4, self.fc5, self.fc6
        ]
        
        self.init_weights()
        
    def init_weights(self):
        init_range = 0.5
        [fc.weight.data.uniform_(-init_range, init_range) for fc in self.fc_layers]
#         [fc.bias.data.zero_() for fc in self.fc_layers]
        
    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = F.gelu(self.drop1(x))
        x = F.gelu(self.fc3(x))
        x = F.gelu(self.drop2(x))
        x = F.tanh(self.fc4(x))
        x = F.relu(self.drop3(x))
        x = F.gelu(self.fc5(x))
        return self.fc6(x)
    
class nnModelBCELoss:
    def __init__(self, model):
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.criterion = torch.nn.BCELoss().to(self.device)        
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=4.0)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1, gamma=0.9)

    def train(self, training_data, batch_size, collate_fn, epochs):
        trainDataLoader = DataLoader(
            training_data,
            batch_size=batch_size,
            collate_fn=collate_fn
        )

        train_loss = 0
        for epoch in trange(epochs, desc="EPOCHS"):
            epoch_iterator = tqdm(trainDataLoader, desc="Iteration")
            train_acc = 0
            
            pred_array = None
            labels_array = None
            
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                self.model.zero_grad()

#                 print(f"batch[0][0]: {batch[0][0]}")
#                 print(f"batch[0][1]: {batch[0][1]}")
#                 print(f"batch[1]: {batch[1]}")
                training_vec_batch = batch[0][0].to(self.device)
                label_vec_batch = batch[0][1].to(self.device)
                doc_ids_batch = batch[1]
                
                output = self.model(training_vec_batch)
#                 print(f"output: {output}")
#                 print(f"label_vec_batch: {label_vec_batch}")
#                 print(f"output.shape: {output.shape}")
#                 print(f"label_vec_batch.shape: {label_vec_batch.shape}")
                loss = self.criterion(output, label_vec_batch)
                loss.backward()
                train_loss += loss.item()
                # Need to compute the entire array before calculating 
                # argmax predictions etc.
                print(f"output.argmax(1): {output.argmax(1)}")
                print(f"label_vec_batch.argmax(1): {label_vec_batch.argmax(1)}")
                train_acc += (output.argmax(dim=1) == label_vec_batch.argmax(dim=1)).sum().item()
                print(f"train_acc: {train_acc}")
                self.optimizer.step()
            self.scheduler.step()
            print(f'Loss: {train_loss:.4f}(train)\t|\tAcc: {train_acc/len(training_data) * 100:.1f}%(train)')
        return train_loss, train_acc

    def evaluate(self, testing_data, batch_size, collate_fn):
        eval_loss = 0
        eval_acc = 0

        testDataLoader = DataLoader(
            testing_data,
            batch_size=batch_size,
            collate_fn=collate_fn
        )

        self.model.eval()
        self.model.to(self.device)
        for batch in tqdm(testDataLoader, desc="EVALUATING"):
            with torch.no_grad():
                testing_vec_batch = batch[0][0].to(self.device)
                label_vec_batch = batch[0][1].to(self.device)
                doc_ids_batch = batch[1]

                output = self.model(testing_vec_batch)
                loss = self.criterion(output, label_vec_batch)
                eval_loss += loss.item()
                print(f"output.argmax(1): {output.argmax(1)}")
                print(f"label_vec_batch.argmax(1): {label_vec_batch.argmax(1)}")
                eval_acc += (output.argmax(dim=1) == label_vec_batch.argmax(dim=1)).sum().item()
        print(f'Loss: {eval_loss:.4f}(test)\t|\tAcc: {eval_acc/len(testing_data) * 100:.1f}%(test)')
        return eval_loss, eval_acc

In [210]:
class nnModelCrossEntropyLoss:
    def __init__(self, model):
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
#         self.criterion = torch.nn.BCELoss().to(self.device)        
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)
        self.optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=4e-2)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, 
            step_size=15, 
            gamma=0.5
        )

    def train(self, training_data, batch_size, collate_fn, epochs):
        trainDataLoader = DataLoader(
            training_data,
            batch_size=batch_size,
            collate_fn=collate_fn
        )

        train_loss = 0
        for epoch in trange(epochs, desc="EPOCHS"):
            epoch_iterator = tqdm(trainDataLoader, desc="Iteration")
            train_acc = 0
            
            pred_array = None
            labels_array = None
            
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                self.model.zero_grad()

#                 print(f"batch[0][0]: {batch[0][0]}")
#                 print(f"batch[0][1]: {batch[0][1]}")
#                 print(f"batch[1]: {batch[1]}")
                training_vec_batch = batch[0][0].to(self.device)
                label_vec_batch = batch[0][1].to(self.device)
                doc_ids_batch = batch[1]
                
                output = self.model(training_vec_batch)
#                 print(f"output: {output}")
#                 print(f"label_vec_batch: {label_vec_batch}")
#                 print(f"output.shape: {output.shape}")
#                 print(f"label_vec_batch.shape: {label_vec_batch.shape}")
                loss = self.criterion(output, label_vec_batch)
                loss.backward()
                train_loss += loss.item()

                if pred_array is None:
                    pred_array = output.detach().cpu().numpy()
                else:
                    pred_array = np.append(
                    pred_array,
                    output.detach().cpu().numpy(), 
                    axis=0
                )
                if labels_array is None:
                    labels_array = label_vec_batch.detach().cpu().numpy()
                else:
                    labels_array = np.append(
                    labels_array,
                    label_vec_batch.detach().cpu().numpy(), 
                    axis=0
                )
                self.optimizer.step()
            self.scheduler.step()
            
            print(f"pred_array: {pred_array}")
            print(f"labels_array: {labels_array}")
            train_acc = num_correctly_classified(pred_array, labels_array)
            
            print(f'Loss: {train_loss:.4f}(train)\t|\tAcc: {train_acc/len(training_data) * 100:.1f}%(train)')
        return train_loss, train_acc

    def evaluate(self, testing_data, batch_size, collate_fn):
        eval_loss = 0
        eval_acc = 0
        
        pred_array = None
        labels_array = None

        testDataLoader = DataLoader(
            testing_data,
            batch_size=batch_size,
            collate_fn=collate_fn
        )

        self.model.eval()
        self.model.to(self.device)
        for batch in tqdm(testDataLoader, desc="EVALUATING"):
            with torch.no_grad():
                testing_vec_batch = batch[0][0].to(self.device)
                label_vec_batch = batch[0][1].to(self.device)
                doc_ids_batch = batch[1]

                output = self.model(testing_vec_batch)
                loss = self.criterion(output, label_vec_batch)
                eval_loss += loss.item()
                
                if pred_array is None:
                    pred_array = output.detach().cpu().numpy()
                else:
                    pred_array = np.append(
                    pred_array,
                    output.detach().cpu().numpy(), 
                    axis=0
                )
                if labels_array is None:
                    labels_array = label_vec_batch.detach().cpu().numpy()
                else:
                    labels_array = np.append(
                    labels_array,
                    label_vec_batch.detach().cpu().numpy(), 
                    axis=0
                )
        
        print(f"pred_array: {pred_array}")
        print(f"labels_array: {labels_array}")
        eval_acc = num_correctly_classified(pred_array, labels_array)
        print(f'Loss: {eval_loss:.4f}(test)\t|\tAcc: {eval_acc/len(testing_data) * 100:.1f}%(test)')
        return eval_loss, eval_acc

In [211]:
# count = 0
# doc_wv_arr = []
# for k, v in doc_wv.items():
#     doc_wv_arr.append(v)
#     count += 1
#     if count > 120:
#         break
# doc_ids = list(ct_scores["uniq_id"][:50])
# X_gene = list(ct_scores["gene_wv"][:50])
# X_disease = list(ct_scores["disease_wv"][:50])
# X_doc_wv = doc_wv_arr[:50]
# y = list(df_qrels_subset["label"][:50])

loss_func = "entropy"
use_fields = True

if use_fields:
    training_data = DatasetWithFields(
        X_doc_wv, y, doc_ids,
        loss_func,
        X_gene, X_disease
    )
else:
    training_data = DatasetWithStringID(
        X_doc_wv, y, doc_ids,
        loss_func,
    )

training_data

<__main__.DatasetWithFields at 0x131e74e10>

In [212]:
len_train = len(training_data)
len_test = int(len(training_data)*0.2)
print(len_train)
print(len_test)
train_, test_ = random_split(training_data, [len_train - len_test, len_test])

5686
1137


In [213]:
[test_.dataset[i][1].item() for i in range(len(test_))]

[0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 0,


In [223]:
count_test = DatasetWithStringID.values_count([test_.dataset[i][1].item() for i in range(len(test_))])
count_test

Counter({0: 702, 1: 435})

In [224]:
count_test[0]/(count_test[0] + count_test[1])

0.6174142480211082

In [225]:
count_train = DatasetWithStringID.values_count([train_.dataset[i][1].item() for i in range(len(train_))])
count_train


Counter({0: 2808, 1: 1741})

In [226]:
count_train[0]/(count_train[0] + count_train[1])

0.6172785227522533

In [216]:
trainDataLoader = DataLoader(
    train_,
    batch_size=8,
    shuffle=True,
    collate_fn=id_collate
)
c = 0
for batch in trainDataLoader:
    print(batch)
    training_vec_batch = batch[0][0]
    label_vec_batch = batch[0][1]
    doc_ids_batch = batch[1]
    print("-----")
    print("-----")
    print(training_vec_batch)
    print("-----")
    print("-----")
    print(label_vec_batch)
    print("-----")
    print("-----")
    print(doc_ids_batch)
    c += 1
    if c > 1:
        break

([tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0551,  0.0975, -0.6466],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.2515, -0.0928, -1.0609],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.1147, -0.0653, -0.8359],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.1615, -0.4006, -0.5722],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.1154, -0.0185, -0.8133],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.1704,  0.0801, -1.0206]]), tensor([0, 1, 0, 0, 0, 0, 1, 0])], ['2019_13_NCT02540824', '2019_1_NCT02583516', '2019_19_NCT00536809', '2019_10_NCT02038010', '2019_10_NCT03618667', '2019_39_NCT03676504', '2019_6_NCT02364609', '2019_13_NCT03215511'])
-----
-----
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0551,  0.0975, -0.6466],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.2515, -0.0928, -1.0609],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.1147, -0.0653, -0.8359],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.1615, -0.4006, -0.5722],
        [ 0.0000,  0.0000,  0.0000,  ..

In [217]:
if use_fields:
    vec_dim = len(gene_wv["2019_26_NCT02195336"]) + \
        len(disease_wv["2019_26_NCT02195336"]) + \
        len(doc_wv["NCT02286219"])
else:
    vec_dim = len(doc_wv["NCT02286219"])
print(vec_dim)

1485


In [218]:
l_one = 512
l_two = 256
l_three = 128
l_four = 128
l_five = 64

dNet = denseNet(vec_dim, l_one, l_two, l_three, l_four, l_five, num_labels=2)
dNet

denseNet(
  (fc1): Linear(in_features=1485, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=128, bias=True)
  (fc5): Linear(in_features=128, out_features=64, bias=True)
  (fc6): Linear(in_features=64, out_features=2, bias=True)
  (drop1): Dropout(p=0.5, inplace=False)
  (drop2): Dropout(p=0.5, inplace=False)
  (drop3): Dropout(p=0.5, inplace=False)
)

In [219]:
batch_size = 16
epochs = 5

if loss_func == "bce":
    model = nnModelBCELoss(dNet)
    model.train(
        train_, batch_size=batch_size, 
        collate_fn=id_collate, epochs=epochs
    )
elif loss_func == "entropy":
    model = nnModelCrossEntropyLoss(dNet)
    model.train(
        train_, batch_size=batch_size, 
        collate_fn=id_collate, epochs=epochs
    )    




EPOCHS:   0%|          | 0/5 [00:00<?, ?it/s][A[A[A



Iteration:   0%|          | 0/285 [00:00<?, ?it/s][A[A[A[A



Iteration:   2%|▏         | 7/285 [00:00<00:04, 61.11it/s][A[A[A[A



Iteration:   5%|▍         | 13/285 [00:00<00:04, 60.02it/s][A[A[A[A



Iteration:   7%|▋         | 19/285 [00:00<00:04, 59.66it/s][A[A[A[A



Iteration:   9%|▉         | 25/285 [00:00<00:04, 59.19it/s][A[A[A[A



Iteration:  11%|█         | 32/285 [00:00<00:04, 60.27it/s][A[A[A[A



Iteration:  14%|█▎        | 39/285 [00:00<00:04, 60.72it/s][A[A[A[A



Iteration:  16%|█▌        | 46/285 [00:00<00:03, 61.37it/s][A[A[A[A



Iteration:  19%|█▊        | 53/285 [00:00<00:03, 61.10it/s][A[A[A[A



Iteration:  21%|██        | 59/285 [00:00<00:03, 59.39it/s][A[A[A[A



Iteration:  23%|██▎       | 65/285 [00:01<00:03, 58.63it/s][A[A[A[A



Iteration:  25%|██▍       | 71/285 [00:01<00:03, 57.64it/s][A[A[A[A



Iteration:  27%|██▋       | 77/285 [00:01<00:03, 

pred_array: [[ 1.8755137  19.139454  ]
 [ 3.4382713   9.924626  ]
 [ 4.119235    2.4981227 ]
 ...
 [ 0.20663142 -0.31521603]
 [ 0.21571717 -0.26732635]
 [ 0.2440415  -0.2748758 ]]
labels_array: [1 1 0 ... 0 0 0]
Loss: 225.0679(train)	|	Acc: 60.9%(train)






Iteration:   5%|▍         | 14/285 [00:00<00:04, 65.25it/s][A[A[A[A



Iteration:   7%|▋         | 21/285 [00:00<00:04, 63.95it/s][A[A[A[A



Iteration:  10%|▉         | 28/285 [00:00<00:04, 64.04it/s][A[A[A[A



Iteration:  12%|█▏        | 35/285 [00:00<00:03, 63.92it/s][A[A[A[A



Iteration:  15%|█▍        | 42/285 [00:00<00:03, 63.06it/s][A[A[A[A



Iteration:  17%|█▋        | 48/285 [00:00<00:03, 62.03it/s][A[A[A[A



Iteration:  19%|█▉        | 55/285 [00:00<00:03, 62.54it/s][A[A[A[A



Iteration:  22%|██▏       | 62/285 [00:00<00:03, 62.89it/s][A[A[A[A



Iteration:  24%|██▍       | 69/285 [00:01<00:03, 61.64it/s][A[A[A[A



Iteration:  26%|██▋       | 75/285 [00:01<00:03, 61.11it/s][A[A[A[A



Iteration:  29%|██▉       | 82/285 [00:01<00:03, 61.80it/s][A[A[A[A



Iteration:  31%|███       | 89/285 [00:01<00:03, 62.11it/s][A[A[A[A



Iteration:  34%|███▎      | 96/285 [00:01<00:03, 62.30it/s][A[A[A[A



Iteration:  36%|███▌ 

pred_array: [[ 0.25809532 -0.28377545]
 [ 0.25439015 -0.28374594]
 [ 0.22496578 -0.31513527]
 ...
 [ 0.25665906 -0.27958873]
 [ 0.24504764 -0.2885799 ]
 [ 0.25652978 -0.27959874]]
labels_array: [1 1 0 ... 0 0 0]
Loss: 415.0455(train)	|	Acc: 62.1%(train)






Iteration:   5%|▍         | 14/285 [00:00<00:04, 63.64it/s][A[A[A[A



Iteration:   7%|▋         | 21/285 [00:00<00:04, 63.10it/s][A[A[A[A



Iteration:   9%|▉         | 27/285 [00:00<00:04, 61.69it/s][A[A[A[A



Iteration:  12%|█▏        | 33/285 [00:00<00:04, 60.00it/s][A[A[A[A



Iteration:  14%|█▍        | 40/285 [00:00<00:04, 60.17it/s][A[A[A[A



Iteration:  16%|█▌        | 46/285 [00:00<00:03, 60.02it/s][A[A[A[A



Iteration:  18%|█▊        | 52/285 [00:00<00:03, 59.07it/s][A[A[A[A



Iteration:  20%|██        | 58/285 [00:00<00:03, 59.07it/s][A[A[A[A



Iteration:  22%|██▏       | 64/285 [00:01<00:03, 58.41it/s][A[A[A[A



Iteration:  25%|██▍       | 70/285 [00:01<00:03, 58.25it/s][A[A[A[A



Iteration:  27%|██▋       | 76/285 [00:01<00:03, 58.05it/s][A[A[A[A



Iteration:  29%|██▉       | 82/285 [00:01<00:03, 57.21it/s][A[A[A[A



Iteration:  31%|███       | 88/285 [00:01<00:03, 57.14it/s][A[A[A[A



Iteration:  33%|███▎ 

pred_array: [[ 0.26575163 -0.28927565]
 [ 0.2760313  -0.27226132]
 [ 0.26020324 -0.2899117 ]
 ...
 [ 0.236883   -0.25729537]
 [ 0.21084486 -0.2627654 ]
 [ 0.1655479  -0.28141847]]
labels_array: [1 1 0 ... 0 0 0]
Loss: 605.3024(train)	|	Acc: 62.0%(train)






Iteration:   4%|▎         | 10/285 [00:00<00:05, 46.78it/s][A[A[A[A



Iteration:   5%|▌         | 15/285 [00:00<00:05, 46.06it/s][A[A[A[A



Iteration:   7%|▋         | 20/285 [00:00<00:05, 46.36it/s][A[A[A[A



Iteration:   9%|▉         | 25/285 [00:00<00:05, 46.39it/s][A[A[A[A



Iteration:  11%|█         | 30/285 [00:00<00:05, 45.89it/s][A[A[A[A



Iteration:  12%|█▏        | 35/285 [00:00<00:05, 45.94it/s][A[A[A[A



Iteration:  14%|█▍        | 40/285 [00:00<00:05, 46.07it/s][A[A[A[A



Iteration:  16%|█▌        | 45/285 [00:00<00:05, 46.27it/s][A[A[A[A



Iteration:  18%|█▊        | 50/285 [00:01<00:05, 45.71it/s][A[A[A[A



Iteration:  19%|█▉        | 55/285 [00:01<00:04, 46.14it/s][A[A[A[A



Iteration:  21%|██        | 60/285 [00:01<00:04, 46.54it/s][A[A[A[A



Iteration:  23%|██▎       | 65/285 [00:01<00:04, 46.80it/s][A[A[A[A



Iteration:  25%|██▍       | 70/285 [00:01<00:04, 45.72it/s][A[A[A[A



Iteration:  26%|██▋  

pred_array: [[ 0.24592765 -0.26689982]
 [ 0.24658765 -0.26704773]
 [ 0.21580091 -0.27800608]
 ...
 [ 0.241982   -0.26028034]
 [ 0.22890987 -0.28132024]
 [ 0.24200407 -0.26025808]]
labels_array: [1 1 0 ... 0 0 0]
Loss: 795.0110(train)	|	Acc: 62.1%(train)






Iteration:   4%|▎         | 10/285 [00:00<00:05, 49.61it/s][A[A[A[A



Iteration:   5%|▌         | 15/285 [00:00<00:05, 49.13it/s][A[A[A[A



Iteration:   7%|▋         | 20/285 [00:00<00:05, 48.89it/s][A[A[A[A



Iteration:   9%|▉         | 25/285 [00:00<00:05, 48.91it/s][A[A[A[A



Iteration:  11%|█         | 30/285 [00:00<00:05, 48.25it/s][A[A[A[A



Iteration:  12%|█▏        | 35/285 [00:00<00:05, 48.49it/s][A[A[A[A



Iteration:  14%|█▍        | 40/285 [00:00<00:05, 47.94it/s][A[A[A[A



Iteration:  16%|█▌        | 45/285 [00:00<00:05, 47.00it/s][A[A[A[A



Iteration:  18%|█▊        | 50/285 [00:01<00:05, 46.90it/s][A[A[A[A



Iteration:  19%|█▉        | 55/285 [00:01<00:04, 47.19it/s][A[A[A[A



Iteration:  21%|██        | 60/285 [00:01<00:04, 47.03it/s][A[A[A[A



Iteration:  23%|██▎       | 65/285 [00:01<00:04, 46.62it/s][A[A[A[A



Iteration:  25%|██▍       | 70/285 [00:01<00:04, 47.01it/s][A[A[A[A



Iteration:  26%|██▋  

pred_array: [[ 0.2521434  -0.27039933]
 [ 0.2351875  -0.29202878]
 [ 0.25216806 -0.270431  ]
 ...
 [ 0.22601123 -0.24229547]
 [ 0.22705714 -0.23977377]
 [ 0.22607967 -0.24228774]]
labels_array: [1 1 0 ... 0 0 0]
Loss: 985.2901(train)	|	Acc: 62.0%(train)





In [220]:
model.evaluate(
    test_, batch_size=8, 
    collate_fn=id_collate,
)




EVALUATING:   0%|          | 0/143 [00:00<?, ?it/s][A[A[A


EVALUATING:  33%|███▎      | 47/143 [00:00<00:00, 468.99it/s][A[A[A


EVALUATING: 100%|██████████| 143/143 [00:00<00:00, 481.06it/s][A[A[A

pred_array: [[ 0.23735522 -0.25363603]
 [ 0.23735522 -0.25363603]
 [ 0.23735522 -0.25363603]
 ...
 [ 0.23735522 -0.25363603]
 [ 0.23735522 -0.25363603]
 [ 0.23735522 -0.25363603]]
labels_array: [0 0 0 ... 1 1 1]
Loss: 97.0651(test)	|	Acc: 59.4%(test)





(97.06509572267532, 675)

In [50]:
break

SyntaxError: 'break' outside loop (<ipython-input-50-6aaf1f276005>, line 4)

## Sandbox

In [None]:
torch.tensor([
    [1.45, 6.55, 5.64],
    [111.7, 1100.50, 10000.13]
]).argmax(dim=1)

In [None]:
import torch
from torch import nn

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward()

In [None]:
print(input)
print(target)
print(output)