In [5]:
import math
import torch
from torch import nn
from torch.nn import functional as F
import collections
import random
import matplotlib.pyplot as plt
import re
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [6]:
file_path=r'F:\study\ml\LM\8\timemachine.txt'
def read_time_machine():
    with open(file_path) as f:
        lines=f.readlines()
    return [re.sub('[^A-Za-z]+',' ',line).strip().lower() for line in lines]

In [7]:
def tokenize(lines,token='word'):
    if token=='word':
        return [line.split() for line in lines]
    elif token=='char':
        return [list(line) for line in lines]
    else:
        print("error : unknown token type : ",token)

In [8]:
class Vocab:
    def __init__(self,tokens=None,min_freq=0,reserved_tokens=None):
        if tokens is None:
            tokens=[]
        if reserved_tokens is None:
            reserved_tokens=[]
        counter=count_corpus(tokens)
        self._token_freqs=sorted(counter.items(),key=lambda x:x[1],reverse=True)
        self.idx_to_token=['<unk>']+reserved_tokens
        self.token_to_idx={token:idx for idx,token in enumerate(self.idx_to_token)}
        self.idx_to_token,self.token_to_idx=[],dict()
        for token,freq in self._token_freqs:
            if freq<min_freq:
                break;
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token]=len(self.idx_to_token)-1
    
    def __len__(self):
        return len(self.idx_to_token)
    
    def __getitem__(self,tokens):
        if not isinstance(tokens,(list,tuple)):
            return self.token_to_idx.get(tokens,self.unk)
        return [self.__getitem__(token) for token in tokens]
    
    def to_tokens(self,indices):
        if not isinstance(indices,(list,tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]
    
    @property
    def unk(self):
        return 0;
    
    @property
    def token_freqs(self):
        return self._token_freqs;
    

In [9]:
def count_corpus(tokens):
    if len(tokens)==0 or isinstance(tokens[0],list):
        tokens=[token for line in tokens for token in line]
    return collections.Counter(tokens)

In [10]:
tokens=tokenize(read_time_machine())
corpus=[token for line in tokens for token in line]
vocab=Vocab(corpus)

In [11]:
def load_corpus_time_machine(max_tokens=-1):
    lines=read_time_machine()
    tokens=tokenize(lines,'char')
    vocab=Vocab(tokens)
    corpus=[vocab[token] for line in tokens for token in line]
    if max_tokens>0:
        corpus=corpus[:max_tokens]
    return corpus,vocab

In [12]:
corpus,vocab=load_corpus_time_machine()
len(corpus),len(vocab)

(170580, 27)

In [13]:
def seq_data_iter_random(corpus,batch_size,num_steps):
    corpus=corpus[random.randint(0,num_step-1):]
    num_subseqs=(len(corpus)-1)//num_steps
    initial_indices=list(range(0,num_subseqs*num_steps,num_steps))
    random.shuffle(initial_indices)

    def data(pos):
        return corpus[pos:pos+num_steps]
    
    num_batches=num_subseqs // batch_size
    for i in range(0,batch_size*num_batches,batch_size):
        initial_indices_per_batch=initial_indices[i:i+batch_size]
        X=[data(j) for j in initial_indices_per_batch]
        Y=[data(j+1) for j in initial_indices_per_batch]
        yield torch.tensor(X),torch.tensor(Y)
    

In [14]:
def seq_data_iter_sequential(corpus,batch_size,num_steps):
    offset=random.randint(0,num_steps)
    num_tokens=((len(corpus)-offset -1) //batch_size)*batch_size
    Xs=torch.tensor(corpus[offset:offset+num_tokens])
    Ys=torch.tensor(corpus[offset+1:offset+1+num_tokens])
    Xs,Ys=Xs.reshape(batch_size,-1),Ys.reshape(batch_size,-1)
    num_batches=Xs.shape[1]//num_steps
    for i in range(0,num_steps*num_batches,num_steps):
        X=Xs[:,i:i+num_steps]
        Y=Ys[:,i:i+num_steps]
        yield X,Y

In [15]:
my_seq = list(range(35))
for X, Y in seq_data_iter_sequential(my_seq, batch_size=2, num_steps=5):
    print('X: ', X, '\nY:', Y)

X:  tensor([[ 2,  3,  4,  5,  6],
        [18, 19, 20, 21, 22]]) 
Y: tensor([[ 3,  4,  5,  6,  7],
        [19, 20, 21, 22, 23]])
X:  tensor([[ 7,  8,  9, 10, 11],
        [23, 24, 25, 26, 27]]) 
Y: tensor([[ 8,  9, 10, 11, 12],
        [24, 25, 26, 27, 28]])
X:  tensor([[12, 13, 14, 15, 16],
        [28, 29, 30, 31, 32]]) 
Y: tensor([[13, 14, 15, 16, 17],
        [29, 30, 31, 32, 33]])


In [16]:
class SeqDataLoader:
    def __init__(self,batch_size,num_steps,use_random_iter,max_tokens):
        if use_random_iter:
            self.data_iter_fn=seq_data_iter_random
        else:
            self.data_iter_fn=seq_data_iter_sequential
        self.corpus,self.vocab=load_corpus_time_machine(max_tokens)
        self.batch_size,self.num_steps=batch_size,num_steps
        
    def __iter__(self):
        return self.data_iter_fn(self.corpus,self.batch_size,self.num_steps)

In [17]:
def load_data_time_machine(batch_size,num_steps,use_random_iter=False,max_tokens=10000):
    data_iter=SeqDataLoader(batch_size,num_steps,use_random_iter,max_tokens)
    return data_iter,data_iter.vocab

### 8.5

In [18]:
batch_size,num_steps=32,35
train_iter,vocab=load_data_time_machine(batch_size,num_steps)

In [19]:
X=torch.arange(10).reshape((2,5))
F.one_hot(X.T,28).shape

torch.Size([5, 2, 28])

In [20]:
def get_params(vocab_size,num_hiddens,device):
    num_inputs=num_outputs=vocab_size
    
    def normal(shape):
        return torch.randn(size=shape,device=device)*0.01
#     hidden layer
    W_xh=normal((num_inputs,num_hiddens))
    W_hh=normal((num_hiddens,num_hiddens))
    b_h=torch.zeros(num_hiddens,device=device)
#     output layer
    W_hq=normal((num_hiddens,num_outputs))
    b_q=torch.zeros(num_outputs,device=device)
    
#     add grad
    params=[W_xh,W_hh,b_h,W_hq,b_q]
    for param in params:
        param.requires_grad_(True)
    return params
    
    

In [21]:
def init_rnn_state(batch_size,num_hiddens,device):
    return (torch.zeros((batch_size,num_hiddens),device=device),)

In [22]:
def rnn(inputs,state,params):
    W_xh,W_hh,b_h,W_hq,b_q=params
    H,=state
    outputs=[]
    for X in inputs:
        H=torch.tanh(torch.mm(X,W_xh)+torch.mm(H,W_hh)+b_h)
        Y=torch.mm(H,W_hq)+b_q
        outputs.append(Y)
    return torch.cat(outputs,dim=0),(H,)

In [23]:
class RNNModuelScatch:
    def __init__(self,vocab_size,num_hiddens,device,get_params,init_state,forward_fn):
        self.vocab_size,self.num_hiddens=vocab_size,num_hiddens
        self.params=get_params(vocab_size,num_hiddens,device)
        self.init_state,self.forward_fn=init_state,forward_fn
    
    def __call__(self,X,state):
        X=F.one_hot(X.T,self.vocab_size).type(torch.float32)
        return self.forward_fn(X,state,self.params)
    
    def begin_state(self,batch_size,device):
        return self.init_state(batch_size,self.num_hiddens,device)
    

In [25]:
num_hiddens=512
net=RNNModuelScatch(len(vocab),num_hiddens,'cpu',get_params,init_rnn_state,rnn)
state=net.begin_state(X.shape[0],'cpu')
Y,new_state=net(X.to('cpu'),state)
Y.shape,len(new_state),new_state[0].shape

(torch.Size([10, 27]), 1, torch.Size([2, 512]))

In [36]:
def predict_ch8(prefix,num_preds,net,vocab,device):
    state=net.begin_state(batch_size=1,device=device)
    outputs=[vocab[prefix[0]]]
    get_input=lambda : torch.tensor([outputs[-1]],device=device).reshape((1,1))
    for y in prefix[1:]:
        _,state=net(get_input(),state)
        outputs.append(vocab[y])
    for _ in range(num_preds):
        y,state=net(get_input(),state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [37]:
torch.tensor([1], device='cpu').reshape((1, 1))

tensor([[1]])

In [38]:
a=lambda: torch.tensor([1], device='cpu').reshape((1, 1))

In [39]:
a()

tensor([[1]])

In [40]:
predict_ch8('time traveller ',10,net,vocab,'cpu')

'time traveller ya gzgzgzg'

In [42]:
def grad_clipping(net,theta):
    if isinstance(net,nn.Module):
        params=[p for p in net.parameters() if p.requires_grad]
    else:
        params=net.params
    norm=torch.sqrt(sum(torch.sum((p.grad**2)) for p in params))
    if norm>theta:
        for param in params:
            param.grad[:]*=theta /norm
    

In [44]:
def train_epoch_ch8(net,train_iter,loss,updater,device,use_random_iter):
    state=None
    tr_l,tr_num=[],[]
    for X,Y in train_iter:
        if state is None or use_random_iter:
            state=net.begin_state(batch_size=X.shape[0],device=device)
        else:
            if isinstance(net,nn.Module) and not isinstance(state,tuple):
                state.detach_()
            else:
                for s in state:
                    s.detach_()
        y=Y.T.reshape(-1)
        X,y=X.to(device),y.to(device)
        y_hat,state=net(X,state)
        l=loss(y_hat,y.long()).mean()
        if isinstance(updater,torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            grad_clipping(net,l)
            updater.step()
        else:
            l.backward()
            grad_clipping(net,l)
            updater(batch_size=1)
        tr_l.append(l*y.numel())
        tr_num.append(y.numel())
    return math.exp(sum(tr_l)/sum(tr_num))
        
            
        

In [45]:
def sgd(params,lr,batch_size):
    with torch.no_grad():
        for p in params:
            p.data -=lr*p.grad.data/batch_size
            p.grad.zero_()
    

In [46]:
def train_ch8(net,train_iter,vocab,lr,num_epochs,device,use_random_iter=False):
    loss=nn.CrossEntropyLoss()
    if isinstance(net,nn.Module):
        updater=torch.optim.SGD(net.parameters(),lr)
    else:
        updater=lambda batch_size:sgd(net.parameters(),lr)
    predict=lambda prefix:predict_ch8(prefix,50,net,vocab,device)
    
    for epoch in range(num_epochs):
        ppl=train_epoch_ch8(net,train_iter,loss,updater,device,use_random_iter)
        
        if(epoch+1) %10==0:
            print(predict("time traveller"))
    print(f'困惑度 {ppl:.1f}')
    print(predict('time traveller'))
    print(predict('traveller'))
    

In [48]:
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, 'cpu')

AttributeError: 'RNNModuelScatch' object has no attribute 'parameters'