In [1]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import json
import collections
import re
import time
import random
import math
import sys
import numpy as np
import math
 
sys.setrecursionlimit(5000) 

device = 'cuda' if torch.cuda.is_available() else'cpu'

In [2]:
class My_data(Data.Dataset):
    def __init__(self,max_len=50,min_en_count=0,min_cn_count=0):
        self.max_len = max_len
        self.min_en_count = min_en_count
        self.min_cn_count = min_cn_count
        self.counter = None
        self.cn_itos = ['<SOS>','<EOS>','<UNK>']
        self.cn_stoi = {'<SOS>':0,'<EOS>':1,'<UNK>':2}
        self.cn_data = []
        self.en_itos = ['<SOS>','<EOS>','<UNK>']
        self.en_stoi = {'<SOS>':0,'<EOS>':1,'<UNK>':2}
        self.en_data = []
        self.num_cn_vocab = 3
        self.num_en_vocab = 3
        self.len = 0
        
    def get_raw_data_cn_en(self,file,js=False, divide=1, choose=1):
        all_cn = []
        all_en = []
        if js == False:
            with open(file) as f:
                lines = f.readlines()
                data_len = len(lines)
                k = int(data_len//divide)
                start = max((choose-1)*k,0)
                end = min(choose*k,data_len)
                for line in lines[start:end]:
                    cn,en = line.strip().split('\t')
                    en = en.lower()
                    en = re.sub(r"([,'.!?])", r" \1", en)
                    en = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'!?]+", r" ", en)
#                     cn = re.sub(r"([.!?])", r" \1", cn)
                    cn = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'‘’“”!?，。？！]+", r" ", cn)
                    en=en.split()
                    all_cn.append(cn)
                    all_en.append(en)
        if js == True:
            with open(file) as f:
                lines = f.readlines()
                data_len = len(lines)
                k = int(data_len//divide)
                start = max((choose-1)*k,0)
                end = min(choose*k,data_len)
                for line in lines[start:end]:
                    st = json.loads(line)
                    cn = st['chinese']
                    en = st['english'].lower()
                    en = re.sub(r"([,'.!?])", r" \1", en)
                    en = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'!?]+", r" ", en)
#                     cn = re.sub(r"([.!?])", r" \1", cn)
                    cn = re.sub(r"[^a-zA-Z\u4e00-\u9fa5.,'‘’“”!?，。？！]+", r" ", cn)
                    en=en.split()
                    all_cn.append(cn)
                    all_en.append(en)
        return all_cn,all_en
    
    def get_cn_en_stoi_itos(self,raw_cn,raw_en):
        cnCounter = collections.Counter([tk for line in raw_cn for tk in line])
        enCounter = collections.Counter([tk for line in raw_en for tk in line])
        cnCounter = dict(filter(lambda x: x[1]>=self.min_cn_count,cnCounter.most_common()))
        enCounter = dict(filter(lambda x: x[1]>=self.min_en_count,enCounter.most_common()))
        
        for tk,_ in cnCounter.items():
            self.cn_itos.append(tk)
        for tk,_ in enCounter.items():
            self.en_itos.append(tk)
        
        self.cn_stoi = {tk:idx for idx,tk in enumerate(self.cn_itos)}
        self.en_stoi = {tk:idx for idx,tk in enumerate(self.en_itos)}
        
        self.num_cn_vocab = len(self.cn_stoi)
        self.num_en_vocab = len(self.en_stoi)
        
    def get_from_vocab(self,cn_file,en_file):
        self.cn_itos = []
        self.en_itos = []
        
        with open(cn_file) as f:
            for line in f.readlines():
                self.cn_itos.append(line.replace('\n',''))
        with open(en_file) as f:
            for line in f.readlines():
                self.en_itos.append(line.replace('\n',''))
        
        self.cn_stoi = {tk: idx for idx,tk in enumerate(self.cn_itos)}
        self.en_stoi = {tk: idx for idx,tk in enumerate(self.en_itos)}
        
        self.num_cn_vocab = len(self.cn_stoi)
        self.num_en_vocab = len(self.en_stoi)
                
        
    def get_data(self,raw_cn,raw_en):
        self.cn_data = []
        self.en_data = []
        k = [0]*len(raw_cn)
        for idx,line in enumerate(raw_cn):
            if len(line) > self.max_len:
                k[idx] = 1
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.cn_itos:
                    tk = '<UNK>'
                temp.append(self.cn_stoi[tk])
            temp.append(1)
            self.cn_data.append(temp)
            
        for idx,line in enumerate(raw_en):
            if k[idx] == 1:
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.en_itos:
                    tk = '<UNK>'
                temp.append(self.en_stoi[tk])
            temp.append(1)
            self.en_data.append(temp)
        self.len = len(self.cn_data)
        
    def append_data(self,raw_cn,raw_en):
        k = [0]*len(raw_cn)
        for idx,line in enumerate(raw_cn):
            if len(line) > self.max_len:
                k[idx] = 1
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.cn_itos:
                    tk = '<UNK>'
                temp.append(self.cn_stoi[tk])
            temp.append(1)
            self.cn_data.append(temp)
            
        for idx,line in enumerate(raw_en):
            if k[idx] == 1:
                continue
            temp = []
            temp.append(0)
            for tk in line:
                if tk not in self.en_itos:
                    tk = '<UNK>'
                temp.append(self.en_stoi[tk])
            temp.append(1)
            self.en_data.append(temp)
        self.len = len(self.cn_data)
        
    def do_all(self,file,js=False):
        cn,en = self.get_raw_data_cn_en(file=file,js=js)
        self.get_cn_en_stoi_itos(cn,en)
        self.get_data(cn,en)
        
    
    def get_data_pair(self,idx):
        return self.cn_data[idx],self.en_data[idx]
    
    def __getitem__(self,idx):
        return self.get_data_pair(idx)
    
    def __len__(self):
        return self.len
        

In [3]:
def batchify(data):
    src_max_len = max(len(src) for src,_ in data)
    tgt_max_len = max(len(tgt) for _,tgt in data)
    src, tgt1, tgt2, s_mask, t_mask, loss_mask= [],[],[],[],[],[]
    
    for s, t in data:
        slen = len(s)
        tlen = len(t)
        s1 = s + [0]*(src_max_len-slen)
        t1 = t[:-1] + [0]*(tgt_max_len-tlen)
        t2 = t[1:] + [0]*(tgt_max_len-tlen)
        s1_mask = [0]*slen + [1]*(src_max_len-slen)
        t1_mask = [0]*(tlen-1) + [1]*(tgt_max_len-tlen)
        loss1_mask = [1]*(tlen-1) + [0]*(tgt_max_len-tlen)
        src.append(s1)
        tgt1.append(t1)
        tgt2.append(t2)
        s_mask.append(s1_mask)
        t_mask.append(t1_mask)
        loss_mask.append(loss1_mask)
    
    return (torch.tensor(src).long(),
            torch.tensor(tgt1).long(),
            torch.tensor(tgt2).long(),
            torch.tensor(s_mask).float(),
            torch.tensor(t_mask).float(),
            torch.tensor(loss_mask).float())

# data_iter = Data.DataLoader(dataset,batch_size,shuffle=True,collate_fn=batchify)
    

In [4]:
class Warmup:
    def __init__(self,optimizer_w, warmup_step,d_model,step_size=1,last_epoch=-1,verbose=False):
        self.lambda1 = lambda epoch_w: math.pow(d_model+1e-10,-0.5)*min(math.pow(epoch_w+1e-10,-0.5),epoch_w*math.pow(warmup_step+1e-10,-1.5))
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_w,lr_lambda=self.lambda1,
                                                          last_epoch=last_epoch, verbose=verbose)
    
    def step(self):
        self.scheduler.step()

In [5]:
class Encoder_2(nn.Module):
    def __init__(self,num_vocab,d_model=512,nhead=2,num_layers=2,dropout=0.1):
        super(Encoder_2,self).__init__()
        self.embedding = nn.Embedding(num_vocab,d_model)
        self.MH = nn.TransformerEncoderLayer(d_model=d_model,nhead=nhead,dropout=dropout)
        self.all_MH = nn.TransformerEncoder(self.MH,num_layers)
        self.position_embedding = self.get_position_embedding(d_model,100).to(device)
                                       
    def forward(self,myinput,mask=None):
        X = self.embedding(myinput)
        X += self.position_embedding[:X.shape[1]]
        X = X.permute(1,0,2)
        #X shaep(L, Batch_size, d_model)
        out = self.all_MH(X,src_key_padding_mask=mask)
        return out
    
    def get_position_embedding(self,d_model,max_len):
        table = torch.empty(max_len,d_model)
        for position in range(max_len):
            for i in range(d_model):
                table[position,i] = position/10000**(i/d_model)
        table[:,0::2] = torch.sin(table[:,0::2])
        table[:,1::2] = torch.sin(table[:,1::2])
        return table.float()
        

In [6]:
class Decoder_2(nn.Module):
    def __init__(self,num_vocab,d_model=512,nhead=2,num_layers=2,dropout=0.1):
        super(Decoder_2,self).__init__()
        self.num_vocab = num_vocab
        self.embedding = nn.Embedding(num_vocab, d_model)
        self.MH = nn.TransformerDecoderLayer(d_model=d_model,nhead=nhead,dropout=dropout)
        self.all_MH = nn.TransformerDecoder(self.MH,num_layers)
        self.dense = nn.Linear(d_model,num_vocab)
        self.position_embedding = self.get_position_embedding(d_model,100).to(device)
        
    def forward(self,target,memory,t_mask=None,m_mask=None,tgt_mask=None):
        X = self.embedding(target)
        X += self.position_embedding[:X.shape[1]]
        X = X.permute(1,0,2)
        Y = self.all_MH(X,memory,tgt_key_padding_mask=t_mask,memory_key_padding_mask=m_mask,tgt_mask=tgt_mask)
        out = self.dense(Y.permute(1,0,2))
        return out
    
    def get_position_embedding(self,d_model,max_len):
        table = torch.empty(max_len,d_model)
        for position in range(max_len):
            for i in range(d_model):
                table[position,i] = position/10000**(i/d_model)
        table[:,0::2] = torch.sin(table[:,0::2])
        table[:,1::2] = torch.sin(table[:,1::2])
        return table.float()

In [7]:
def get_decoder_mask(L):
    return torch.from_numpy(np.triu(np.ones(L),k=1)).bool()

In [8]:
class Myplot:
    def __init__(self,step_size=10):
        self.all_loss = []
        self.cum_loss = 0
        self.step = 0
        self.step_size = step_size
        
    def forward(self,loss):
        self.cum_loss += loss
        self.step += 1
        if self.step%self.step_size == 0:
            self.all_loss.append(self.cum_loss/self.step)
            self.cum_loss = 0
            self.step = 0
            
    def show(self):
        import matplotlib.pyplot as plt
        import matplotlib.ticker as ticker
        import numpy as np
        plt.figure()
        fig, ax = plt.subplots()
        loc = ticker.MultipleLocator(base=0.2) # put ticks at regular intervals
        ax.yaxis.set_major_locator(loc)
        plt.plot(self.all_loss)
        plt.show()

In [9]:
def train(num_epochs,encoder,decoder,en_optim,de_optim,en_scheduler,de_scheduler,
          criterion,train_iter,myplt,teacher_force_ratio=0.9,clip=5):
    
    import time 
    
    start = time.time()
    per_n_steps = 300
    steps = 0
    arang_list = torch.arange(512)
    ones_list = torch.ones(512)
    
    encoder.train()
    decoder.train()
    
    for epoch in range(1,num_epochs+1):
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        
        sum_loss = 0
        k = 0
        
        for src, tgt1, tgt2, s_mask, t_mask, loss_mask in train_iter:
            
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
            
            steps += 1
            
            src = src.to(device)
            tgt1 = tgt1.to(device)
            tgt2 = tgt2.to(device)
            s_mask = s_mask.to(device)
            t_mask = t_mask.to(device)

            #else do tecacher_forcing
            changes = random.choices(arang_list[1:tgt1.shape[-1]],ones_list[1:tgt1.shape[-1]],k=int(tgt1.shape[-1]*(1-teacher_force_ratio)))
            for idx in changes:
                tgt1[:,idx] = random.randint(3,decoder.num_vocab-1)
                    
                
            
            en_scheduler.step()
            de_scheduler.step()
            en_optim.zero_grad()
            de_optim.zero_grad()
            memory = encoder(src,s_mask)
            tgt_mask = get_decoder_mask(tgt2.shape[1]).to(device)
            
#             print(tgt1.shape)
#             print(memory.shape)
#             print(t_mask.shape)
#             print(s_mask.shape)
#             print(tgt_mask.shape)
            
            pred = decoder(tgt1,memory,t_mask,s_mask,tgt_mask=tgt_mask)
#             loss1 = criterion(pred,tgt2.to(device),num_classes=decoder.num_vocab,mask=loss_mask.to(device))
            loss_mean = (criterion(pred.permute(0,2,1),tgt2)*loss_mask.to(device)).mean()
#             loss_mean = loss1.mean()
            
            loss_mean.backward()
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
            en_optim.step()
            de_optim.step()
            
            sum_loss += loss_mean.item()
            
            myplt.forward(loss_mean.item())
         
            if steps%per_n_steps == 0:
                passtime = time.time() - start
                min = passtime//60
                s = passtime-min*60
                print(f'time :{min} m {s} s, epoch: {epoch} steps: {steps} ;loss {loss_mean.item()}')
        
        torch.save(encoder.state_dict(),'TRANS05en_1.pth')
        torch.save(decoder.state_dict(),'TRANS05de_1.pth')

        print(' ',sum_loss)
            

In [12]:
trainfile = 'translation2019zh/translation2019zh_train.json'
trainfile2 = 'cn-eng.txt'


trainset = My_data(min_en_count=8,min_cn_count=0)
# 3333trainset.do_all(trainfile, js=True)
# raw_cn, raw_en = trainset.get_raw_data_cn_en(trainfile,js=True,divide=50,choose=1)
# raw_cn2, raw_en2 = trainset.get_raw_data_cn_en(trainfile2,js=False)
# raw_cn3 = raw_cn+raw_cn2
# raw_en3 = raw_en+raw_en2
# trainset.get_cn_en_stoi_itos(raw_cn3,raw_en3)


cn_file = 'cn_vocab.txt'
en_file = 'en_vocab.txt'
trainset.get_from_vocab(cn_file,en_file)



In [13]:
lr = 0.03

myplt = Myplot(50)

# def __init__(self,num_vocab,d_model=512,nhead=2,num_layers=2,dropout=0.1):

encoder = Encoder_2(trainset.num_cn_vocab, d_model=512, nhead=4, num_layers=4).to(device)
decoder = Decoder_2(trainset.num_en_vocab, d_model=512, nhead=4, num_layers=4).to(device)

en_optim = torch.optim.Adam(encoder.parameters(),lr=lr)
de_optim = torch.optim.Adam(decoder.parameters(),lr=lr)

en_scheduler = Warmup(en_optim,4000,512)
de_scheduler = Warmup(de_optim,4000,512)

# class Warmup:
#     def __init__(self,optimizer, warmup_step,d_model,step_size=1,last_epoch=-1,verbose=False):

In [14]:
#模型初始化
def weights_init(m):
    if isinstance(m,(nn.Linear)):
#         nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias,0.0)
        nn.init.orthogonal_(m.weight)
        #正交初始化
def weights_init2(m):
    if isinstance(m,(nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias,0.0)
    

encoder.apply(weights_init2)
decoder.apply(weights_init2)
# # attentioner.apply(weights_init2)
# enc = torch.load('TRANS03en.pth')
# dec = torch.load('TRANS03de.pth')
# encoder.load_state_dict(torch.load('TRANS07en.pth',map_location=device))
# decoder.load_state_dict(torch.load('TRANS07de.pth',map_location=device))

Decoder_2(
  (embedding): Embedding(25769, 512)
  (MH): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (dropout3): Dropout(p=0.1, inplace=False)
  )
  (all_MH): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_at

In [15]:

criterion = nn.CrossEntropyLoss(reduction='none')

In [16]:
num_epochs = 50
clip = 5
teacher_force_ratio = 0.9

In [19]:
lr = 1

en_optim = torch.optim.Adam(encoder.parameters(),lr=lr)
de_optim = torch.optim.Adam(decoder.parameters(),lr=lr)

en_scheduler = Warmup(en_optim,8000,512)
de_scheduler = Warmup(de_optim,8000,512)

In [21]:
# raw_cn, raw_en = trainset.get_raw_data_cn_en(trainfile2,js=False)
# trainset.get_data(raw_cn,raw_en)

# # raw_cn, raw_en = trainset.get_raw_data_cn_en(trainfile,js=True,divide=100,choose=1)
# file3 = 't22'
# raw_cn, raw_en = trainset.get_raw_data_cn_en(file3,js=False)
# trainset.append_data(raw_cn,raw_en)
len(trainset)

0

In [22]:
# batch_size = 256
# train_iter = Data.DataLoader(trainset,batch_size,shuffle=True,collate_fn=batchify,num_workers=8)
# train(num_epochs,encoder,decoder,en_optim,de_optim,en_scheduler,
#       de_scheduler,criterion,train_iter,myplt,teacher_force_ratio,clip=clip)

In [None]:
####train 2v
import time 
    
start = time.time()
per_n_steps = 100
steps = 0
arang_list = torch.arange(512)
ones_list = torch.ones(512)
    
encoder.train()
decoder.train()
    
for epoch in range(1,num_epochs+1):
    
    if epoch%5 == 0 or epoch==1:
#         raw_cn, raw_en = trainset.get_raw_data_cn_en(trainfile2,js=False)
#         trainset.get_data(raw_cn,raw_en)
        raw_cn, raw_en = trainset.get_raw_data_cn_en(trainfile,js=True,divide=30,choose=random.randint(1,29))
        trainset.get_data(raw_cn,raw_en)
        
        batch_size = 256
        train_iter = Data.DataLoader(trainset,batch_size,shuffle=True,collate_fn=batchify,num_workers=8)
    
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()
        
    sum_loss = 0
    k = 0
        
    for src, tgt1, tgt2, s_mask, t_mask, loss_mask in train_iter:
            
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
        torch.cuda.empty_cache()
            
        steps += 1
            
        src = src.to(device)
        tgt1 = tgt1.to(device)
        tgt2 = tgt2.to(device)
        s_mask = s_mask.to(device)
        t_mask = t_mask.to(device)
        
            #else do tecacher_forcing
        changes = random.choices(arang_list[1:tgt1.shape[-1]],ones_list[1:tgt1.shape[-1]],k=int(tgt1.shape[-1]*(1-teacher_force_ratio)))
        for idx in changes:
            tgt1[:,idx] = random.randint(3,decoder.num_vocab-1)
                    
                
            
        en_scheduler.step()
        de_scheduler.step()
        en_optim.zero_grad()
        de_optim.zero_grad()
        memory = encoder(src,s_mask)
        tgt_mask = get_decoder_mask(tgt2.shape[1]).to(device)
            
        pred = decoder(tgt1,memory,t_mask,s_mask,tgt_mask=tgt_mask)
#             loss1 = criterion(pred,tgt2.to(device),num_classes=decoder.num_vocab,mask=loss_mask.to(device))
        loss_mean = (criterion(pred.permute(0,2,1),tgt2)*loss_mask.to(device)).mean()
#             loss_mean = loss1.mean()
            
        loss_mean.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
        en_optim.step()
        de_optim.step()
            
        sum_loss += loss_mean.item()
            
        myplt.forward(loss_mean.item())
         
        if steps%per_n_steps == 0:
            passtime = time.time() - start
            mins = passtime//60
            s = passtime-mins*60
            print(f'time :{mins} m {s} s, epoch: {epoch} steps: {steps} ;loss {loss_mean.item()}')
        
    torch.save(encoder.state_dict(),'TRANS05en_1.pth')
    torch.save(decoder.state_dict(),'TRANS05de_1.pth')

    print(' ',sum_loss)
            

In [None]:
myplt.show()