In [1]:
!ls ../input/

openigraphdata


In [2]:
import numpy as np
import pandas as pd
import re
import csv
import umap
import seaborn as sns
import os
from PIL import Image
import nltk
from nltk import bigrams
import itertools
import cv2
import collections
import spacy

import torch
import torchvision
import pickle

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from collections import Counter
from itertools import combinations
from nltk import everygrams

from transformers import BertModel
from transformers import BertTokenizer
from torchvision import transforms, models
from skimage import data, io, filters
from torch.utils.data import DataLoader, Dataset
import torch.optim.lr_scheduler as lr_scheduler
from nltk.corpus import stopwords


import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [4]:
stop_words = list(stopwords.words('english'))
modals = ['can', 'could', 'may', 'might', 'must', 'will', 'would', 'should']
print(len(stop_words))
stop_words.remove('no')
stop_words.remove('not')

stop_words = stop_words + ['cm','mm', 'thank'] + modals
print(len(stop_words))

179
188


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
def get_word_vocab(word_list,top_k=500):
    
    if top_k is None:
        word_vocab = list(set(word_list))
    else:
        counts = Counter(word_list)
        most_common = counts.most_common(top_k)
        word_vocab = [t[0] for t in most_common]
    word2id = {}
    id2word = {}
    for i,w in enumerate(word_vocab):
        word2id[w] = i
        id2word[i] = w
    return word_vocab,word2id,id2word

def preprocess(df_col):
    
    df_col = df_col.apply(lambda x : x.lower())
    df_col = df_col.apply(lambda x : re.sub(r"x-xxxx","",x))
    df_col = df_col.apply(lambda x : re.sub(r"xxxx","",x))
    df_col = df_col.apply(lambda x : re.sub(r"year-old", "", x))
    df_col = df_col.apply(lambda x : re.sub(r"2 images","",x))
    df_col = df_col.apply(lambda x : re.sub('[^A-Za-z]+', ' ', x))
    df_col = df_col.apply(lambda x : re.sub(' +', ' ', x))
    df_col = df_col.apply(lambda x : x.strip())
    return df_col

def get_word_edge_index(cap_list,n=3):
    ngram_list = []
    for cap in cap_list:
        if(len(cap) >= n):
            ngram_list.append(list(everygrams(cap,n,n)))
    ngram_list = [ngram for l in ngram_list for ngram in l]
    edges = set()
    
    for ngram in ngram_list:
        comb = list(combinations(ngram,2))
        for c in comb:
            if(c[0] != c[1]):
                edges.add(tuple(sorted(c)))
    return np.array(list(edges))

def get_openi_words(openi_df,stop_words=stop_words):
#     openi_df['Target'] = openi_df['Target'].apply(lambda x : x.lower())

    openi_df['Target'] = preprocess(openi_df['Target'])
    nlp = spacy.load('en_core_web_sm')
    openi_targets = openi_df['Target'].values
    preproc_targets = []
    for caption in openi_targets:
        doc = nlp(caption)
        w = []
        for token in doc:
            w.append(token.lemma_.lower())   
        w =  [item for item in w if item.isalpha()]
        w = [item for item in w if item not in stop_words]
        preproc_targets.append(w)
        
    # remove dulplicate keywords in the report/caption
    keywords = [set(tar) for tar in preproc_targets]
    keywords_string = [','.join(k) for k in keywords]
    preproc_targets_string = [','.join(tar) for tar in preproc_targets]
    openi_df['Prepoc_Target'] = preproc_targets_string
    openi_df['Keywords'] = keywords_string
    openi_words = [w for tar in preproc_targets for w in tar]
    len(openi_words)
    return openi_df,openi_words,preproc_targets


def get_openi_keywords_val_test(openi_df,openi_words):
#     openi_df['Target'] = openi_df['Target'].apply(lambda x : x.lower())
    openi_df['Target'] = preprocess(openi_df['Target'])
    nlp = spacy.load('en_core_web_sm')
    openi_targets = openi_df['Target'].values
    preproc_targets = []
    for caption in openi_targets:
        doc = nlp(caption)
        w = []
        for token in doc:
            w.append(token.lemma_.lower())   
        w =  [item for item in w if item.isalpha()]
        w = [item for item in w if item not in stop_words]
        w = [item for item in w if item in openi_words]
        preproc_targets.append(w)
        
    # remove dulplicate keywords in the report/caption
    keywords = [set(tar) for tar in preproc_targets]
    keywords_string = [','.join(k) for k in keywords]
    preproc_targets_string = [','.join(tar) for tar in preproc_targets]
    openi_df['Prepoc_Target'] = preproc_targets_string
    openi_df['Keywords'] = keywords_string
    return openi_df,preproc_targets


In [7]:
word_dir = '../input/openigraphdata/Mesh_Keywords/openi_word_embed/'

## Load Data

In [8]:
train_df = pd.read_csv(word_dir + 'train_reports_with_keywords.csv')
val_df = pd.read_csv(word_dir + 'val_reports_with_keywords.csv')
test_df = pd.read_csv(word_dir + 'test_reports_with_keywords.csv')

In [9]:
print(train_df.shape)
print(val_df.shape)
print(test_df.shape)

(2566, 12)
(366, 12)
(733, 12)


In [10]:
_,train_words,_ = get_openi_words(train_df)
openi_word_vocab,openi_word2id,openi_id2word = get_word_vocab(train_words,top_k=None)
print(len(openi_word_vocab))


1512


In [11]:
openi_word_vocab[:10]

['suggestion',
 'erosion',
 'single',
 'well',
 'kub',
 'recently',
 'basal',
 'entirely',
 'dystrophy',
 'view']

In [12]:
train_df,train_targets = get_openi_keywords_val_test(train_df,openi_word_vocab)
train_df.head()

Unnamed: 0,Image,Frontal,Lateral,Comparison,Indication,Findings,Impression,MESH,Problems,Target,Problem_List,Prepoc_Problems,Prepoc_Target,Keywords
0,CXR3403_IM-1647-1001;CXR3403_IM-1647-2001,CXR3403_IM-1647-1001,CXR3403_IM-1647-2001,None.,XXXX-year-old female. Left chest pain. Right r...,the cardiomediastinal silhouette is normal in ...,negative for acute abnormality,normal,normal,the cardiomediastinal silhouette is normal in ...,['normal'],normal,"cardiomediastinal,silhouette,normal,size,conto...","size,effusion,contour,pleural,abnormality,pneu..."
1,CXR810_IM-2343-1001;CXR810_IM-2343-2001,CXR810_IM-2343-1001,CXR810_IM-2343-2001,None.,chest pain,normal heart size mediastinal and aortic conto...,no evidence of active cardiopulmonary disease,Calcified Granuloma/scattered;Spine/degenerative,Calcified Granuloma;Spine,normal heart size mediastinal and aortic conto...,"['calcified granuloma', 'spine']",calcified granuloma;spine,"normal,heart,size,mediastinal,aortic,contour,n...","pulmonary,visible,cardiopulmonary,contour,aort..."
2,CXR2474_IM-1003-1001,CXR2474_IM-1003-1001,,,Shortness of breath.,patchy interstitial infiltrates have developed...,bibasilar patchy airspace disease with bilater...,Infiltrate/lung/lower lobe/bilateral/interstit...,Infiltrate;Costophrenic Angle;Aorta;Airspace D...,patchy interstitial infiltrates have developed...,"['infiltrate', 'costophrenic angle', 'aorta', ...",infiltrate;costophrenic angle;aorta;airspace d...,"patchy,interstitial,infiltrate,develop,low,lob...","bilateral,pulmonary,infiltrate,pleural,normal,..."
3,CXR632_IM-2213-1001;CXR632_IM-2213-1002,CXR632_IM-2213-1001,CXR632_IM-2213-1002,Chest x-XXXX and XXXX.,"XXXX-year-old male with shortness of breath, X...",cardiac and mediastinal silhouette are unremar...,no acute cardiopulmonary abnormality,normal,normal,cardiac and mediastinal silhouette are unremar...,['normal'],normal,"cardiac,mediastinal,silhouette,unremarkable,lu...","lung,cardiopulmonary,cardiac,effusion,abnormal..."
4,CXR3101_IM-1453-1001;CXR3101_IM-1453-3001,CXR3101_IM-1453-1001,CXR3101_IM-1453-3001,None.,Multiple myeloma. Bone marrow transplant.,the heart size and pulmonary vascularity appea...,no evidence of active disease,"Spine/degenerative;Stents/abdomen;Aorta, Thora...","Spine;Stents;Aorta, Thoracic;Calcified Granuloma",the heart size and pulmonary vascularity appea...,"['spine', 'stents', 'aorta thoracic', 'calcifi...",spine;stents;aorta thoracic;calcified granuloma,"heart,size,pulmonary,vascularity,appear,within...","limit,pulmonary,pleural,see,normal,change,gran..."


In [13]:
val_df,_ = get_openi_keywords_val_test(val_df,openi_word_vocab)
val_df.head()

Unnamed: 0,Image,Frontal,Lateral,Comparison,Indication,Findings,Impression,MESH,Problems,Target,Problem_List,Prepoc_Problems,Prepoc_Target,Keywords
0,CXR3518_IM-1717-1001;CXR3518_IM-1717-2001,CXR3518_IM-1717-1001,CXR3518_IM-1717-2001,None available.,"XXXX year old with prostate cancer, no chest c...",the heart and mediastinal silhouettes are with...,no acute visualized cardiopulmonary abnormality,Osteophyte/thoracic vertebrae/degenerative/mul...,Osteophyte,the heart and mediastinal silhouettes are with...,['osteophyte'],osteophyte,"heart,mediastinal,silhouette,within,normal,lim...","limit,visualize,cardiopulmonary,mediastinal,no..."
1,CXR2553_IM-1059-1001;CXR2553_IM-1059-2001,CXR2553_IM-1059-1001,CXR2553_IM-1059-2001,,"XXXX-year-old male with XXXX, 786.2",there are several small calcified granulomas t...,no evidence of active disease,Calcified Granuloma/lung/multiple/small;Spine/...,Calcified Granuloma;Spine,there are several small calcified granulomas t...,"['calcified granuloma', 'spine']",calcified granuloma;spine,"several,small,calcify,granuloma,lung,otherwise...","limit,pulmonary,contour,pleural,mediastinal,no..."
2,CXR3156_IM-1486-1001;CXR3156_IM-1486-2001,CXR3156_IM-1486-1001,CXR3156_IM-1486-2001,None.,XXXX-year-old woman with positive PPD..,the lungs are clear bilaterally specifically n...,no acute cardiopulmonary abnormality specifica...,normal,normal,the lungs are clear bilaterally specifically n...,['normal'],normal,"lung,clear,bilaterally,specifically,no,evidenc...","visualize,cardiopulmonary,tuberculous,pleural,..."
3,CXR3883_IM-1971-1001;CXR3883_IM-1971-12012,CXR3883_IM-1971-1001,CXR3883_IM-1971-12012,None.,The patient is a XXXX-year-old female with lef...,no pneumothorax pleural effusion or airspace c...,no acute cardiopulmonary abnormality,normal,normal,no pneumothorax pleural effusion or airspace c...,['normal'],normal,"no,pneumothorax,pleural,effusion,airspace,cons...","limit,pulmonary,cardiopulmonary,effusion,size,..."
4,CXR330_IM-1577-0001-0001;CXR330_IM-1577-0001-0002,CXR330_IM-1577-0001-0002,CXR330_IM-1577-0001-0001,None available.,XXXX-year-old woman with hepatic encephalopath...,heart size mediastinal contour and pulmonary v...,no acute cardiopulmonary abnormality,normal,normal,heart size mediastinal contour and pulmonary v...,['normal'],normal,"heart,size,mediastinal,contour,pulmonary,vascu...","limit,pulmonary,visualize,cardiopulmonary,cont..."


In [14]:
test_df,_ = get_openi_keywords_val_test(test_df,openi_word_vocab)
test_df.head()

Unnamed: 0,Image,Frontal,Lateral,Comparison,Indication,Findings,Impression,MESH,Problems,Target,Problem_List,Prepoc_Problems,Prepoc_Target,Keywords
0,CXR1852_IM-0554-1001;CXR1852_IM-0554-2001,CXR1852_IM-0554-1001,CXR1852_IM-0554-2001,,Chest pain,lungs are clear no focal consolidation effusio...,no acute cardiopulmonary disease,normal,normal,lungs are clear no focal consolidation effusio...,['normal'],normal,"lung,clear,no,focal,consolidation,effusion,pne...","lung,cardiopulmonary,effusion,contour,pneumoth..."
1,CXR2277_IM-0864-1001;CXR2277_IM-0864-2001,CXR2277_IM-0864-1001,CXR2277_IM-0864-2001,None Available.,XXXX-year-old female with chest pain.,the lungs are clear without evidence of focal ...,no radiographic evidence of acute cardiopulmon...,normal,normal,the lungs are clear without evidence of focal ...,['normal'],normal,"lung,clear,without,evidence,focal,airspace,dis...","limit,cardiopulmonary,contour,pleural,mediasti..."
2,CXR1015_IM-0001-1001;CXR1015_IM-0001-2001;CXR1...,CXR1015_IM-0001-1001;CXR1015_IM-0013-1001,CXR1015_IM-0001-2001;CXR1015_IM-0013-2001,XXXX,"XXXX-year-old female, COPD exacerbation, short...",streaky and patchy bibasilar opacities triangu...,bibasilar opacities right greater than left fe...,Opacity/lung/base/bilateral/patchy/streaky;Tec...,Opacity;Technical Quality of Image Unsatisfact...,streaky and patchy bibasilar opacities triangu...,"['opacity', 'technical quality of image unsati...",opacity;technical quality of image unsatisfact...,"streaky,patchy,bibasilar,opacity,density,proje...","project,pulmonary,pleural,see,streaky,definite..."
3,CXR1992_IM-0649-1001;CXR1992_IM-0649-4004,CXR1992_IM-0649-4004,CXR1992_IM-0649-1001,"XXXX, XXXX",Seizure,borderline heart size the lungs are hyperexpan...,findings of chronic obstructive pulmonary disease,Cardiomegaly/borderline;Lung/hyperdistention;L...,"Cardiomegaly;Lung;Lung, Hyperlucent;Pulmonary ...",borderline heart size the lungs are hyperexpan...,"['cardiomegaly', 'lung', 'lung hyperlucent', '...",cardiomegaly;lung;lung hyperlucent;pulmonary d...,"borderline,heart,size,lung,hyperexpande,hyperl...","compatible,pulmonary,contour,pleural,aortic,sp..."
4,CXR979_IM-2466-1001;CXR979_IM-2466-2001,CXR979_IM-2466-1001,CXR979_IM-2466-2001,None.,XXXX-year-old female. Chest pain. Prior lumpec...,the cardiomediastinal silhouette is normal in ...,negative for acute abnormality,Lung/hyperdistention;Surgical Instruments/thor...,Lung;Surgical Instruments,the cardiomediastinal silhouette is normal in ...,"['lung', 'surgical instruments']",lung;surgical instruments,"cardiomediastinal,silhouette,normal,size,conto...","compatible,contour,pleural,normal,size,chest,l..."


In [15]:
train_df.to_csv('train_reports_with_keywords.csv',index=False)
val_df.to_csv('val_reports_with_keywords.csv',index=False)
test_df.to_csv('test_reports_with_keywords.csv',index=False)


In [16]:
openi_word_edges_bigram = get_word_edge_index(train_targets,n=2)
openi_word_edges_trigram = get_word_edge_index(train_targets,n=3)
openi_word_edges_fourgram = get_word_edge_index(train_targets,n=4)

print(openi_word_edges_bigram.shape)
print(openi_word_edges_trigram.shape)
print(openi_word_edges_fourgram.shape)


(11980, 2)
(23564, 2)
(33534, 2)


In [17]:
np.save('openi_word_edges_bigram.npy', openi_word_edges_bigram)
np.save('openi_word_edges_trigram.npy', openi_word_edges_trigram)
np.save('openi_word_edges_fourgram.npy', openi_word_edges_fourgram)

In [18]:
with open('openi_word_vocab.pkl', 'wb') as f:
    pickle.dump(openi_word_vocab, f)
    
    
with open('openi_word2id.pkl', 'wb') as f:
    pickle.dump(openi_word2id, f)
    

with open('openi_id2word.pkl', 'wb') as f:
    pickle.dump(openi_id2word, f)

## Generate Initial Word Embeddings for Knowledge graph

In [19]:
words = np.array(openi_word_vocab)

print(words.shape)

(1512,)


In [20]:
class MLDataset(Dataset):
    def __init__(self, x, max_length=20):
        self.x = x
        self.max_length = max_length
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        input_text = self.x[idx]

        encoded_dict = self.tokenizer.encode_plus(
          text=input_text,  # the sentence to be encoded
          add_special_tokens=True,  # Add [CLS] and [SEP]
          max_length = self.max_length,  # maximum length of a sentence
          padding='max_length',  # Add [PAD]s
          return_attention_mask = True,  # Generate the attention mask
          return_tensors = 'pt',  # ask the function to return PyTorch tensors
          truncation=True        #if text length > max_length it will truncate
         )

        return encoded_dict['input_ids'], encoded_dict['token_type_ids'], encoded_dict['attention_mask'],input_text
    
    
    
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding,self).__init__()
        base_model = BertModel.from_pretrained('bert-base-uncased')
        bert_model = nn.Sequential(*list(base_model.children())[0:])
        self.bert_embedding = bert_model[0]
    def forward(self, input_ids, token_type_ids, attention_mask):
        embeddings = self.bert_embedding(input_ids=input_ids, token_type_ids=token_type_ids,position_ids=None)
        return embeddings.mean(axis=1)

In [21]:
bert_model = Embedding()
bert_model.to(device)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Embedding(
  (bert_embedding): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [22]:
dataset = MLDataset(words)
dataloader = DataLoader(dataset, batch_size = 4, shuffle=False)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

In [23]:
word_embeddings = []
with torch.no_grad():
    bert_model.eval()
    for (input_tokens,segment_ids,attention_mask,input_word) in dataloader:
        input_tokens,segment_ids,attention_mask = input_tokens.to(device), segment_ids.to(device), attention_mask.to(device)
        input_tokens = input_tokens.squeeze(1)
        segment_ids = segment_ids.squeeze(1)
        attention_mask = attention_mask.squeeze(1)
        embed = bert_model(input_tokens, segment_ids, attention_mask)
        word_embeddings.append(embed)
        
word_embeddings = torch.cat(word_embeddings).cpu().numpy()
print(word_embeddings.shape)

(1512, 768)


In [24]:
np.save('openi_word_embeddings.npy', word_embeddings)
