In [42]:
with open('story.txt','r') as f:
    text = f.read()

In [43]:
char2index = {}
index2char = {}

In [44]:
char2index['#'] = 0
index2char[0] = '#'

In [45]:
all_chars = set(text)
for i,c in enumerate(all_chars):
    char2index[c] = i+1
    index2char[i+1] = c

In [46]:
vocab_size = len(char2index)
hidden_dim = 60

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim

In [48]:
class MLP(nn.Module):
    
    def __init__(self,vocab_size,embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embed_dim)
        self.seq = nn.Sequential(
            nn.Linear(embed_dim,embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim,embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim,vocab_size),
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,seq):
        # seq batch seq_len
        embedding = self.embedding(seq) # batch seq_len embedding_dim
        last = embedding[:,-1,:] # batch 1 embedding_dim
        last = last.squeeze(1) # batch embedding_dim
        logits = self.seq(last) # batch vocab_size
        probs = self.sigmoid(logits) # batch vocab_size
        return probs
        

In [49]:
context_len = 30

In [50]:
all_x = []
all_y = []
x = [0]*context_len
for c in text:
    index = char2index[c]
    y = index
    all_x.append(x)
    all_y.append(y)
    x = x[1:]+[index]
all_x = torch.tensor(all_x)
all_y = torch.tensor(all_y)

In [51]:
mlp = MLP(vocab_size,hidden_dim)

In [52]:
op = optim.Adam(mlp.parameters(),lr=0.0001)

In [53]:
criterion = nn.CrossEntropyLoss()

In [54]:
def get_data(n):
    idxs = torch.randint(0,all_x.shape[0],(n,))
    xx = all_x[idxs]
    yy = all_y[idxs]
    return xx,yy

In [55]:
get_data(10)

(tensor([[ 60,  81,  98,   3,  71,  21,  56,  95,   7, 116,  94, 102,  17,  69,
           29,  85,  45,  11, 119,   7,  23, 118,  17,  69,  82,  14, 118,  56,
           50,  83],
         [ 56,  10, 100,  81, 116,  94, 101,  55, 106,  20, 102,  51,  60,  81,
           98,   3,  71,  21,  56,  95,   7, 116,  94, 102,  17,  69,  29,  85,
           45,  11],
         [  9,  75,  95,   5,  46, 131,  70, 112, 122,  81,  79,  82,  59,  74,
           36, 102,  97,  33,  20,  81,  93,  12,  84, 121, 105,  25,  11, 117,
           45,  19],
         [  7, 116,  94,  51,  60, 108,  90,  59,  81,   6,   1,  64, 124,  23,
          113, 130,  40, 102,  67,  66,  81,  43, 114,  17,  49,  77, 103,  73,
           81,  97],
         [ 54,  58,  22,  11, 109, 116,  94,  93,  12,  84, 121, 102, 110,  72,
           45,  41,  27,  30,  14, 118,  48,  63,  96,  42,  81,  57, 104,  75,
           39,  24],
         [ 62,  81,  35,   8,  84, 121, 102, 130,  40,  67,  66,   7, 116,  94,
           51, 

In [56]:
batch_n = 10

In [62]:
for i in range(100000):
    xx,yy = get_data(batch_n)
    probs = mlp(xx)
    loss =  criterion(probs,yy)
    if i % 1000 == 0:
        print(loss)
    op.zero_grad()
    loss.backward()
    op.step()

tensor(3.9410, grad_fn=<NllLossBackward0>)
tensor(3.9542, grad_fn=<NllLossBackward0>)
tensor(3.9147, grad_fn=<NllLossBackward0>)
tensor(3.9255, grad_fn=<NllLossBackward0>)
tensor(3.9135, grad_fn=<NllLossBackward0>)
tensor(3.9292, grad_fn=<NllLossBackward0>)
tensor(3.9135, grad_fn=<NllLossBackward0>)
tensor(3.9304, grad_fn=<NllLossBackward0>)
tensor(3.9122, grad_fn=<NllLossBackward0>)
tensor(3.9216, grad_fn=<NllLossBackward0>)
tensor(3.9446, grad_fn=<NllLossBackward0>)
tensor(3.9360, grad_fn=<NllLossBackward0>)
tensor(3.9160, grad_fn=<NllLossBackward0>)
tensor(3.9268, grad_fn=<NllLossBackward0>)
tensor(3.9242, grad_fn=<NllLossBackward0>)
tensor(3.9614, grad_fn=<NllLossBackward0>)
tensor(3.9378, grad_fn=<NllLossBackward0>)
tensor(3.9481, grad_fn=<NllLossBackward0>)
tensor(3.9135, grad_fn=<NllLossBackward0>)
tensor(3.9254, grad_fn=<NllLossBackward0>)
tensor(3.9211, grad_fn=<NllLossBackward0>)
tensor(3.9210, grad_fn=<NllLossBackward0>)
tensor(3.9210, grad_fn=<NllLossBackward0>)
tensor(3.93

In [63]:
start = '你'
index = char2index[start]
xx = [index]
for _ in range(100):
    input_x = torch.tensor(xx)
    probs = mlp(input_x.view(1,-1)).view(-1) # vocabsize
    choice = torch.multinomial(probs,1)
    xx = xx + [choice.item()]

In [64]:
print(''.join([index2char[index] for index in xx]))

你想的事半的知识在的第一篇关的的文的路的必要的阶的路线的角度的阶的阶的，就踏的阶的知的，就很容易形的必要知识的角的知的的时的，详细写的的角的那的那么的必的阶的那么的角度，因的的时的时候有你想象的时的第的
