In [1]:
import torch
import random
import numpy as np
import os
from torch import nn
from torch.utils.data import Dataset,DataLoader
import matplotlib.pyplot as plt

corpus_file = '../hw2.1_corpus.txt'

In [2]:
# Load Corpus

corpus = []
with open(corpus_file,'r') as f:
    for row in f:
        row = row.replace('\n','')
        row = [w for w in row]
        corpus.append(row)

In [3]:
# Use list to guarantee the embedded index for each word are same everytime

words = []
words_set = set()
for ws in corpus:
    for w in ws:
        if w not in words_set:
            words_set.add(w)
            words.append(w)

In [4]:
import re

class Embedding:

    def __init__(self,words=None,dim=300):
        self.word_dict = {}
        self.word_list = []
        self.emb_dim = dim
        self.maxPOS = 12      # Maximum POS
        self.addition_words = ['<PAD>','<SOS>','<EOS>','<UNK>']
        
        for POS in range(self.maxPOS):
            self.addition_words.append(str(POS))
        
        for addition_word in self.addition_words:
            if addition_word not in words:
                self.word_dict[addition_word] = len(self.word_list)
                self.word_list.append(addition_word)
        
        for word in words:
            if word not in self.word_dict:
                self.word_dict[word] = len(self.word_list)
                self.word_list.append(word)
                
        self.vectors = torch.nn.init.uniform_(
                torch.empty(len(self.word_dict),dim))
        
    def to_index(self, word):
        # single word tokenize
        if word not in self.word_dict:
            return self.word_dict['<UNK>']
        
        return self.word_dict[word]
        
    def tokenize(self, words):
        # whole sentence tokenize
        return [self.to_index(w) for w in words]
    
    def to_word(self, idx):
        
        return self.word_list[idx]
        
    def unTokenize(self,ids):
        
        return [self.to_word(idx) for idx in ids]
        
    def get_vocabulary_size(self):
        return self.vectors.shape[0]
    
    def get_dim(self):
        return self.vectors.shape[1]

In [5]:
# establish embedder to tokenize
embedder = Embedding(words=words,dim=300)

PAD = embedder.to_index('<PAD>')
SOS = embedder.to_index('<SOS>')
EOS = embedder.to_index('<EOS>')

In [6]:
all_set = []
counting_table = np.zeros((embedder.maxPOS+1,embedder.maxPOS+1))

for former,latter in zip(corpus[:-1],corpus[1:]):
    
    n = len(latter)
    
    addition_tokens = []
    
    former = ['<SOS>'] + former + ['<EOS>'] 
    latter = ['<SOS>'] + latter + ['<EOS>']
    
    selected_idx = random.randint(a=1,b=min(n,embedder.maxPOS))
    
    while 1:
        a = random.randint(a=0,b=selected_idx)
        b = selected_idx - a
        
        if (a,b) not in [(4,2),(2,4)]:
            break
        
    counting_table[a,b] += 1
    
    addition_tokens.append(latter[selected_idx])
    if a > 0:    
        addition_tokens.append(str(a))
    if b > 0:
        addition_tokens.append(str(b))
    
    former = former + addition_tokens
    
    all_set.append((former,latter))
    
print(counting_table)

[[    0. 53060. 33876. 24074. 18301. 13818. 14125.  7228.  5013.  3371.
   2188.  1439.   944.]
 [53167. 33951. 24179. 18326. 13813. 14373.  7309.  5108.  3298.  2226.
   1414.   934.     0.]
 [33656. 24099. 18229. 13667.     0.  7505.  5008.  3375.  2225.  1480.
    921.     0.     0.]
 [24445. 18210. 13826. 14295.  7319.  5192.  3356.  2262.  1437.   919.
      0.     0.     0.]
 [18264. 13736.     0.  7336.  5142.  3450.  2280.  1409.   906.     0.
      0.     0.     0.]
 [13800. 14103.  7429.  5081.  3443.  2217.  1447.   873.     0.     0.
      0.     0.     0.]
 [14332.  7519.  4892.  3346.  2291.  1415.   915.     0.     0.     0.
      0.     0.     0.]
 [ 7421.  4993.  3497.  2170.  1459.   880.     0.     0.     0.     0.
      0.     0.     0.]
 [ 5130.  3307.  2342.  1382.   909.     0.     0.     0.     0.     0.
      0.     0.     0.]
 [ 3375.  2219.  1400.   887.     0.     0.     0.     0.     0.     0.
      0.     0.     0.]
 [ 2230.  1413.   868.     0.     0.    

In [7]:
from sklearn.model_selection import train_test_split

train_set,valid_set = train_test_split(all_set,test_size=0.2,random_state=42)

print(len(all_set),len(train_set),len(valid_set))

741714 593371 148343


In [8]:
# list of turple : [(x0,y0),(x1,y1),(x2,y2),....]
all_set[0:5]

[(['<SOS>',
   '心',
   '疼',
   '你',
   '还',
   '没',
   '挣',
   '脱',
   '思',
   '念',
   '的',
   '囚',
   '禁',
   '<EOS>',
   '在',
   '1',
   '1'],
  ['<SOS>',
   '他',
   '在',
   '你',
   '一',
   '段',
   '难',
   '忘',
   '远',
   '行',
   '最',
   '后',
   '却',
   '离',
   '去',
   '<EOS>']),
 (['<SOS>',
   '他',
   '在',
   '你',
   '一',
   '段',
   '难',
   '忘',
   '远',
   '行',
   '最',
   '后',
   '却',
   '离',
   '去',
   '<EOS>',
   '依',
   '3',
   '1'],
  ['<SOS>', '你', '无', '力', '依', '靠', '在', '我', '这', '里', '<EOS>']),
 (['<SOS>',
   '你',
   '无',
   '力',
   '依',
   '靠',
   '在',
   '我',
   '这',
   '里',
   '<EOS>',
   '被',
   '1',
   '3'],
  ['<SOS>', '隔', '着', '刚', '被', '雨', '淋', '湿', '的', '玻', '璃', '<EOS>']),
 (['<SOS>',
   '隔',
   '着',
   '刚',
   '被',
   '雨',
   '淋',
   '湿',
   '的',
   '玻',
   '璃',
   '<EOS>',
   '问',
   '2'],
  ['<SOS>', '你', '问', '了', '我', '到', '底', '爱', '在', '哪', '里', '<EOS>']),
 (['<SOS>',
   '你',
   '问',
   '了',
   '我',
   '到',
   '底',
   '爱',
   '在',
   '哪',
   '里',
   '<EOS

In [9]:
class SentDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        
        return self.data[index]
        
    def collate_fn(self, datas):
        # get max length in this batch
        max_data_len = max([len(data[0]) for data in datas])
        max_label_len = max([len(data[1]) for data in datas])
        
        batch_x = []
        batch_y = []
        len_x = []
        len_y = []
        batch_y_ = []
        
        
        for data,label in datas:
            
            len_x.append(len(data))
            len_y.append(len(label))
            
            # Tokenize
            pad_data = embedder.tokenize(data)
            pad_label = embedder.tokenize(label)
            
            # Padding data and label
            if len(data) < max_data_len:
                pad_data.extend([PAD] * (max_data_len-len(data)))
            if len(label) < max_label_len:
                pad_label.extend([PAD] * (max_label_len-len(label)))
                
                
            batch_x.append(pad_data)
            batch_y.append(pad_label)
            
            # generate y_
            focus_designate = [PAD] * len(pad_label)
            focus_designate[0] = SOS
            focus_designate[pad_label.index(EOS)] = EOS
            
            
            idx1 = data[data.index('<EOS>')+2]
            
            
            if data.index('<EOS>')+3 < len(data):
                idx2 = data[data.index('<EOS>')+3]
                idx = int(idx1) + int(idx2)
            else:
                idx = int(idx1)
            focus_designate[idx] = pad_label[idx]
            
            batch_y_.append(focus_designate)
            
            
        return torch.LongTensor(batch_x), torch.LongTensor(len_x), torch.LongTensor(batch_y), len_y, torch.LongTensor(batch_y_)

In [10]:
# For Validate~~~~

dataset = SentDataset(train_set)
dataloader = DataLoader(dataset=dataset,
                        batch_size=4,
                        shuffle=True,
                        collate_fn=dataset.collate_fn,
                        num_workers=0)
for x,x_len,y,y_len,y_ in dataloader:
    print('Sentence lenght:',x_len,y_len,'\n')
    
    for xi,yi,y_i in zip(x,y,y_):      
        
        print(embedder.unTokenize(xi))
        print(embedder.unTokenize(yi))
        print(embedder.unTokenize(y_i),'\n')
    
    break


Sentence lenght: tensor([ 9, 12,  9, 11]) [9, 8, 13, 13] 

['<SOS>', '若', '心', '有', '城', '<EOS>', '伤', '1', '5', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '便', '不', '会', '感', '受', '伤', '疼', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '伤', '<PAD>', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'] 

['<SOS>', '那', '取', '名', '回', '忆', '的', '时', '光', '<EOS>', '下', '2']
['<SOS>', '留', '下', '你', '我', '模', '样', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '<PAD>', '下', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'] 

['<SOS>', '除', '非', '你', '说', '<EOS>', '憾', '4', '7', '<PAD>', '<PAD>', '<PAD>']
['<SOS>', '离', '开', '我', '你', '从', '不', '曾', '觉', '得', '遗', '憾', '<EOS>']
['<SOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '憾', '<EOS>'] 

['<SOS>', '就', '是', '你', '爱', '的', '宽', '容', '<EOS>', '温', '5', '<PAD>']
['<SOS>', '你', '眼', '底', '的', '

In [11]:
from torch.nn.utils.rnn import  pack_padded_sequence, pad_packed_sequence

class Encoder(nn.Module):

    def __init__(self, vocab_size, embedding_size, output_size):
        
        super(Encoder, self).__init__()

        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size,embedder.get_dim())
        self.embedding.weight = nn.Parameter(embedder.vectors)
        self.gru = nn.GRU(embedding_size, output_size,batch_first=True, bias=False)

    def forward(self, input_seqs, input_lengths, hidden=None):
        
        # Sort mini-batch by input_lengths
        sorted_input_lengths, indices = torch.sort(input_lengths,descending=True)
        _, desorted_indices = torch.sort(indices, descending=False)
        input_seqs = input_seqs[indices]
        
        # Encoder work
        embedded = self.embedding(input_seqs)
        packed = pack_padded_sequence(embedded, sorted_input_lengths.cpu().numpy(), batch_first=True)
        packed_outputs, hidden = self.gru(packed, hidden)
        outputs, output_lengths = pad_packed_sequence(packed_outputs,batch_first=True)
        
        # Desort mini-batch
        outputs = outputs[desorted_indices]
        hidden = hidden[:,desorted_indices]
        
        return outputs, hidden

In [12]:
class Decoder(nn.Module):

    def __init__(self, hidden_size, output_size, teacher_forcing_ratio=0.5):
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.output_size = output_size
        self.embedding = nn.Embedding(embedder.get_vocabulary_size(),embedder.get_dim()) # Unused
        self.embedding.weight = nn.Parameter(embedder.vectors)
        self.cell = nn.GRUCell(embedder.get_dim(), hidden_size, bias=False)
        self.clf = nn.Linear(hidden_size, output_size, bias=False)
        
        if hidden_size == embedder.vectors.T.shape[0]:
            self.clf.weight = nn.Parameter(embedder.vectors)

        self.log_softmax = nn.LogSoftmax(dim=1)  # work with NLLLoss

        self.teacher_forcing_ratio = teacher_forcing_ratio

    def forward_step(self, inputs, hidden):
        
        # Unused
        embedded = self.embedding(inputs)
        # For research : all x to 0
        embedded = torch.zeros_like(embedded)
        
        hidden = self.cell(embedded, hidden) # [B,Hidden_dim]
        clf_output = self.clf(hidden) # [B,Output_dim]
        output = self.log_softmax(clf_output)

        return output, hidden

    def forward(self, context_vector, target_vars, target_lengths):

        batch_size = context_vector.shape[1]
        
        decoder_input = torch.LongTensor([SOS] * batch_size).to(device)
        decoder_hidden = context_vector.squeeze(0)

        if target_lengths is None:
            max_target_length = 50
        else:
            max_target_length = max(target_lengths)
        decoder_outputs = []
        decoder_hiddens = []

        use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
        
        for t in range(max_target_length):    
            
            decoder_outputs_on_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_outputs_on_t)
            decoder_hiddens.append(decoder_hidden)
            
            # Take input for next GRU iteration
            if use_teacher_forcing :
                decoder_input = target_vars[:,t]
            else:
                decoder_input = decoder_outputs_on_t.argmax(-1)
            
            # Early Stop when all predict <EOS> 
            if torch.all(decoder_input==EOS) and target_lengths is None and self.train() == False:
                break
            
        # Stack output of each word at dimension 2
        decoder_outputs = torch.stack(decoder_outputs,dim=2)
        # Stack hidden of each timestep at dimension 1
        decoder_hiddens = torch.stack(decoder_hiddens,dim=1)
        
        return decoder_outputs, decoder_hiddens

In [13]:
class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, input_seqs, input_lengths, target_seqs=None, target_lengths=None):
        outputs, hidden = encoder(input_seqs, input_lengths)
        outputs, hiddens = decoder(hidden, target_seqs, target_lengths)
        return outputs,hiddens

In [14]:
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook as tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

context_dim = 128

encoder = Encoder(embedder.get_vocabulary_size(),embedder.get_dim(),output_size=context_dim)
decoder = Decoder(context_dim,embedder.get_vocabulary_size(),0.5)
model = Seq2Seq(encoder,decoder)
model.to(device)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(6575, 300)
    (gru): GRU(300, 128, bias=False, batch_first=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(6575, 300)
    (cell): GRUCell(300, 128, bias=False)
    (clf): Linear(in_features=128, out_features=6575, bias=False)
    (log_softmax): LogSoftmax()
  )
)

# Training

In [None]:
'''def run_epoch(epoch,dataset,isTraining):
    
    if isTraining:
        model.train()
    else:
        model.eval()
        
    dataloader = DataLoader(dataset=dataset,
                            batch_size=32,
                            shuffle=True,
                            collate_fn=dataset.collate_fn,
                            num_workers=0)
    
    if isTraining:
        desc='Train {}'
    else:
        desc='Valid {}'
    
    trange = tqdm(enumerate(dataloader), total=len(dataloader),desc=desc.format(epoch))
    
    loss=0
    acc = 0
    
    for i,(x,x_len,y,y_len,y_) in trange:
        
        x = x.to(device)
        y = y.to(device)
        y_ = y_.to(device)
        
        # outputs : [b,emb,s] , hiddens : [b,s,hidden]
        outputs,hiddens = model(x,x_len,y,y_len)
        
        idx = y_>2
        tf_map = y_[idx] == outputs.argmax(1)[idx]
        batch_acc = tf_map.sum().cpu().float().numpy()/len(tf_map)
        acc += batch_acc
        
        batch_loss_all = criterion(outputs, y)
        batch_loss_designated = criterion(outputs, y_)
        batch_loss = (1-focus_ratio)*batch_loss_all + focus_ratio*batch_loss_designated
        
        if isTraining:
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
        
        loss += batch_loss.item()
        
        trange.set_postfix({'loss':loss/(i+1),'accuracy':acc/(i+1)})
        
        if isTraining:
            history_loss['train'].append(batch_loss.item())
            history_acc['train'].append(batch_acc)
        else:
            history_loss['valid'].append(batch_loss.item())
            history_acc['valid'].append(batch_acc)'''

In [None]:
'''# Training


dataset_all = SentDataset(all_set)
dataset_train = SentDataset(train_set)
dataset_valid = SentDataset(valid_set)

criterion = torch.nn.NLLLoss(ignore_index=PAD, size_average=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
max_epoch = 20
focus_ratio = 0.5

history_loss = {'train':[],'valid':[]}
history_acc = {'train':[],'valid':[]}


for epoch in range(max_epoch):
    
    # Training
    run_epoch(epoch,dataset=dataset_train,isTraining=True)
    
    # Validation
    run_epoch(epoch,dataset=dataset_valid,isTraining=False)
    
    # Saving
    if not os.path.exists('model'):
        os.makedirs('model')
    torch.save(model.state_dict(), 'model/model.pkl.{}'.format(epoch))'''

# Plot (Loss and acc)

In [None]:
'''import matplotlib.pyplot as plt

modes = ['train', 'valid']
recs = [history_loss, history_acc]
names = ['Loss', 'Accuracy']

values = []
for mode in modes:
    v = []
    for rec in recs:
        v.append(rec[mode])
    values.append(v)
 
plt.figure(figsize=(32, 4))
plt.subplots_adjust(left=0.02, right=0.999)
for r, name in enumerate(names):
    plt.subplot(1, len(recs), r+1)
    for m in range(len(modes)):
        plt.plot(values[m][r])
    plt.title(name)
    plt.legend(modes)
    plt.xlabel('iteration')
    plt.show()
#plt.savefig('figure.png', dpi=100)'''

# Inference Test Data
### Define test data dataloader

In [15]:
class TestDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]
        
    def collate_fn(self, datas):
        max_data_len = max([len(data) for data in datas])
        batch_x = []
        len_x = []
        
        for data in datas:
            len_x.append(len(data))
            pad_data = [embedder.to_index(w) for w in data]
            if len(data) < max_data_len:
                pad_data.extend([PAD] * (max_data_len-len(data)))
            batch_x.append(pad_data)

        return torch.LongTensor(batch_x), torch.LongTensor(len_x)

## Load pre-trained model

In [16]:
path_pkl = '../pre-train/model.pkl.2-2-additional_without42&24'
model.load_state_dict(torch.load(path_pkl))
model.decoder.teacher_forcing_ratio = 0.0
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(6575, 300)
    (gru): GRU(300, 128, bias=False, batch_first=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(6575, 300)
    (cell): GRUCell(300, 128, bias=False)
    (clf): Linear(in_features=128, out_features=6575, bias=False)
    (log_softmax): LogSoftmax()
  )
)

## 做一個指定位置的context hidden

In [17]:
context_spos = ['<SOS>', '让', '我', '拥', '有', '了', '恬', '静', '的', '<EOS>', '静','3','2']
len_context = torch.LongTensor([len(context_spos)])
context_spos = [embedder.to_index(w) for w in context_spos]
context_spos = torch.LongTensor([context_spos]).to(device)
context_spos = model.encoder(context_spos,len_context)[1] 
context_spos = context_spos.squeeze()# context hidden

In [18]:
context_spos.shape

torch.Size([128])

# Deconstruct

In [19]:
import torch.nn.functional as F

def Deconstruction(model,x,x_len):
    
    # Encoder 
    encoder_hiddens, context = model.encoder(x,x_len)
    context = context.squeeze(0)
    
    # replace cells in context hidden by context_spos
    cells = []
    for c in cells:
        context[:,c] = context_spos[c]
    
    # Decoder
    decoder_hidden = context.squeeze(0)
    
    # Collection signal
    decoder_outputs = []
    decoder_hiddens = []
    decoder_resetGates = []
    decoder_updateGates = []
    decoder_newGates = []
    
    while True:
        
        # GRU Cell
        gru = model.decoder.cell
        
        U_h = F.linear(decoder_hidden, gru.weight_hh)
        Ur_h, Uz_h, Un_h = U_h.chunk(3, 1)
        reset_gate = torch.sigmoid(Ur_h)
        update_gate = torch.sigmoid(Uz_h)
        new_gate = torch.tanh(reset_gate * Un_h)
        decoder_hidden = new_gate + update_gate * (decoder_hidden - new_gate)
        
        # Classifier
        clf_output = model.decoder.clf(decoder_hidden)
        decoder_output = model.decoder.log_softmax(clf_output)
        
        decoder_resetGates.append(reset_gate)
        decoder_updateGates.append(update_gate)
        decoder_newGates.append(new_gate)
        decoder_outputs.append(decoder_output)            
        decoder_hiddens.append(decoder_hidden)
                                                   
        if torch.all(decoder_output.argmax(-1)==EOS) == True:
            break
            
    outputs = torch.stack(decoder_outputs,dim=2)             # (b,6xxx,s)
    
    gru_info = { 
        'context':context,                                        # (b,128)
        'hiddens':torch.stack(decoder_hiddens,dim=2),             # (b,128,s)
        'resetgates':torch.stack(decoder_resetGates,dim=2),       # (b,128,s)
        'updategates':torch.stack(decoder_updateGates,dim=2),     # (b,128,s)
        'newgates':torch.stack(decoder_newGates,dim=2)            # (b,128,s)
    }
    
    return outputs, gru_info

## Generate certain condition valid datas ( by designated word / position filter )

In [20]:
certain_set = []

for sent in random.sample(corpus, k=16):
    
    if len(sent) > 8:
        sent = sent[:8]
    else:
        sent = sent + ['啊']*(8-len(sent))
    
    designated_POS = random.randint(a=2,b=10)
    designated_POS1 = random.randint(a=1,b=designated_POS-1)
    designated_POS2 = designated_POS - designated_POS1
    
    designated_word = random.choice(random.choice(corpus))
    
    control_signal = [ designated_word, str(designated_POS1) , str(designated_POS2)]
    
    control_signal = [ designated_word, '4','2']
    
    data = ['<SOS>'] + sent + ['<EOS>'] + control_signal
    
    print(data)
    
    certain_set.append(data)
    
dataset_certain = TestDataset(certain_set)

['<SOS>', '重', '温', '几', '次', '啊', '啊', '啊', '啊', '<EOS>', '极', '4', '2']
['<SOS>', '一', '千', '杯', '不', '醉', '不', '想', '睡', '<EOS>', '黑', '4', '2']
['<SOS>', '让', '美', '好', '带', '来', '欢', '笑', '在', '<EOS>', '般', '4', '2']
['<SOS>', '是', '否', '你', '让', '我', '受', '的', '不', '<EOS>', '中', '4', '2']
['<SOS>', '为', '什', '么', '留', '下', '这', '个', '结', '<EOS>', '唉', '4', '2']
['<SOS>', '有', '了', '你', '啊', '啊', '啊', '啊', '啊', '<EOS>', '意', '4', '2']
['<SOS>', '我', '的', '舞', '台', '自', '己', '主', '宰', '<EOS>', '望', '4', '2']
['<SOS>', '唱', '给', '那', '的', '小', '孩', '啊', '啊', '<EOS>', '只', '4', '2']
['<SOS>', '夜', '山', '森', '木', '啊', '啊', '啊', '啊', '<EOS>', '爱', '4', '2']
['<SOS>', '一', '一', '检', '点', '啊', '啊', '啊', '啊', '<EOS>', '命', '4', '2']
['<SOS>', '我', '却', '不', '知', '不', '觉', '的', '啊', '<EOS>', '便', '4', '2']
['<SOS>', '呀', '啦', '嘿', '啊', '啊', '啊', '啊', '啊', '<EOS>', '开', '4', '2']
['<SOS>', '远', '的', '可', '以', '把', '过', '去', '遗', '<EOS>', '须', '4', '2']
['<SOS>', '你', '或', '者', '一', '直', '觉'

## Prediction certain condition data

In [21]:
dataloader = DataLoader(dataset=dataset_certain,
                        batch_size=128,
                        shuffle=False,
                        collate_fn=dataset_certain.collate_fn,
                        num_workers=0)

predictions = []
trange = tqdm(dataloader, total=len(dataloader))

for x,x_len in trange:
    
    x = x.to(device)
    
    outputs,gru_info = Deconstruction(model,x,x_len)
    
    for pred in outputs.cpu().detach().numpy().argmax(1):
        predictions.append(pred)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))




In [22]:
# Process multi-EOS tokens
predictions_set = []
for p in predictions:
    p = list(p)
    if EOS in p:
        p = p[:p.index(EOS)+1]
    else:
        p.append(EOS)
    predictions_set.append(embedder.unTokenize(p))
    

for sent_id in range(len(certain_set)):
    sent_in = ''.join(certain_set[sent_id])
    sent_pred = ''.join(predictions_set[sent_id])
    print('input:\t{}\npred:\t{}\n'.format(sent_in,sent_pred))

input:	<SOS>重温几次啊啊啊啊<EOS>极42
pred:	<SOS>我是的的的极<EOS>

input:	<SOS>一千杯不醉不想睡<EOS>黑42
pred:	<SOS>我不的的的黑<EOS>

input:	<SOS>让美好带来欢笑在<EOS>般42
pred:	<SOS>我是的的的般<EOS>

input:	<SOS>是否你让我受的不<EOS>中42
pred:	<SOS>我在你在心中<EOS>

input:	<SOS>为什么留下这个结<EOS>唉42
pred:	<SOS>你哈你的唉唉<EOS>

input:	<SOS>有了你啊啊啊啊啊<EOS>意42
pred:	<SOS>我你你的的意<EOS>

input:	<SOS>我的舞台自己主宰<EOS>望42
pred:	<SOS>我在我的希望<EOS>

input:	<SOS>唱给那的小孩啊啊<EOS>只42
pred:	<SOS>我的的的我只是<EOS>

input:	<SOS>夜山森木啊啊啊啊<EOS>爱42
pred:	<SOS>我妹的的的爱<EOS>

input:	<SOS>一一检点啊啊啊啊<EOS>命42
pred:	<SOS>我是的的生命<EOS>

input:	<SOS>我却不知不觉的啊<EOS>便42
pred:	<SOS>我不的的我便<EOS>

input:	<SOS>呀啦嘿啊啊啊啊啊<EOS>开42
pred:	<SOS>我妹我我离开<EOS>

input:	<SOS>远的可以把过去遗<EOS>须42
pred:	<SOS>我不我不必须<EOS>

input:	<SOS>你或者一直觉得这<EOS>觉42
pred:	<SOS>我不我我感觉<EOS>

input:	<SOS>对折再对折轻轻把<EOS>你42
pred:	<SOS>我不我不爱你<EOS>

input:	<SOS>在我和你的痛苦中<EOS>埃42
pred:	<SOS>我你你的尘埃<EOS>



In [23]:
updategates = gru_info['updategates'].mean(0).detach().cpu().numpy()

resetgates = gru_info['resetgates'].mean(0).detach().cpu().numpy()

newgates = gru_info['newgates'].mean(0).detach().cpu().numpy()

context = gru_info['context'].detach().cpu().numpy()

output_len = updategates.shape[1]

updategates.shape , resetgates.shape , newgates.shape , context.shape


((128, 9), (128, 9), (128, 9), (16, 128))

## context_set[i] 收集position為i的輸入句經過Encoder產生的context(h0)

In [24]:
context_set = [[],[],[],[],[],[],[],[],[],[],[],[],[]]

for xi,hi in zip(x,context):
    xi = xi.cpu().numpy()
    pos = 0
    for token in xi[::-1]:
        w = embedder.to_word(token)
        if w.isdigit():
            pos += int(w)
        else:
            break
    print(pos,w)
    context_set[pos].append(hi)
    
for i in range(1,len(context_set)):
    try:
        context_set[i] = np.stack(context_set[i],axis=0)
    except:
        print('pos',i,'is empty')
        pass

6 极
6 黑
6 般
6 中
6 唉
6 意
6 望
6 只
6 爱
6 命
6 便
6 开
6 须
6 觉
6 你
6 埃
pos 1 is empty
pos 2 is empty
pos 3 is empty
pos 4 is empty
pos 5 is empty
pos 7 is empty
pos 8 is empty
pos 9 is empty
pos 10 is empty
pos 11 is empty
pos 12 is empty
