In [1]:
import pandas as pd
import numpy as np
import pickle
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
# from scipy.special import softmax # ver>=1.20

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Load Data

In [2]:
def read_ht_data(file_path):
    raw = pd.read_csv(file_path, index_col = 0)
    #raw = raw.drop(raw.columns[0], axis=1)
    col_to_drop = ['related_topics', 'question_url', 'answers', 'sub_category']
    raw = raw.drop(columns = col_to_drop)
    return raw

In [3]:
# full_data = read_ht_data('healthtap_medical_qna_dataset_1.6m.csv')

In [4]:
def drop_and_split(raw):
    # shuffle the data
    raw = raw.reindex(np.random.permutation(raw.index))
    raw = raw.sort_index(axis='index')
    print(raw[:10])
    print("Length before treatment ", len(raw))
    raw  = raw.dropna()
    label_counter = Counter(raw['main_category'])
    label_counter.pop('question')
    lab = []
    qst = []
    for label in label_counter:
        if label_counter[label] > 2:
            Qs = raw[raw['main_category']==label]['question'].values.tolist()
            qst += Qs
            lab += len(Qs)*[label] 
    assert len(qst) == len(lab)
    print("Length after treatment ", len(lab))
    return [[inp, lab] for inp, lab in zip(qst, lab)]

In [9]:
X_train = pickle.load(open('X_train.pkl','rb'))
y_train = pickle.load(open('y_train.pkl','rb'))
X_test = pickle.load(open('X_test.pkl','rb'))
y_test = pickle.load(open('y_test.pkl','rb'))

In [10]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [11]:
def lab_to_ids(label):
    lab_dict = AttrDict()
    counter = Counter(label)
    lab_dict.word2id = {lab:ids for (ids, lab) in enumerate(counter.keys())}
    lab_dict.id2word = {ids:lab for lab,ids in lab_dict.word2id.items()}
#     lab_dict.word2id['<UNK>'] = len(lab_dict.word2id)
#     lab_dict.id2word[len(lab_dict.word2id)]='<UNK>'
    return lab_dict

In [12]:
import re
import string
exclude = set(string.punctuation)
def sent_to_words(sents):
    def normalizeString(s):
        s = re.sub(r"&quot;", r"", s)
        s = re.sub(r"&apos;", r"", s)
        s = re.sub(r"([.!?])", r" ", s)
        s = re.sub(r"([_])", r" ", s)
        s = re.sub(r"[^a-zA-Z0-9.!?]+", r"", s)
        return s.lower()
    def normalizeSent(sent):
        return re.sub(' +',' ',sent).strip()
    def sent_split(s):
        s = re.sub(r"/", r" ", s)
        s = re.sub(r"\\", r" ", s)
        s = re.sub(' +',' ',s)
        return s
    normed = [ ' '.join([normalizeString(s) for s in sent_split(sent).split(' ') if s and s not in exclude]) for sent in sents]
    
    return [normalizeSent(norm).split(' ') for norm in normed]

In [13]:
def get_trans_mat():
    df = pd.read_csv("symptoms-DO.tsv",sep = '\t')
    d_set = list(df['disease_name'].unique())
    s_set = list(df['symptom_name'].unique())
    tran_mat = np.full((len(s_set), len(d_set)), -np.inf)
    for i, s_name in enumerate(s_set):
        d_subset = df[df['symptom_name'] == s_name]['disease_name'].unique()
        js = [d_set.index(x) for x in d_subset]
        tran_mat[i,js] = 1
    tran_mat = torch.from_numpy(tran_mat).to(device)
    return torch.softmax(tran_mat, dim=1), s_set, d_set
tran_mat, row_name, col_name = get_trans_mat()

In [14]:
import numpy as np
PAD_IDX = 0
UNK_IDX = 1
class QuesData(Dataset):
    def __init__(self, sent, cat, sent_len_perc = 80, max_vocab_size = 25000, valid_vocab=None):
        full_input, self.MAX_sent_len = self.drop_LongSents(sent, cat, sent_len_perc)
#         full_input, self.MAX_sent_len = self.drop_LongSents(data, sent_len_perc)
        self.d_inp, self.d_lab = [x[0] for x in full_input], [x[1] for x in full_input]
        if not valid_vocab:
            self.lab_dict = lab_to_ids(self.d_lab)
            vocab_counter = self.word_count(self.d_inp)
            self.vocab = self.build_vocab(vocab_counter, max_vocab_size)
        else:
            self.lab_dict = valid_vocab['label']
            self.vocab = valid_vocab['input']
        
    def __len__(self):
        return len(self.d_inp)
    
    def __getitem__(self, idx):
        input_sent = self.d_inp[idx]
        labels = self.d_lab[idx]
        input_ids = [self.vocab.word2id[word] if word in self.vocab.word2id.keys() else self.vocab.word2id['<UNK>'] for word in input_sent]
        label_ids = [self.lab_dict.word2id[labels]]
#         label_ids = [self.lab_dict.word2id[label] if label in self.lab_dict.word2id.keys() else self.lab_dict.word2id['<UNK>'] for label in labels]
        return input_ids, label_ids[0]
#         return input_sent, input_ids, labels, label_ids
    
    def drop_LongSents(self, sent, labs, sent_len_perc):
        inp = sent_to_words(sent)
#         labs = [x[1] for x in data]
        sent_lens = [len(s) for s in inp]
        print(max(sent_lens))
        MAX_len = int(np.percentile(sent_lens, sent_len_perc))
        dropped = [(d,lab) for d,lab in zip(inp, labs) if len(d)<=MAX_len]
        return dropped, MAX_len
    def word_count(self, ins):
        count = Counter()
        for sent in ins:
            for word in sent:
                if word: count[word] += 1
        return count
    def build_vocab(self, word_count, max_vocab_size):
        vocab = AttrDict()
        vocab.word2id = {'<PAD>': PAD_IDX, '<UNK>': UNK_IDX}
        vocab.word2id.update({token: (ids + 2) for ids, (token, count) in enumerate(word_count.most_common(max_vocab_size))if count>=2 })
        vocab.id2word = {ids:word for word, ids in vocab.word2id.items()}
        return vocab

## Batchify

In [16]:
def collate_fn(data):
    def _pad_sequences(seqs):
        lens = [len(seq) for seq in seqs]
        padded_seqs = torch.zeros(len(seqs), max(lens))
        for i, seq in enumerate(seqs):
            end = lens[i]
            padded_seqs[i, :end] = torch.LongTensor(seq[:end]).to(device)
        return padded_seqs, lens

    data.sort(key=lambda x: len(x[0]), reverse=True) #sort according to length of src seqs
    ques_seqs, trg = zip(*data)
    ques_seqs, ques_lens = _pad_sequences(ques_seqs)
    #(batch, seq_len) => (seq_len, batch)
    ques_seqs.transpose_(0,1)

    return ques_seqs, ques_lens, trg
    

In [18]:
# with open('preproc_data.pkl','rb') as file:
#     fdata = pickle.load(file)
BATCH_SIZE = 128
train_dataset = QuesData(X_train, y_train, sent_len_perc = 85)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_fn,
                                           shuffle=True)
valid_dataset = QuesData(X_test, y_test, sent_len_perc = 85, valid_vocab={'input':train_dataset.vocab, 'label':train_dataset.lab_dict})
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=collate_fn,
                                           shuffle=True)

773
804


# Baseline model with BiLSTM Encoder 

In [20]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size,num_classes, embed_size, row_name, col_name, tran_mat, hidden_size,n_layers = 1, dropout=0):
        super(EncoderLSTM, self).__init__()
        self.input_size = input_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.embedding = nn.Embedding(input_size, embed_size, padding_idx=PAD_IDX)
        self.lstm = nn.LSTM(embed_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
        self.fc1 = nn.Linear(hidden_size * 2, hidden_size)
        self.fc2 = nn.Linear(hidden_size,num_classes)
        
    def forward(self, input_seqs, input_lengths, hidden=None):
        batch_size = input_seqs.size()[1]
        embedded = self.embedding(input_seqs) #input_seq: T*B, embedded: T*B*H
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        _, hidden = self.lstm(packed, hidden)
        hidden = torch.cat((hidden[0][0,:,:],hidden[0][1,:,:]), dim = 1)
        out = torch.relu(self.fc1(hidden))
        return self.fc2(out)
    
    def initHidden(self):
        return torch.zeros(2, 2, BATCH_SIZE, self.hidden_size, device=device)

## Train

In [22]:
def train(input, label, in_lens, model, optim, criterion, max_clip_norm = 5):
    input = input.long().to(device)
    label = torch.LongTensor(label).to(device)
    in_lens = torch.LongTensor(in_lens).to(device)
    batch_size = input.size(1)
    model.train()
    optim.zero_grad()
#     hidden = model.initHidden()
    logits = model(input, in_lens)
    loss = criterion(logits, label)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
    optim.step()
    return loss.item()
def freeze_layer(layer):
    for param in layer.parameters():
        param.requires_grad = False

In [23]:
def evaluate(model, data_iter):
    model.eval()
    correct = 0
    total = 0
    for input, in_lens, label in data_iter:
        input = input.long().to(device)
        label = torch.LongTensor(label).to(device)
        in_lens = torch.LongTensor(in_lens).to(device)
        output = model(input, in_lens)
        
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
      
    return correct / float(total)

In [2]:
input_size = len(train_dataset.vocab.id2word)
num_classes = len(train_dataset.lab_dict.id2word)
embed_size = 100
hidden_size = 100
model = EncoderLSTM(input_size,num_classes, embed_size, row_name, col_name, tran_mat, hidden_size,n_layers = 1, dropout=0).to(device)
model.load_state_dict(torch.load('lstm01.pth'))
model.train()
# freeze_layer(model.lstm)
# freeze_layer(model.embedding)
encoder_optim = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-3, weight_decay=1e-10)
encoder_scheduler = torch.optim.lr_scheduler.LambdaLR(encoder_optim, lr_lambda=lambda epoch: 0.95 ** epoch)
criterion = nn.CrossEntropyLoss()

accs = []
train_loss = []
count = 1
for epoch in range(30):
    total_loss = 0
    for idx, batch_data in enumerate(train_loader):
        train_input, train_lens, train_label = batch_data
        loss = train(train_input, train_label, train_lens, model, encoder_optim, criterion, 2)
        total_loss += loss
        if idx%1000 == 0:
            print('Training Loss: {}'.format(loss))

    
    train_loss.append((total_loss/(idx+ 1)))
    if epoch%1==0:
        train_acc = evaluate(model, train_loader)
        val_acc = evaluate(model, valid_loader)
        print("Epoch %i; Train acc: %f; Dev acc %f" %(epoch,\
                        train_acc, val_acc))
        try:
            if val_acc > max(accs):
                torch.save(model.state_dict(), 'lstm01.pth')
        except:
            torch.save(model.state_dict(), 'lstm01.pth')
        accs.append(val_acc)

    