In [2]:
import torch
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import TensorDataset,DataLoader
from torch.utils.tensorboard import SummaryWriter
dtype=torch.FloatTensor

In [17]:
char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']
word2idx={w:i for i,w, in enumerate(char_arr)}
idx2word={i:w for i,w in enumerate(char_arr)}
n_class=len(word2idx)
seq_data=['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']

n_step=len(seq_data[0])-1
n_hidden=128

def make_data(seq_data):
    x,y=[],[]
    for seq in seq_data:
        x_tmp=[word2idx[n] for n in seq[:-1]]
        y_tmp=word2idx[seq[-1]]
        x.append(np.eye(n_class)[x_tmp])
        y.append(y_tmp)
    return torch.Tensor(x),torch.LongTensor(y)
x,y=make_data(seq_data)
dataset=TensorDataset(x,y)
loader=DataLoader(dataset,3,True)


In [15]:
class TextLSTM(nn.Module):
    def __init__(self):
        super(TextLSTM,self).__init__()
        self.lstm=nn.LSTM(input_size=n_class,hidden_size=n_hidden)
        self.fc=nn.Linear(n_hidden,n_class)
    def forward(self,x):
        batch_size=x.shape[0]
        input=x.transpose(0,1)

        h0=torch.zeros(1,batch_size,n_hidden)
        c0=torch.zeros(1,batch_size,n_hidden)

        outputs,(_,_)=self.lstm(input,(h0,c0))
        outputs=outputs[-1]
        y=self.fc(outputs)
        return y
model=TextLSTM()
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.AdamW(model.parameters(),lr=1e-3)

In [19]:
step=0
writer=SummaryWriter()
for epoch in range(50):
    loss_record=[]
    for x,y in tqdm(loader):
        pred=model(x)
        loss=criterion(pred,y)
        loss_record.append(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step+=1
    mean_train_loss=sum(loss_record)/len(loss_record)
    writer.add_scalar("mean_train_loss",mean_train_loss,step)

100%|██████████| 4/4 [00:01<00:00,  2.80it/s]
100%|██████████| 4/4 [00:00<00:00, 100.13it/s]
100%|██████████| 4/4 [00:00<00:00, 90.95it/s]
100%|██████████| 4/4 [00:00<00:00, 91.02it/s]
100%|██████████| 4/4 [00:00<00:00, 91.03it/s]
100%|██████████| 4/4 [00:00<00:00, 100.11it/s]
100%|██████████| 4/4 [00:00<00:00, 100.12it/s]
100%|██████████| 4/4 [00:00<00:00, 91.03it/s]
100%|██████████| 4/4 [00:00<00:00, 77.03it/s]
100%|██████████| 4/4 [00:00<00:00, 58.89it/s]
100%|██████████| 4/4 [00:00<00:00, 83.45it/s]
100%|██████████| 4/4 [00:00<00:00, 91.02it/s]
100%|██████████| 4/4 [00:00<00:00, 100.13it/s]
100%|██████████| 4/4 [00:00<00:00, 100.15it/s]
100%|██████████| 4/4 [00:00<00:00, 100.13it/s]
100%|██████████| 4/4 [00:00<00:00, 83.45it/s]
100%|██████████| 4/4 [00:00<00:00, 100.13it/s]
100%|██████████| 4/4 [00:00<00:00, 66.60it/s]
100%|██████████| 4/4 [00:00<00:00, 71.56it/s]
100%|██████████| 4/4 [00:00<00:00, 100.14it/s]
100%|██████████| 4/4 [00:00<00:00, 100.10it/s]
100%|██████████| 4/4 [00:

In [23]:
test=[sen[:3] for sen in seq_data]
x,y=make_data(seq_data)
y_test=model(x).data.max(1)[1]
# print(y_test)
print(test,'->',[idx2word[n.item()] for n in y_test])

['mak', 'nee', 'coa', 'wor', 'lov', 'hat', 'liv', 'hom', 'has', 'sta'] -> ['e', 'd', 'l', 'd', 'e', 'e', 'e', 'e', 'e', 'r']
