<a href="https://colab.research.google.com/github/bluetinue/Country_Name/blob/main/%E5%9F%BA%E4%BA%8EGRU%E7%9A%84seq2seq%E7%9A%84%E8%8B%B1%E8%AF%91%E6%B3%95%E6%A1%88%E4%BE%8B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title 导包
from google.colab import drive
drive.mount('/content/drive')

# 用于正则表达式
import re
# 用于构建网络结构和函数的torch工具包
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# torch中预定义的优化方法工具包
import torch.optim as optim
import time
# 用于随机生成数据
import random
import matplotlib.pyplot as plt

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
#@title 全局变量

#开始字符标注
SOS_TOKEN = 0
EOS_TOKEN = 1

#最大句子长度
MAX_LENGTH = 10

data_path = "/content/drive/MyDrive/NLP/data/eng-fra-v2.txt"

In [29]:
#@title 文本清洗工具函数
def normal2String(s):
  s1 = s.lower().strip()
  s2 = re.sub(r"([.!?])", r" \1 ", s1)
  s3 = re.sub(r"[^a-zA-Z.!?]+", r" ", s2)
  return s3.strip()

In [45]:
#@title 对原始数据进行预处理
#构建出[[英文,法文，.....的列表对]]
def wash_data():
  with open(data_path,"r",encoding="utf-8") as fr:
    lines = fr.readlines()
  #列表推导式
  my_pairs = [[ normal2String(s) for s in i.strip().split("\t")] for i in lines]

  #初始化词表，默认有开始和结束分隔符和初始长度
  english_word2index = {"SOS":0,"EOS":1}
  english_word2index_n = 2
  fres_word2index = {"SOS":0,"EOS":1}
  fres_word2index_n = 2

  #构造两个词表的word2index表
  for pair in my_pairs:
    for word in pair[0].split(" "):
      if word not in english_word2index:
        english_word2index[word] = len(english_word2index)
        #english_word2index[word] = english_word2index_n
        #english_word2index_n += 1

    for word in pair[1].split(" "):
      if word not in fres_word2index:
        fres_word2index[word] = len(fres_word2index)
        #fres_word2index[word] = fres_word2index_n
        #fres_word2index_n += 1

  #构造两个词表的index2word表
  english_index2word = {v:k for k,v in english_word2index.items()}
  fres_index2word = {v:k for k,v in fres_word2index.items()}
  return english_word2index,english_index2word,\
    len(english_word2index),fres_word2index,fres_index2word,\
    len(fres_index2word),my_pairs

In [59]:
english_word2index, english_index2word,  english_word_n, french_word2index, french_index2word, french_word_n, my_pairs = wash_data()

In [62]:
#@title 构建数据源对象
class SeqDataset(Dataset):
  def __init__(self,my_pairs):
    super().__init__()
    self.my_pairs = my_pairs
    self.sample_len = len(my_pairs)

  def __len__(self):
    return self.sample_len

  def __getitem__(self, index):
    index = min(max(0,index),self.sample_len-1)

    x = self.my_pairs[index][0]
    y = self.my_pairs[index][1]

    #文本索引张量化，给后续的embedding层处理
    x = [english_word2index[word] for word in x.split(" ")]
    x.append(EOS_TOKEN)
    tensor_x = torch.tensor(x,dtype=torch.long)

    y = [french_word2index[word] for word in y.split(" ")]
    y.append(EOS_TOKEN)
    tensor_y = torch.tensor(y,dtype=torch.long)
    return tensor_x,tensor_y


In [74]:
def use_dataset():
  my_dataset = SeqDataset(my_pairs)
  my_dataloader = DataLoader(dataset=my_dataset,batch_size=1,shuffle=True)
  return my_dataloader

In [76]:
my_dataloader = use_dataset()

In [75]:
#@title 构建基于GRU的编码器
class EncodeGru(nn.Module):
  def __init__(self,vocb_size,hidden_size):
    super().__init__()
    self.vocb_size = vocb_size
    self.hidden_size = hidden_size

    #将输入进embedding词嵌入层转换成词向量
    self.embedding = nn.Embedding(vocb_size,hidden_size)

    #实例化GRU层
    self.gru = nn.GRU(hidden_size,hidden_size,batch_first=True)

  def forward(self,vocb_size,hidden):
    #数据经过词嵌入层
    output = self.embedding(vocb_size)
    output,hidden = self.gru(output,hidden)
    return output,hidden

  def inithidden(self):
    return torch.zeros(1,1,self.hidden_size)

In [77]:
vocb_size = english_word_n
hidden_size = 256
encoder = EncodeGru(vocb_size,hidden_size)
for x,y in my_dataloader:
  hidden = encoder.inithidden()
  output,hidden = encoder(x,hidden)
  print(output.shape)
  print(hidden.shape)
  break

torch.Size([1, 6, 256])
torch.Size([1, 1, 256])


In [None]:
#@title 构建基于GRU的解码器

In [None]:
#@title 构建模型评估函数

In [None]:
#@title 构建模型测试函数