In [30]:
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from gensim.models import KeyedVectors
import nltk
from nltk.corpus import stopwords
from string import punctuation

In [3]:
nltk.download("punkt_tab")
nltk.download('stopwords')

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Steven\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Steven\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [4]:
data_path = './data/'
df_data = pd.read_csv(data_path + 'data_processed.csv')

In [11]:
pos_genRe_idx = np.where((df_data['GEN_RE'] == 1))[0]
neg_genRe_idx = np.where((df_data['GEN_RE'] == 0))[0]
pos_30Re_idx = np.where((df_data['30_RE'] == 1))[0]
neg_30Re_idx = np.where((df_data['30_RE'] == 0))[0]

In [16]:
df_data.columns

Index(['SUBJECT_ID', 'HADM_ID', 'TEXT', 'ADMITTIME', 'DISCHTIME', 'GEN_RE',
       '30_RE'],
      dtype='object')

Download word2vec file from Kaggle: https://www.kaggle.com/datasets/alexiscorona/pubmed-and-pmc-w2v/

In [None]:
w2v = KeyedVectors.load_word2vec_format("PubMed-and-PMC-w2v.bin", binary=True)

In [25]:
def tokenize(text):
    text = nltk.word_tokenize(text)
    stop_words = set(stopwords.words('english'))
    text = [token.lower() for token in text if token.lower() not in stop_words and token not in punctuation and not token.isnumeric()]
    return text

def vectorize(text):
    vectors = [w2v[token] if token in w2v else np.random.uniform(-1, 1, (200,)).astype(np.float32) for token in text]
    return np.array(vectors)

def generate_dataset(df_data, pos_idx, neg_idx):
    num_pos = len(pos_idx)
    labels = [1] * num_pos + [0] * num_pos
    labels = torch.tensor(labels, dtype=torch.long)
    neg_idx_sample = np.random.choice(neg_idx, size=num_pos, replace=False)
    all_idx = pos_idx.tolist() + neg_idx_sample.tolist()
    datapoints = []
    max_seq_len = 0
    for idx in all_idx:
        vectors = vectorize(tokenize(df_data.loc[idx, 'TEXT']))
        if vectors.shape[0] > max_seq_len:
            max_seq_len = vectors.shape[0]
        datapoints.append(vectors)   
    return datapoints, labels, max_seq_len

In [39]:
class HFDataset(Dataset):
    def __init__(self, datapoints, labels, max_seq_len):
        self.datapoints = datapoints
        self.labels = labels
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        vectors = self.datapoints[index]
        if vectors.shape[0] < max_seq_len:
            padded_vectors = np.pad(vectors, ((0, max_seq_len - vectors.shape[0]), (0, 0)), mode='constant', constant_values=0)
        else:
            padded_vectors = vectors
        return torch.tensor(padded_vectors, dtype=torch.float), self.labels[index]

In [None]:
# class HFCNN(nn.Module):
#     def __init__(self):
#         super(HFCNN, self).__init__()
#         

In [40]:
dataset = HFDataset(datapoints, labels, max_seq_len)
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
for i, (input, target) in enumerate(train_loader):
    print(input.shape)
    print(target.shape)

torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size([32, 4658, 200])
torch.Size([32])
torch.Size