In [1]:
import pandas as pd

In [2]:
def read_ht_data(file_path):
    raw = pd.read_csv(file_path, index_col = 0)
    #raw = raw.drop(raw.columns[0], axis=1)
    col_to_drop = ['related_topics', 'question_url', 'answers', 'sub_category']
    raw = raw.drop(columns = col_to_drop)
    return raw

In [3]:
full_data = read_ht_data('SINGLE_PART/healthtap_medical_qna_dataset_1.6m.csv')

  mask |= (ar1 == a)


In [4]:
from collections import Counter
def drop_and_split(raw):
    print("Length before treatment ", len(raw))
    raw  = raw.dropna()
    label_counter = Counter(raw['main_category'])
    label_counter.pop('question')
    lab = []
    qst = []
    for label in label_counter:
        if label_counter[label] > 2:
            Qs = raw[raw['main_category']==label]['question'].values.tolist()
            qst += Qs
            lab += len(Qs)*[label] 
    assert len(qst) == len(lab)
    print("Length after treatment ", len(lab))
    return [[inp, lab] for inp, lab in zip(qst, lab)]

In [5]:
fdata = drop_and_split(full_data)

Length before treatment  1605229
Length after treatment  1603996


In [7]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [8]:
def lab_to_ids(label):
    lab_dict = AttrDict()
    counter = Counter(label)
    lab_dict.word2id = {lab:ids for (ids, lab) in enumerate(counter.keys())}
    lab_dict.id2word = {ids:lab for lab,ids in lab_dict.word2id.items()}
    return lab_dict

In [9]:
import re
import string
exclude = set(string.punctuation)
def sent_to_words(sents):
    def normalizeString(s):
        s = re.sub(r"&quot;", r"", s)
        s = re.sub(r"&apos;", r"", s)
        s = re.sub(r"([.!?])", r" ", s)
        s = re.sub(r"([_])", r" ", s)
        s = re.sub(r"[^a-zA-Z0-9.!?]+", r"", s)
        return s.lower()
    def normalizeSent(sent):
        return re.sub(' +',' ',sent).strip()
    def sent_split(s):
        s = re.sub(r"/", r" ", s)
        s = re.sub(r"\\", r" ", s)
        s = re.sub(' +',' ',s)
        return s
    normed = [ ' '.join([normalizeString(s) for s in sent_split(sent).split(' ') if s and s not in exclude]) for sent in sents]
    
    return [normalizeSent(norm).split(' ') for norm in normed]

In [12]:
import numpy as np
PAD_IDX = 0
UNK_IDX = 1
class QuesData:
    def __init__(self, data, sent_len_perc = 80, max_vocab_size = 25000):
        full_input, self.MAX_sent_len = self.drop_LongSents(data, sent_len_perc)
        self.d_inp, self.d_lab = [x[0] for x in full_input], [x[1] for x in full_input]
        self.lab_dict = lab_to_ids(self.d_lab)
        
        vocab_counter = self.word_count(self.d_inp)
        self.vocab = self.build_vocab(vocab_counter, max_vocab_size)
    
    def __getitem__(self, idx):
        input_sent = self.d_inp[idx]
        labels = self.d_lab[idx]
        input_ids = [self.vocab.word2id[word] if word in self.vocab.word2id.keys() else self.vocab.word2id['<UNK>'] for word in input_sent]
        label_ids = [self.lab_dict.word2id[labels] ]

        return input_sent, input_ids, labels, label_ids
    
    def drop_LongSents(self, data, sent_len_perc):
        inp = sent_to_words([x[0] for x in data])
        labs = [x[1] for x in data]
        sent_lens = [len(s) for s in inp]
        MAX_len = int(np.percentile(sent_lens, sent_len_perc))
        dropped = [(d,lab) for d,lab in zip(inp, labs) if len(d)<=MAX_len]
        return dropped, MAX_len
    def word_count(self, ins):
        count = Counter()
        for sent in ins:
            for word in sent:
                if word: count[word] += 1
        return count
    def build_vocab(self, word_count, max_vocab_size):
        vocab = AttrDict()
        vocab.word2id = {'<PAD>': PAD_IDX, '<UNK>': UNK_IDX}
        vocab.word2id.update({token: (ids + 2) for ids, (token, count) in enumerate(word_count.most_common(max_vocab_size))if count>=2 })
        vocab.id2word = {ids:word for word, ids in vocab.word2id.items()}
        return vocab

In [22]:
cla = QuesData(fdata[:50000], sent_len_perc = 60)