In [None]:
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
dtype=torch.FloatTensor
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
sentences = ["i like dog", "i love coffee", "i hate milk"]
sentences_list=" ".join(sentences).split()
vocab=list(set(sentences_list))
word2idx={w:i for i,w in enumerate(vocab)}
idx2word={i:w for i,w in enumerate(vocab)}
n_vocab=len(vocab)

def make_data(sentences):
    x,y=[],[]
    for sen in sentences:
        word=sen.split()
        idx=[word2idx[w] for w in word]
        x_tmp=idx[:-1]
        y_tmp=idx[-1]
        x.append(np.eye(n_vocab)[x_tmp])
        y.append(y_tmp)
    return x,y
x,y=make_data(sentences)
x,y=torch.Tensor(x),torch.LongTensor(y)
# print(x,y)
dataset=TensorDataset(x,y)
loader=DataLoader(dataset,batch_size=2,shuffle=True)


In [None]:
hidden_size=10
n_step=2
class TextRNN(nn.Module):
    def __init__(self):
        super(TextRNN,self).__init__()
        self.rnn=nn.RNN(input_size=n_vocab,hidden_size=hidden_size,num_layers=2)
        self.fc=nn.Linear(hidden_size,n_vocab)
    def forward(self,hidden,x):#!!
        x=x.transpose(0,1)
        out,hidden=self.rnn(x,hidden) #!!
        out=out[-1]
        model=self.fc(out)
        return model

In [None]:
model=TextRNN().to(device)
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.AdamW(model.parameters(),lr=1e-3)

writer=SummaryWriter()
step=0

for epoch in range(100):
    loss_record=[]

    for x,y in tqdm(loader):
        x,y=x.to(device),y.to(device)
        h0=torch.zeros(2,x.shape[0],hidden_size) #这个也可以放在textrnn的forward(self,x)里
        pred=model(h0,x)
        loss=criterion(pred,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step+=1
        loss_record.append(loss)
    mean_train_loss=sum(loss_record)/len(loss_record)
    writer.add_scalar('trainloss',mean_train_loss,step)


In [None]:
x,y=make_data(sentences)
x,y=torch.Tensor(x),torch.LongTensor(y)

h0=torch.zeros(2,len(x),hidden_size) # num_layers*num_directions,input_size=feature_len,hidden_size
# print(x)
print(model(h0,x).data)
predict=model(h0,x).data.max(1,keepdim=True)[1].squeeze()
print([n.item() for n in predict])
print([sen.split()[:2] for sen in sentences], '->',[idx2word[n.item()] for n in predict])

