In [5]:
%load_ext autoreload
%autoreload 2

In [100]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from gensim.models import Word2Vec
from constants import *

# Add the parent directory to the system path
sys.path.append("..")

from HLAN.data_util_gensim import create_vocabulary_label_pre_split, create_vocabulary
from utils import load_data_multilabel_pre_split , create_dataloaders
from HAN_model import HierarchicalAttentionNetwork

In [3]:
workspaceFolder = "/Users/hungryfoolish/Documents/OMSCS/BD4H/project/Explainable-Automated-Medical-Coding-master"
word2vec_model_path = f"{workspaceFolder}/embeddings/processed_full.w2v"

word2vec_label_path_200 = f"{workspaceFolder}/embeddings/train_full-code-emb-mimic3-tr-200.model"
word2vec_label_path_400 = f"{workspaceFolder}/embeddings/train_full-code-emb-mimic3-tr-400.model"

training_data_path = f"{workspaceFolder}/datasets/data/train_50_eamc.csv"
validation_data_path = f"{workspaceFolder}/datasets/data/dev_50_eamc.csv"
testing_data_path = f"{workspaceFolder}/datasets/data/test_50_eamc.csv"
dataset = "mimic3-ds-50"
num_classes = 50
sequence_length = 2500
batch_size = 32
embed_size=100
hidden_size=100
num_sentences=100
sentence_length=sequence_length // num_sentences
training_data_path

'/Users/hungryfoolish/Documents/OMSCS/BD4H/project/Explainable-Automated-Medical-Coding-master/datasets/data/train_50_eamc.csv'

# Data Loading

In [27]:
vocabulary_word2index_label,vocabulary_index2word_label = create_vocabulary_label_pre_split(training_data_path=training_data_path, validation_data_path=validation_data_path, testing_data_path=testing_data_path, name_scope=dataset + "-HAN") # keep a distinct name scope for each model and each dataset.

In [28]:
vocabulary_word2index, vocabulary_index2word = create_vocabulary(word2vec_model_path,name_scope=dataset + "-HAN")

cache_path: ../cache_vocabulary_label_pik/mimic3-ds-50-HAN_word_vocabulary.pik file_exists: True


In [30]:
len(vocabulary_word2index)

150854

In [31]:
len(vocabulary_word2index_label)

50

In [10]:
X, Y = load_data_multilabel_pre_split(vocabulary_word2index, vocabulary_word2index_label, data_path=training_data_path, keep_label_percent=1.0)

load_data.started...
load_data_multilabel_new.data_path: /Users/hungryfoolish/Documents/OMSCS/BD4H/project/Explainable-Automated-Medical-Coding-master/datasets/data/train_50_eamc.csv
load_data.ended...


In [11]:
len(X), len(Y)

(8066, 8066)

In [23]:
max([len(x) for x in X])

7567

### step through data loading code

this is the maximum sequence length, quite higher than 2500. So currently those text are truncated!

In [21]:
len(Y[0]), len(Y[1])

(50, 50)

In [28]:
def pad_or_truncate_sequence(sequence, maxlen):
    if len(sequence) > maxlen:
        sequence = sequence[:maxlen]
    elif len(sequence) < maxlen:
        sequence += [0] * (maxlen - len(sequence))
    return sequence

X_padded = [pad_or_truncate_sequence(x, sequence_length) for x in X]

In [29]:
max([len(x) for x in X_padded]), min([len(x) for x in X_padded])

(2500, 2500)

In [63]:
torch.tensor(X_padded).shape

torch.Size([8066, 2500])

In [71]:
from torch.utils.data import DataLoader, TensorDataset

# Convert your data to PyTorch tensors
# split X into sentences
X_tensor = torch.tensor(X_padded)
Y_tensor = torch.tensor(Y)

# Create a TensorDataset from your tensors
dataset = TensorDataset(X_tensor, Y_tensor)

# Create a DataLoader from your dataset
dataloader = DataLoader(dataset, batch_size=batch_size)



In [72]:
# Now you can iterate over the DataLoader to get your batches
for X_batch, Y_batch in dataloader:
    # Process the batch
    print(X_batch.shape, Y_batch.shape)
    break

torch.Size([32, 2500]) torch.Size([32, 50])


### test dataloading code

In [16]:
#from utils import create_train_dataloader

dataloader = create_train_dataloader()

cache_path: ../cache_vocabulary_label_pik/mimic3-ds-50-HAN_word_vocabulary.pik file_exists: True
load_data.started...
load_data_multilabel_new.data_path: /Users/hungryfoolish/Documents/OMSCS/BD4H/project/Explainable-Automated-Medical-Coding-master/datasets/data/train_50_eamc.csv
load_data.ended...


In [17]:
for x, y in dataloader:
    print(x.shape, y.shape)
    break

torch.Size([32, 2500]) torch.Size([32, 50])


In [18]:
y[0]

tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])

# Model

## Step through the model

In [73]:
_embeddings = nn.Embedding(vocab_size, embed_size)


In [74]:
embedded_documents = _embeddings(X_batch)

In [75]:
embedded_documents.shape

torch.Size([32, 2500, 100])

In [76]:
_word_gru = nn.GRU(embed_size, hidden_size, bidirectional=True)

In [77]:
output_words, _ = _word_gru(embedded_documents)

In [78]:
output_words.shape

torch.Size([32, 2500, 200])

In [86]:
output_words = output_words.view(batch_size, num_sentences, sentence_length, 2 * hidden_size)
output_words.shape

torch.Size([32, 100, 25, 200])

In [150]:
class Attention(nn.Module):
    def __init__(self, input_size, output_size):
        super(Attention, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.linear = nn.Linear(input_size, output_size)
        self.context_vector = nn.Parameter(torch.randn((output_size)))
        
    def forward(self, input):
        hidden_representation = torch.tanh(self.linear(input))
        attention_logits = torch.sum(hidden_representation * self.context_vector, dim=-1)
        
        # for numerical stability, subtract the max of the attention logits
        attention_logits_max, _ = torch.max(attention_logits, dim=-1, keepdim=True) 
        attention = F.softmax(attention_logits - attention_logits_max, dim=-1)

        output = torch.sum(input * attention.unsqueeze(-1), dim=-2)
        return output

### step through attention

In [88]:
linear = nn.Linear(2 * hidden_size, 2 * hidden_size)
hidden_representation = torch.tanh(linear(output_words))
hidden_representation.shape

torch.Size([32, 100, 25, 200])

In [92]:
context_vector = nn.Parameter(torch.randn((2 * hidden_size)))

In [113]:
attention_logits = torch.sum(hidden_representation * context_vector, dim=-1)
attention_logits.shape

torch.Size([32, 100, 25])

In [114]:
attention_logits_max, _ = torch.max(attention_logits, dim=-1, keepdim=True)
attention_logits_max.shape

torch.Size([32, 100, 1])

In [115]:
attention = F.softmax(attention_logits - attention_logits_max, dim=-1)
attention.shape

torch.Size([32, 100, 25])

In [116]:
attention.unsqueeze(-1).shape

torch.Size([32, 100, 25, 1])

In [118]:
sentence_representation = torch.sum(output_words * attention.unsqueeze(-1), dim=-2)
sentence_representation.shape

torch.Size([32, 100, 200])

### resume using attention

In [138]:
_word_attention = Attention(hidden_size * 2, hidden_size * 2)

In [139]:
output_words_attn = _word_attention(output_words)

In [140]:
output_words_attn.shape

torch.Size([32, 100, 200])

In [123]:
sentence_gru = nn.GRU(2*hidden_size, 2*hidden_size, bidirectional=True)

In [124]:
output_sentences, _ = sentence_gru(output_words_attn)
output_sentences.shape

torch.Size([32, 100, 400])

### Step through sentence attention code

In [142]:
input_size = 4*hidden_size
output_size = 2*hidden_size
linear = nn.Linear(input_size, output_size)
context_vector = nn.Parameter(torch.randn((output_size)))

In [143]:
input = output_sentences
input.shape

torch.Size([32, 100, 400])

In [144]:
hidden_representation = torch.tanh(linear(input))
hidden_representation.shape

torch.Size([32, 100, 200])

In [145]:
attention_logits = torch.sum(hidden_representation * context_vector, dim=-1)
attention_logits.shape

torch.Size([32, 100])

In [146]:
attention_logits_max, _ = torch.max(attention_logits, dim=-1, keepdim=True) 
attention = F.softmax(attention_logits - attention_logits_max, dim=-1)
attention.shape

torch.Size([32, 100])

In [147]:
attention.unsqueeze(-1).shape

torch.Size([32, 100, 1])

In [148]:
temp = input * attention.unsqueeze(-1)
temp.shape

torch.Size([32, 100, 400])

In [149]:
output = torch.sum(input * attention.unsqueeze(-1), dim=-2)
output.shape

torch.Size([32, 400])

### resume with sentence attention

In [151]:
sentence_attention = Attention(4*hidden_size, 2*hidden_size)

In [152]:
output_sentences_attn = sentence_attention(output_sentences)
output_sentences_attn.shape

torch.Size([32, 400])

In [154]:
fc = nn.Linear(4*hidden_size, num_classes)
output = fc(output_sentences_attn)
output.shape

torch.Size([32, 50])

## Forward pass

In [12]:
model = HierarchicalAttentionNetwork(batch_size=batch_size, vocab_size=vocab_size, embed_size=embed_size, hidden_size=hidden_size, num_sentences=num_sentences, sentence_length=sentence_length, num_classes=num_classes)

In [19]:
out = model(x)

In [20]:
out.shape

torch.Size([32, 50])

In [22]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape)

embeddings.weight torch.Size([150854, 100])
word_gru.weight_ih_l0 torch.Size([300, 100])
word_gru.weight_hh_l0 torch.Size([300, 100])
word_gru.bias_ih_l0 torch.Size([300])
word_gru.bias_hh_l0 torch.Size([300])
word_gru.weight_ih_l0_reverse torch.Size([300, 100])
word_gru.weight_hh_l0_reverse torch.Size([300, 100])
word_gru.bias_ih_l0_reverse torch.Size([300])
word_gru.bias_hh_l0_reverse torch.Size([300])
sentence_gru.weight_ih_l0 torch.Size([600, 200])
sentence_gru.weight_hh_l0 torch.Size([600, 200])
sentence_gru.bias_ih_l0 torch.Size([600])
sentence_gru.bias_hh_l0 torch.Size([600])
sentence_gru.weight_ih_l0_reverse torch.Size([600, 200])
sentence_gru.weight_hh_l0_reverse torch.Size([600, 200])
sentence_gru.bias_ih_l0_reverse torch.Size([600])
sentence_gru.bias_hh_l0_reverse torch.Size([600])
word_attention.context_vector torch.Size([200])
word_attention.linear.weight torch.Size([200, 200])
word_attention.linear.bias torch.Size([200])
sentence_attention.context_vector torch.Size([200])

In [23]:
model.embed_size, model.hidden_size

(100, 100)

### Step through model iteration 2

In [29]:
_embeddings = nn.Embedding(vocab_size, embed_size)
embedded_documents = _embeddings(x)
embedded_documents.shape

torch.Size([32, 2500, 100])

In [31]:
word_gru = nn.GRU(embed_size, hidden_size, bidirectional=False) # paper model only has 1 set of parameters, don't have a separate parameters set for reverse direction

In [33]:
output_words_left2right, _ = word_gru(embedded_documents)
output_words_left2right.shape

torch.Size([32, 2500, 100])

In [34]:
output_words_right2left, _ = word_gru(embedded_documents)
output_words_right2left.shape

torch.Size([32, 2500, 100])

In [36]:
# reverse the order of right2left
output_words_right2left_reshaped = output_words_right2left.view(-1, num_sentences, sentence_length, embed_size)
output_words_right2left_reshaped.shape

torch.Size([32, 100, 25, 100])

#### understand .flip

In [49]:
a = torch.tensor([
        [[0, 1, 2, 3], [4, 5, 6, 7]],
        [[8, 9, 10, 11], [12, 13, 14, 15]],
        [[100, 101, 102, 103], [104, 105, 106, 107]]
        ])
a

tensor([[[  0,   1,   2,   3],
         [  4,   5,   6,   7]],

        [[  8,   9,  10,  11],
         [ 12,  13,  14,  15]],

        [[100, 101, 102, 103],
         [104, 105, 106, 107]]])

In [50]:
a.shape

torch.Size([3, 2, 4])

In [51]:
c = a.flip(dims=[0])
c

tensor([[[100, 101, 102, 103],
         [104, 105, 106, 107]],

        [[  8,   9,  10,  11],
         [ 12,  13,  14,  15]],

        [[  0,   1,   2,   3],
         [  4,   5,   6,   7]]])

In [52]:
b = a.flip(dims=[1])
b

tensor([[[  4,   5,   6,   7],
         [  0,   1,   2,   3]],

        [[ 12,  13,  14,  15],
         [  8,   9,  10,  11]],

        [[104, 105, 106, 107],
         [100, 101, 102, 103]]])

#### resume

In [53]:
output_words_right2left_reshaped_reversed = output_words_right2left_reshaped.flip(dims=[2])
output_words_right2left_reshaped_reversed.shape

torch.Size([32, 100, 25, 100])

In [54]:
output_words_right2left_reversed = output_words_right2left_reshaped_reversed.view(batch_size, num_sentences * sentence_length, embed_size)
output_words_right2left_reversed.shape

torch.Size([32, 2500, 100])

In [55]:
output_words = torch.cat((output_words_left2right, output_words_right2left_reversed), dim=-1)
output_words.shape

torch.Size([32, 2500, 200])

#### Word level attention step through, with per label attention

In [59]:
input_size = output_size = 2 * hidden_size
linear = nn.Linear(input_size, output_size)
context_vector = nn.Parameter(torch.randn((num_classes, output_size)))

In [61]:
context_vector.shape

torch.Size([50, 200])

In [68]:
context_vector.unsqueeze(0).unsqueeze(0).shape

torch.Size([1, 1, 50, 200])

In [69]:
context_vector.unsqueeze(0).unsqueeze(0).transpose(-1, -2).shape

torch.Size([1, 1, 200, 50])

In [60]:
input_tensor = output_words
hidden_representation = torch.tanh(linear(input_tensor))
hidden_representation.shape

torch.Size([32, 2500, 200])

In [66]:
# split both to word and sentence level
hidden_representation_reshaped = hidden_representation.view(batch_size, num_sentences, sentence_length, output_size)
hidden_representation_reshaped.shape

torch.Size([32, 100, 25, 200])

trying dimensions

In [76]:
a = hidden_representation_reshaped
b = context_vector.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
print(a.shape, b.shape)


torch.Size([32, 100, 25, 200]) torch.Size([1, 1, 200, 50])


In [77]:
c = torch.matmul(a, b)
c.shape

torch.Size([32, 100, 25, 50])

good, now use it

In [78]:
context_vector_expanded = context_vector.unsqueeze(0).unsqueeze(0).transpose(-1, -2)
print(hidden_representation_reshaped.shape, context_vector_expanded.shape)

torch.Size([32, 100, 25, 200]) torch.Size([1, 1, 200, 50])


In [81]:
attention_logits = torch.matmul(hidden_representation_reshaped, context_vector_expanded)
attention_logits.shape

torch.Size([32, 100, 25, 50])

In [84]:
attention_logits_max, _ = torch.max(attention_logits, dim=2, keepdim=True)
attention_logits_max.shape

torch.Size([32, 100, 1, 50])

In [85]:
attention = F.softmax(attention_logits - attention_logits_max, dim=-1)
attention.shape

torch.Size([32, 100, 25, 50])

In [90]:
attention_reshaped = attention.unsqueeze(-1)
attention_reshaped.shape

torch.Size([32, 100, 25, 50, 1])

In [88]:
input_tensor_reshaped = input_tensor.view(batch_size, num_sentences, sentence_length, 1, output_size)
input_tensor_reshaped.shape

torch.Size([32, 100, 25, 1, 200])

In [92]:
temp = attention_reshaped * input_tensor_reshaped
temp.shape

torch.Size([32, 100, 25, 50, 200])

In [93]:
output = torch.sum(temp, dim=2)
output.shape

torch.Size([32, 100, 50, 200])

In [None]:
output = torch.sum(input_tensor * attention.unsqueeze(-1), dim=-2)

resume after attention word level

In [57]:
word_attention = Attention(2*hidden_size, 2*hidden_size, per_label_attention=True, num_classes=50) # need to be separate for each class depending on input parameter per_label_attention

In [58]:
word_attention.context_vector.shape

torch.Size([50, 200])

In [None]:
output_words_attn = word_attention(output_words)
output_words_attn.shape

### test Attention Word Level

In [104]:
from HAN_model import AttentionWordLevel

In [105]:
word_attention = AttentionWordLevel(input_size=hidden_size*2, output_size=hidden_size*2, per_label_attention=True, num_classes=50, num_sentences=100, sentence_length=25)

In [106]:
output_words_attn = word_attention(output_words)

In [107]:
output_words_attn.shape

torch.Size([32, 100, 50, 200])

### resume with HAN

In [108]:
sentence_gru = nn.GRU(2*hidden_size, 2*hidden_size, bidirectional=False)

In [111]:
output_words_attn.shape

torch.Size([32, 100, 50, 200])

In [112]:
# reshape to 3D tensor
output_words_attn_reshaped = output_words_attn.view(batch_size, num_sentences*num_classes, 2*hidden_size)
output_words_attn_reshaped.shape

torch.Size([32, 5000, 200])

In [113]:
output_sentences_left2right, _ = sentence_gru(output_words_attn_reshaped)
output_sentences_left2right.shape

torch.Size([32, 5000, 200])

In [114]:
output_sentences_left2right_reshaped = output_sentences_left2right.view(batch_size, num_sentences, num_classes, 2*hidden_size)
output_sentences_left2right_reshaped.shape

torch.Size([32, 100, 50, 200])

In [115]:
output_sentences_right2left, _ = sentence_gru(output_words_attn_reshaped)
output_sentences_right2left.shape

torch.Size([32, 5000, 200])

In [116]:
output_sentences_right2left_reshaped = output_sentences_right2left.view(batch_size, num_sentences, num_classes, 2*hidden_size)
output_sentences_right2left_reshaped.shape

torch.Size([32, 100, 50, 200])

In [117]:
output_sentences_right2left_reshaped_reversed = output_sentences_right2left_reshaped.flip(dims=[1])
output_sentences_right2left_reshaped_reversed.shape

torch.Size([32, 100, 50, 200])

In [118]:
output_sentences = torch.cat((output_sentences_left2right_reshaped, output_sentences_right2left_reshaped_reversed), dim=-1)
output_sentences.shape

torch.Size([32, 100, 50, 400])

### Sentence level attention

In [119]:
input_size = 4 * hidden_size
output_size = 2 * hidden_size
linear = nn.Linear(input_size, output_size)

In [120]:
input_tensor = output_sentences
input_tensor.shape

torch.Size([32, 100, 50, 400])

In [123]:
hidden_representation = torch.tanh(linear(input_tensor))
hidden_representation.shape

torch.Size([32, 100, 50, 200])

In [124]:
context_vector = nn.Parameter(torch.randn((num_classes, output_size)))
context_vector.shape

torch.Size([50, 200])

In [125]:
attention_logits = torch.sum(hidden_representation * context_vector.unsqueeze(0).unsqueeze(0), dim=-1)
attention_logits.shape

torch.Size([32, 100, 50])

In [127]:
attention_logits_max, _ = torch.max(attention_logits, dim=-2, keepdim=True)
attention_logits_max.shape

torch.Size([32, 1, 50])

In [128]:
attention = F.softmax(attention_logits - attention_logits_max, dim=-1)
attention.shape

torch.Size([32, 100, 50])

In [130]:
temp = attention.unsqueeze(-1) * input_tensor
temp.shape

torch.Size([32, 100, 50, 400])

In [131]:
output = torch.sum(temp, dim=1)
output.shape

torch.Size([32, 50, 400])

done with sentence attention, now the last layer

In [146]:
W = nn.Parameter(torch.randn((num_classes, hidden_size*4)))
b = nn.Parameter(torch.randn((num_classes)))

In [139]:
W.shape

torch.Size([50, 400])

In [145]:
temp = output *  W.unsqueeze(0)
temp.shape

torch.Size([32, 50, 400])

In [147]:
final_output = torch.sum(temp, dim=-1) + b
final_output.shape

torch.Size([32, 50])

# Backward pass

In [24]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005)

In [1]:
# train one epoch
losses = []
for i, (x, y) in enumerate(dataloader):
    optimizer.zero_grad()
    y_hat = model(x)
    loss = criterion(y_hat, y)
    loss.backward()
    losses.append(loss.item())
    optimizer.step()
    if i % 10 == 0:
        print('iter', i, 'loss', np.array(losses).mean())
print('done 1 epoch', 'loss:', np.array(losses).mean())

NameError: name 'dataloader' is not defined

backward pass works. However loss was going up! Loss function in original model is more complex. Need to check and compare

# Initialization

prepare the values. first, embeddings

In [32]:
vocabulary_word2index, vocabulary_index2word = create_vocabulary(word2vec_model_path,name_scope=dataset + "-HAN")

cache_path: ../cache_vocabulary_label_pik/mimic3-ds-50-HAN_word_vocabulary.pik file_exists: True


In [6]:
word_w2v = Word2Vec.load(word2vec_model_path)

In [33]:
len(word_w2v.wv.key_to_index)

150853

In [34]:
vocab_size = len(vocabulary_word2index)
vocab_size

150854

actual vocab has an extra word for 'PAD_ID'

In [41]:
# initialize a tensor with random values 
bound = np.sqrt(6.0) / np.sqrt(vocab_size)
embeddings = torch.rand((vocab_size, embed_size)) * 2 * bound - bound


In [42]:
embeddings.shape

torch.Size([150854, 100])

In [43]:
for i in range(vocab_size):
    word = vocabulary_index2word[i]
    # word might not exists in the word2vec model!
    if word in word_w2v.wv.key_to_index:
        embeddings[i] = torch.tensor(word_w2v.wv[word])

next, W_projection for Linear layer (hidden_size*4, num_classes)

In [44]:
vocabulary_word2index_label,vocabulary_index2word_label = create_vocabulary_label_pre_split(training_data_path=training_data_path, validation_data_path=validation_data_path, testing_data_path=testing_data_path, name_scope=dataset + "-HAN") # keep a distinct name scope for each model and each dataset.

In [25]:
label_w2v_400 = Word2Vec.load(word2vec_label_path_400)

In [26]:
len(label_w2v_400.wv.key_to_index)

8686

In [93]:
bound = np.sqrt(6.0) / np.sqrt(num_classes + hidden_size * 4)
W_linear = torch.rand((num_classes, hidden_size*4)) * 2 * bound - bound

In [94]:
vocabulary_index2word_label[0]

'401.9'

In [95]:
for i in range(num_classes):
    label = vocabulary_index2word_label[i]
    if label in label_w2v_400.wv.key_to_index:
        W_linear[i, :] = torch.tensor(label_w2v_400.wv[label])
    else:
        print(i, label, 'not in vocab')

lastly, the word level and sentence level context vector

In [52]:
label_w2v_200 = Word2Vec.load(word2vec_label_path_200)

In [51]:
word_context_vector = torch.rand((num_classes, hidden_size * 2)) * 2* bound - bound
sentence_context_vector = torch.rand((num_classes, hidden_size * 2)) * 2* bound - bound

In [98]:
for i in range(num_classes):
    label = vocabulary_index2word_label[i]
    if label in label_w2v_200.wv.key_to_index:
        word_context_vector[i, :] = torch.tensor(label_w2v_200.wv[label])
        sentence_context_vector[i, :] = torch.tensor(label_w2v_200.wv[label])
    else:
        print(i, label, 'not in vocab')

initialization

In [89]:
model = HierarchicalAttentionNetwork(vocab_size=vocab_size, embed_size=embed_size, hidden_size=hidden_size, num_sentences=num_sentences, sentence_length=sentence_length, num_classes=num_classes)

In [90]:
for n, p in model.named_parameters():
    print(n, p.shape, p.requires_grad)

W torch.Size([50, 400]) True
b torch.Size([50]) True
embeddings.weight torch.Size([150854, 100]) True
word_gru.weight_ih_l0 torch.Size([300, 100]) True
word_gru.weight_hh_l0 torch.Size([300, 100]) True
word_gru.bias_ih_l0 torch.Size([300]) True
word_gru.bias_hh_l0 torch.Size([300]) True
sentence_gru.weight_ih_l0 torch.Size([600, 200]) True
sentence_gru.weight_hh_l0 torch.Size([600, 200]) True
sentence_gru.bias_ih_l0 torch.Size([600]) True
sentence_gru.bias_hh_l0 torch.Size([600]) True
word_attention.context_vector torch.Size([50, 200]) True
word_attention.linear.weight torch.Size([200, 200]) True
word_attention.linear.bias torch.Size([200]) True
sentence_attention.context_vector torch.Size([50, 200]) True
sentence_attention.linear.weight torch.Size([200, 400]) True
sentence_attention.linear.bias torch.Size([200]) True


In [97]:
with torch.no_grad():
    model.embeddings.weight.data.copy_(embeddings)
    model.word_attention.context_vector.data.copy_(word_context_vector)
    model.sentence_attention.context_vector.data.copy_(sentence_context_vector)
    model.W.data.copy_(W_linear)


# Test model loading

In [101]:
model = HierarchicalAttentionNetwork(vocab_size=vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_sentences=NUM_SENTENCES, sentence_length=SENTENCE_LENGTH, num_classes=NUM_CLASSES)

In [102]:
model.load_state_dict(torch.load("./checkpoints/HLAN_LE_20epochs.pt"))

<All keys matched successfully>

In [103]:
model.eval()

HierarchicalAttentionNetwork(
  (embeddings): Embedding(150854, 100)
  (word_gru): GRU(100, 100)
  (sentence_gru): GRU(200, 200)
  (word_attention): AttentionPerLabelWordLevel(
    (linear): Linear(in_features=200, out_features=200, bias=True)
  )
  (sentence_attention): AttentionPerLabelSentenceLevel(
    (linear): Linear(in_features=400, out_features=200, bias=True)
  )
)