In [1]:
import numpy as np
import torch
import torch.nn as nn

d_k = 16 # Q K 的维度
d_v = 16 # V 维度
d_embedding = 128
n_heads = 3
batch_size = 16
n_channel = 256
n_layers = 6

class ScaleDotAttention(nn.Module):
    def __init__(self):
        super(ScaleDotAttention, self).__init__()
    def forward(self,Q=None,K=None,V=None,atten_mask=None):
        scores = torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(d_k)
        scores.masked_fill_(atten_mask, -np.inf)
        weights = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(weights, V)
        return context, weights

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention,self).__init__()
        self.W_Q = nn.Linear(d_embedding, d_k*n_heads)
        self.W_K = nn.Linear(d_embedding, d_k*n_heads)
        self.W_V = nn.Linear(d_embedding, d_v*n_heads)
        self.linear = nn.Linear(n_heads*d_v, d_embedding)
        self.scaledot = ScaleDotAttention()
        self.layer_norm = nn.LayerNorm(d_embedding)
    def forward(self, Q=None,K=None,V=None,atten_mask=None):
        residual = Q
        batch_size = Q.size(0)
        q = self.W_Q(Q).view(batch_size,-1,n_heads, d_k).transpose(1,2)
        k = self.W_K(K).view(batch_size,-1,n_heads, d_k).transpose(1,2)
        v = self.W_V(V).view(batch_size,-1,n_heads, d_v).transpose(1,2)

        atten_mask = atten_mask.unsqueeze(1).repeat(1,n_heads,1,1)

        context,weights = self.scaledot(Q=q,K=k,V=v,atten_mask=atten_mask)

        context = context.transpose(1,2).contiguous().view(batch_size,-1,n_heads*d_v)

        output = self.linear(context)

        output = self.layer_norm(output+residual)

        return output, weights

class FeedForward(nn.Module):
    def __init__(self):
        super(FeedForward,self).__init__()
        self.conv1 = nn.Conv1d(in_channels=d_embedding,out_channels=n_channel,kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=n_channel,out_channels=d_embedding,kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_embedding)
    def forward(self, inputs):
        residual = inputs
        output = nn.ReLU()(self.conv1(inputs.transpose(1,2)))
        output = self.conv2(output).transpose(1,2)
        output = self.layer_norm(output+residual)
        return output

def get_sin_cos_pos(n_pos,embedding_dim):
    sincos = np.zeros((n_pos,embedding_dim))
    for i in range(n_pos):
        for j in range(embedding_dim):
            value = i/np.power(10000,2*(j//2)/embedding_dim)
            sincos[i,j] = value
    sincos[:,0::2] = np.sin(sincos[:,0::2])
    sincos[:,1::2] = np.cos(sincos[:,1::2])
    return torch.FloatTensor(sincos)

def get_atten_mask(seq_q,seq_k):
    batch_size = seq_q.size(0)
    len_q = seq_q.size(1)
    len_k = seq_k.size(1)
    pad_atten_mask = seq_k.data.eq(0).unsqueeze(1)
    pad_atten_mask = pad_atten_mask.expand(batch_size,len_q,len_k)
    return pad_atten_mask

def get_Mask_atten(seq):
    atten_shape = [seq.size(0),seq.size(1),seq.size(1)]
    maskatten = np.triu(np.ones(atten_shape),k=1)
    maskatten = torch.from_numpy(maskatten).byte()
    return maskatten

class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer,self).__init__()
        self.dec_self_atten = MultiHeadAttention()
        self.feedforward = FeedForward()
        self.norm1 = nn.LayerNorm(d_embedding)
        self.norm2 = nn.LayerNorm(d_embedding)
    def forward(self,dec_inputs,dec_self_atten_mask):
        dec_outputs,dec_self_atten = self.dec_self_atten(dec_inputs,dec_inputs,dec_inputs,dec_self_atten_mask)

        norm1 = self.norm1(dec_inputs+dec_outputs)
        feed = self.feedforward(norm1)

        dec_outputs = self.norm2(norm1+feed)
        return dec_outputs

class Decoder(nn.Module):
    def __init__(self,vocab_size,max_lenth):
        super(Decoder,self).__init__()
        self.src_emb = nn.Embedding(vocab_size,d_embedding)
        self.pos_emb = nn.Embedding(vocab_size,d_embedding)
        self.layers = nn.ModuleList(DecoderLayer() for i in range(n_layers))
    def forward(self,dec_inputs):
        pos_index = torch.arange(len(dec_inputs),device=dec_inputs.device).unsqueeze(-1)
        inputs_emb = self.src_emb(dec_inputs)+self.pos_emb(pos_index)
        dec_self_atten_mask = get_Mask_atten(inputs_emb).to(device)
        for layer in self.layers:
            dec_outputs = layer(inputs_emb,dec_self_atten_mask)
            inputs_emb = dec_outputs
            dec_self_atten_mask = get_Mask_atten(inputs_emb).to(device)
        return dec_outputs

class GPT(nn.Module):
    def __init__(self,corpus):
        super(GPT,self).__init__()
        self.corpos = corpus
        self.decoder = Decoder(corpus.vocab_size,corpus.seq_len)
        self.linear = nn.Linear(d_embedding,corpus.vocab_size)
    def forward(self,dec_inputs):
        dec_outputs = self.decoder(dec_inputs)
        logit = self.linear(dec_outputs)
        return logit

In [2]:
def read_data(file_path):
    with open(file_path) as f:
        data = f.readlines()
    return [i.replace('\n','').strip(' ').lower() for i in data if i.replace('\n','')!=" " and i.replace('\n','')!='']

In [3]:
from collections import Counter
import transformers as tfs
from tqdm import tqdm

In [4]:
class bertCorpus():
    def __init__(self, sentences, max_len = 100):
        self.tokenizer = tfs.BertTokenizer.from_pretrained('bert-base-uncased')
        self.sentences = sentences
        self.seq_len = max_len
        self.vocab = self.tokenizer.get_vocab()
        self.vocab_size = len(self.vocab)
        self.id2word = self.tokenizer.ids_to_tokens
    
    def make_batch(self):
        input_batch,target_batch = [],[]
#         sentence_index = torch.randperm(len(self.sentences))[:batch_size]
        for index in tqdm(range(len(self.sentences))):
            sentence = self.sentences[index]
            if len(sentence.split(' ')) > 15:
                seq = self.tokenizer(sentence, return_tensors='pt',max_length=self.seq_len, truncation=True, padding='max_length')['input_ids'][0].tolist()

                input_batch.append(seq[:-1])
                target_batch.append(seq[1:])
                
        input_batch = torch.LongTensor(input_batch)
        target_batch = torch.LongTensor(target_batch)
        return input_batch,target_batch
    
class chinesebertCorpus():
    def __init__(self, sentences, max_len = 50):
        self.tokenizer = tfs.BertTokenizer.from_pretrained('bert-base-chinese')
        self.sentences = sentences
        self.seq_len = max_len
        self.vocab = self.tokenizer.get_vocab()
        self.vocab_size = len(self.vocab)
        self.id2word = self.tokenizer.ids_to_tokens
    
    def make_batch(self):
        input_batch,target_batch = [],[]
#         sentence_index = torch.randperm(len(self.sentences))[:batch_size]
        for index in tqdm(range(len(self.sentences))):
            sentence = self.sentences[index]
            if len(sentence) > 1:
                seq = self.tokenizer(sentence, return_tensors='pt',max_length=self.seq_len, truncation=True, padding='max_length')['input_ids'][0].tolist()

                input_batch.append(seq[:-1])
                target_batch.append(seq[1:])
                
        input_batch = torch.LongTensor(input_batch)
        target_batch = torch.LongTensor(target_batch)
        return input_batch,target_batch

In [5]:
corpus = bertCorpus(read_data('./wikitext-103/wiki.train.tokens')+read_data('./wikitext-103/wiki.test.tokens')+read_data('./wikitext-103/wiki.valid.tokens')+read_data('./wikitext-103/wiki.train.txt'))

In [6]:
aa = corpus.make_batch()

100%|███████████████████████████████| 1389458/1389458 [12:07<00:00, 1909.07it/s]


In [7]:
X_train_batch = []
Y_train_batch = []

In [8]:
bsize = batch_size
for i in range(int(len(aa[0])/bsize)):
    X_train_batch.append(aa[0][i*bsize:(i+1)*bsize])
    Y_train_batch.append(aa[1][i*bsize:(i+1)*bsize])

In [9]:
import torch.optim as optim
from tqdm import tqdm
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "mps"
model = GPT(corpus).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=corpus.vocab['[PAD]'])
optimizer = optim.Adam(model.parameters(),lr=0.001)
epochs=200
train_loss = []
min_loss = np.inf
accumulate = 8
for epoch in range(epochs):
    total_train_loss = []
    model.train()
    n=0
    for x_train_batch, y_train_batch in tqdm(zip(X_train_batch,Y_train_batch)):
        optimizer.zero_grad()
        dec_inputs,target_batch = x_train_batch,y_train_batch
        dec_inputs,target_batch = dec_inputs.to(device),target_batch.to(device)
        outputs=model(dec_inputs)
        loss = criterion(outputs.view(-1,corpus.vocab_size),target_batch.view(-1))
        loss = loss/accumulate
        loss.backward()
        if (n+1)%accumulate == 0:
            optimizer.step()
            optimizer.zero_grad()

        if (n+1)%1000==0:
            print("epoch:%d----loss:%.4f"%(epoch+1,loss))
        n+=1

        total_train_loss.append(loss.item())
    train_loss.append(np.mean(total_train_loss ))
    if train_loss[-1] < min_loss:
        min_loss = train_loss[-1]
        torch.save(model, './GPT.pkl')
        print('train_loss:%.4f,best_loss:%.4f----OK'%(train_loss[-1],min_loss))
    else:
        print('train_loss:%.4f,best_loss:%.4f'%(train_loss[-1],min_loss))

1001it [02:01,  8.08it/s]

epoch:1----loss:0.9008


2001it [03:58,  8.14it/s]

epoch:1----loss:0.9041


3001it [05:56,  8.23it/s]

epoch:1----loss:0.9195


4001it [07:54,  8.13it/s]

epoch:1----loss:0.8754


5001it [09:50,  8.24it/s]

epoch:1----loss:0.8982


6001it [11:47,  8.29it/s]

epoch:1----loss:0.8404


7001it [13:45,  8.25it/s]

epoch:1----loss:0.7911


8001it [15:41,  8.30it/s]

epoch:1----loss:0.8672


9001it [17:38,  8.78it/s]

epoch:1----loss:0.8176


10001it [19:36,  8.29it/s]

epoch:1----loss:0.7931


11001it [21:33,  8.23it/s]

epoch:1----loss:0.8053


12001it [23:31,  8.10it/s]

epoch:1----loss:0.7896


13001it [25:29,  8.20it/s]

epoch:1----loss:0.7824


14001it [27:27,  8.13it/s]

epoch:1----loss:0.7850


15001it [29:25,  8.15it/s]

epoch:1----loss:0.7910


16001it [31:23,  8.06it/s]

epoch:1----loss:0.8125


17001it [33:20,  8.27it/s]

epoch:1----loss:0.7662


18001it [35:18,  8.30it/s]

epoch:1----loss:0.7742


19001it [37:15,  8.20it/s]

epoch:1----loss:0.7828


20001it [39:13,  8.11it/s]

epoch:1----loss:0.7762


21001it [41:11,  8.21it/s]

epoch:1----loss:0.7934


22001it [43:08,  8.96it/s]

epoch:1----loss:0.7789


23001it [45:06,  8.18it/s]

epoch:1----loss:0.8259


24001it [47:04,  8.25it/s]

epoch:1----loss:0.7742


25001it [49:01,  8.21it/s]

epoch:1----loss:0.7097


26001it [50:59,  8.17it/s]

epoch:1----loss:0.5860


27001it [52:56,  8.30it/s]

epoch:1----loss:0.7719


28001it [54:54,  8.09it/s]

epoch:1----loss:0.7921


29001it [56:52,  8.96it/s]

epoch:1----loss:0.7330


30001it [58:49,  8.27it/s]

epoch:1----loss:0.6813


31001it [1:00:47,  8.19it/s]

epoch:1----loss:0.7513


32001it [1:02:45,  8.33it/s]

epoch:1----loss:0.7581


33001it [1:04:44,  8.10it/s]

epoch:1----loss:0.7222


34001it [1:06:43,  8.16it/s]

epoch:1----loss:0.7482


35001it [1:08:41,  8.12it/s]

epoch:1----loss:0.7031


36001it [1:10:40,  8.11it/s]

epoch:1----loss:0.7703


37001it [1:12:38,  8.17it/s]

epoch:1----loss:0.7693


37021it [1:12:41,  8.62it/s]

In [28]:
def gen_tanxin(model,input_str,max_len=5):
    model.eval()
    input_tokens = corpus.tokenizer(input_str, return_tensors='pt', truncation=True)['input_ids'][0].tolist()[:-1]
    if len(input_tokens) == 0:
        return 'fvv'
    output_tokens = input_tokens
    with torch.no_grad():
        for _ in range(max_len):
            device = 'mps'
            inputs = torch.LongTensor(output_tokens).unsqueeze(0).to(device)
            outputs = model(inputs)
            
            logits = outputs[:,-1,:]
            
            _,next_token = torch.topk(logits,1,dim=-1)
            
            if next_token.item() == corpus.vocab["[SEP]"]:
                break
            output_tokens.append(next_token.item())
    output_str = " ".join([corpus.id2word[token] for token in output_tokens])
    return output_str

def gen_beam(model,input_str,max_len=5,beam_width=5,repetition_penalty=1.2):
    model.eval()
    input_tokens = corpus.tokenizer(input_str, return_tensors='pt', truncation=True)['input_ids'][0].tolist()[:-1]
    if len(input_tokens) == 0:
        return 'fvv'
    
    candidates = [(input_tokens,0.0)]
    
    final_result = []
    with torch.no_grad():
        for _ in range(max_len):
            new_cands = []
            for cand,cand_score in candidates:
                device = 'mps'
                inputs = torch.LongTensor(cand).unsqueeze(0).to(device)
                outputs = model(inputs)
                logits = outputs[:,-1,:]
                
                for token in set(cand):
                    logits[0,token] /= repetition_penalty
                    
                logits[0,corpus.vocab['[PAD]']] = -np.inf
                
                scores,next_tokens = torch.topk(logits,beam_width,dim=-1)
                
                for score,next_token in zip(scores.squeeze(),next_tokens.squeeze()):
                    new_cand = cand+[next_token.item()]
                    new_score = cand_score-score.item()
                    if next_token.item() == corpus.vocab["[SEP]"]:
                        final_result.append((new_cand,new_score))
                    else:
                        new_cands.append((new_cand,new_score))
        
            candidates = sorted(new_cands,key=lambda x:x[1],reverse=True)[:beam_width]
            
    if final_result:
        best_cand,_ = sorted(final_result,key=lambda x:x[1],reverse=True)[0]
    else:
        best_cand,_ = sorted(candidates,key=lambda x:x[1],reverse=True)[0]
    output_str = " ".join([corpus.id2word[token] for token in best_cand])
    return output_str    

In [29]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('./GPT.pkl')
# model = GPT(corpus).to(device)
# model.load_state_dict(torch.load('./GPT.pkl'))

In [42]:
input_str = "you are the'"

In [43]:
gen_tanxin(model,input_str,max_len=10)

"[CLS] you are the ' s first episode , and < un ##k > ,"

In [44]:
gen_beam(model,input_str,max_len=10,beam_width=3,repetition_penalty=1.2)

'[CLS] you are the \' ll written , " < < [SEP]'