In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
from torch import Tensor

In [4]:
from ner_ehr.models import LSTMNERTagger
from ner_ehr.losses import cross_entropy

In [5]:
batch_size = 2
embedding_dim = 5
vocab_size = 7
hidden_size = 9
num_classes = 13
seq_length = 11
embedding_weights = torch.randn((vocab_size, embedding_dim))
bidirectional = True

lstm = LSTMNERTagger(
    embedding_dim=embedding_dim, 
    vocab_size=vocab_size,
    hidden_size=hidden_size, 
    num_classes=num_classes, 
    bidirectional=bidirectional,
    embedding_weights=None)

In [6]:
rng = np.random.default_rng(42)
X = rng.choice(vocab_size, size=(batch_size, seq_length))

y = rng.choice(num_classes, size=vocab_size, p = [.88,] + [.01]*12)
Y = [[y[i] for i in x] for x in X]

X = torch.tensor(X)
Y = torch.tensor(Y)

print(X.shape, Y.shape)

Y_hat = lstm(X) # output of lstm cells, final hidden state, final cell state
Y_hat.shape

torch.Size([2, 11]) torch.Size([2, 11])


torch.Size([2, 11, 13])

In [7]:
adam = torch.optim.Adam(lstm.parameters(), lr=.001)
t = tqdm(range(1000), leave=False, position=0)
for i in t:
    Y_hat = lstm(X)
    loss = cross_entropy(Y_hat=Y_hat, Y=Y)
    adam.zero_grad()
    loss.backward()
    adam.step()
    t.set_description(f"loss: {loss.item()}")

                                                                               

In [8]:
lstm.embed.state_dict()

OrderedDict([('weight',
              tensor([[ 0.2157,  1.4937,  1.4773, -2.8541,  0.6943],
                      [-0.8613,  0.7176,  0.1704, -0.0129, -2.0453],
                      [ 1.7008, -1.5388, -1.2685, -0.5501, -0.6152],
                      [ 0.1853, -0.4560, -0.8930, -1.1458,  0.5767],
                      [-0.9381, -0.8464,  0.0548,  0.1583, -0.5692],
                      [ 1.2096, -0.9094,  0.9693,  0.1707, -2.7603],
                      [ 0.0352, -0.0475, -0.1072, -0.0768, -1.1837]]))])

In [9]:
with torch.no_grad():
    lstm.eval()
    Y_hat = lstm(X)

In [10]:
torch.argmax(nn.functional.softmax(Y_hat, dim=-1), dim=-1)

tensor([[5, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0],
        [0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0]])

In [11]:
Y

tensor([[5, 0, 0, 0, 0, 0, 5, 0, 0, 5, 0],
        [0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0]])

In [18]:
(Y==torch.argmax(nn.functional.softmax(Y_hat, dim=-1), dim=-1)).sum()

tensor(249)