In [None]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor
import math
import time
from torch.utils.data import Dataset

In [None]:
"""
data should be csv.
with 4 col ['theme', 'keyword', 'src', 'tgt']
length of 'theme' must be 1
"""
device = torch.device('cuda:0')
BATCH_SIZE = 64
TEST_BATCH_SIZE = 10
HIDDEN_SIZE = 256
N_EPOCHS = 51
CLIP = 1
LEARNING_RATE = 3e-5 
ENC_DROPOUT = 0.2 # not use
DEC_DROPOUT = 0.2 # not use
TEACH_RATE = 1
LAMBDA_COVERAGE = 1
THEME_LEN = 1 # fixed set as 1
KEYWORD_LEN = 1 # fixed set as 1
SRC_LEN = 80
TGT_LEN = 80
VOCAB_MIN_FREQUENCY = 1
GENERATION_LEN = 80
CHECKPOINT_TIMES = 8 # every XX times save the model and parameters
TRAINING_DATA_PATH = '...'
TEST_DATA_PATH = '...'

In [None]:
class Lang:
    
    
    def __init__(self, name):
        self.name = name
        self.wordfrequency = {}
        self.word2index = {'<unk>': 0, '<pad>': 1, '<sos>': 2, '<eos>': 3}
        self.word2count = {'<unk>': 0, '<pad>': 0, '<sos>': 0, '<eos>': 0}
        self.index2word = {0: '<unk>', 1: "<pad>", 2: "<sos>", 3:'<eos>'}
        self.n_words = 4  # Count SOS and EOS and PAD and UNK

    def get_vocab(self, n): 
        for word in self.wordfrequency.keys():
            self.add_word(word, n)
            
    def get_frequency(self, sentence):
        for word in sentence.split(' '):
            self.frequency(word)
            
    def add_word(self, word, n):
        if self.wordfrequency[word] >= n:
            if word not in self.word2index: 
                self.word2index[word] = self.n_words
                self.word2count[word] = self.wordfrequency[word]
                self.index2word[self.n_words] = word
                self.n_words += 1
            else:
                self.word2count[word] = self.wordfrequency[word]
        else:
            self.word2count['<unk>'] += 1
            
    def frequency(self, word):
        if word not in self.wordfrequency:
            self.wordfrequency[word] = 1
        else:
            self.wordfrequency[word] += 1

In [None]:
with open(TRAINING_DATA_PATH, 'r', encoding='utf-8-sig') as r:
    r1 = r.readlines()
    r1 = r1[1:]
    lang = Lang(r1)
    for i in range(len(r1)):
        j = r1[i].split(',')
        b1 = j[0].strip(' ').strip('\n').strip(' ')
        b2 = j[1].strip(' ').strip('\n').strip(' ')
        b3 = j[2].strip(' ').strip('\n').strip(' ')
        b4 = j[3].strip(' ').strip('\n').strip(' ')
        d = f'{b1} {b2} {b3} {b4}'
        lang.get_frequency(d)
    lang.get_vocab(VOCAB_MIN_FREQUENCY)

In [None]:
def indexes_from_sentence(lang, sentence, n):
    ll = []
    for word in sentence.split(' '):
        if word in lang.word2index:
            ll.append(lang.word2index[word])
        else:
            ll.append(lang.word2index['<unk>'])
            
    while len(ll) < n:
        ll.append(lang.word2index['<pad>'])      
    if len(ll) > n:
        ll = ll[:n-1]
        ll.append(lang.word2index['<eos>'])        
    return ll

def tensor_from_sentence(lang, sentence, n):
    indexes = indexes_from_sentence(lang, sentence, n)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1)

def tensors_from_pair(pair,  theme_len, keyword_len, src_len, tgt_len):
    theme = tensor_from_sentence(lang, pair[0], theme_len)
    keyword = tensor_from_sentence(lang, pair[1], keyword_len)
    src = tensor_from_sentence(lang, pair[2], src_len)
    tgt = tensor_from_sentence(lang, pair[3], tgt_len)
    return (theme, keyword, src, tgt)

In [None]:
def get_datas(data_path, theme_len=THEME_LEN, keyword_len=KEYWORD_LEN, src_len=SRC_LEN, tgt_len=TGT_LEN):
    datas = []
    with open(data_path, 'r', encoding='utf-8-sig') as r:
        r1 = r.readlines()
        for i in range(1, len(r1)):
            j = r1[i].split(',')
            b1 = j[0].strip(' ').strip('\n').strip(' ')
            b2 = j[1].strip(' ').strip('\n').strip(' ')
            b3 = j[2].strip(' ').strip('\n').strip(' ')
            b4 = j[3].strip(' ').strip('\n').strip(' ')
            theme, keyword, src, tgt = tensors_from_pair([b1, b2, b3, b4], theme_len, keyword_len, src_len, tgt_len)
            datas.append((theme, keyword, src, tgt))
    return datas

In [None]:
class Create_Dataset(Dataset):
    
    
    def __init__(self, data):
        self.datas = data
                
    def __len__(self):
        return len(self.datas)

    def __getitem__(self, idx):
        return self.datas[idx]

In [None]:
train_datas = get_datas(TRAINING_DATA_PATH)
train_dataset = Create_Dataset(train_datas)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
DATA_SIZE = len(lang.index2word)
print(DATA_SIZE)

In [None]:
class Encoder(nn.Module):
    
    
    def __init__(self, data_size, hidden_size, dropout):
        
        super().__init__()
        self.hidden_size = hidden_size
        self.theme_embedding = nn.Embedding(data_size, hidden_size)
        self.keyword_embedding = nn.Embedding(data_size, hidden_size)
        self.src_embedding = nn.Embedding(data_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, bidirectional=True)
        self.enc_hid_out = nn.Linear(hidden_size*2, hidden_size, bias=False)
        self.enc_a_weight = nn.Linear(hidden_size*2, hidden_size, bias=True)
        self.enc_theme_weight = nn.Linear(hidden_size, hidden_size)
        self.enc_keyword_weight = nn.Linear(hidden_size, hidden_size)
        self.enc_weight = nn.Linear(hidden_size, 1)
        self.enc_out_out = nn.Linear(hidden_size*3, hidden_size, bias=True)
        self.dropout = nn.Dropout(dropout)
        
    def point(self, enc_theme, enc_keyword, enc_a): 
        ps_l = enc_a.size(1)
        batch_size, pt_l, _ = enc_keyword.size()
        b1 = self.enc_a_weight(enc_a.contiguous().view(-1, self.hidden_size*2)) # batch*src_len, hidden
        b1 = b1.view(batch_size, ps_l, 1, self.hidden_size) # batch, src_len, 1, hidden
        b1 = b1.expand(batch_size, ps_l, pt_l, self.hidden_size) # batch, src_len, theme_len, hidden
        
        
        b22 = self.enc_theme_weight(enc_theme.contiguous().view(-1, self.hidden_size)) # batch*theme_len, hidden 
        b2 = b22.view(batch_size, 1, 1, self.hidden_size) # batch, 1, 1, hidden
        b2 = b2.expand(batch_size, ps_l, pt_l, self.hidden_size) # batch, src_len, keyword_len, hidden
        
        b3 = self.enc_keyword_weight(enc_keyword.contiguous().view(-1, self.hidden_size)) # batch*keyword_len, hidden
        b3 = b3.view(batch_size, 1, pt_l, self.hidden_size) # batch, 1, keyword_len, hidden
        b3 = b3.expand(batch_size, ps_l, pt_l, self.hidden_size) # batch, src_len, keyword_len, hidden
        
        b = torch.tanh(b1 + b2 + b3) # batch, src_len, keyword_len, hidden_size (keyword_len=theme_len=1)!!!!!!

        enc_w_0 = self.enc_weight(b.view(-1, self.hidden_size)).view(batch_size, ps_l, pt_l) # batch, src_len, keyword_len
        
        return torch.bmm(enc_w_0, b22.view(batch_size, 1, self.hidden_size).expand(batch_size, pt_l, self.hidden_size)) # batch, src_len, hidden
    
    def forward(self, theme, keyword, src):
        batch_size = src.size(1)
        src = src.view(-1, batch_size) # src_len, batch
        theme = theme.view(-1, batch_size) # theme_len, batch
        keyword = keyword.view(-1, batch_size) # keyword_len, batch
        s_l = src.size(0)
        t_l = theme.size(0)
        src_emb = self.src_embedding(src) # src_len, batch, hidden
        theme_emb = self.theme_embedding(theme) # theme_len, batch, hidden
        keyword_emb = self.keyword_embedding(keyword) # keyword_len, batch, hidden
        enc_output, enc_hidden = self.rnn(src_emb) # src_len, batch, hidden*2; 2, batch, hidden
        enc_hidden = torch.cat([enc_hidden[0:enc_hidden.size(0):2],
                                enc_hidden[1:enc_hidden.size(0):2]], 2) # 1, batch, hidden*2
        enc_a = torch.cat([enc_output[:,:,:self.hidden_size],
                          enc_output[:,:,self.hidden_size:]], 2) # src_len, batch, hidden*2
        enc_theme = theme_emb.transpose(0, 1) # batch, theme_len, hidden
        enc_a = enc_a.transpose(0, 1) # batch, src_len, hidden*2
        enc_keyword = keyword_emb.transpose(0, 1) # batch, keyword_len, hidden
        enc_w_2 = self.point(enc_theme, enc_keyword, enc_a) # batch, src_len, hidden

        concat_enc_w = torch.cat([enc_w_2, enc_a], 2).view(batch_size*s_l, self.hidden_size*3) # batch*src_len, hidden*3
        enc_out = self.enc_out_out(concat_enc_w).view(batch_size, s_l, self.hidden_size) # batch, src_len, hidden
        enc_out = enc_out.transpose(0, 1).contiguous() # src_len, batch, hidden
        
        enc_hid = self.enc_hid_out(enc_hidden) # 1, batch, hidden
        #enc_out = self.dropout(dec_out)
        
        return enc_out, enc_hid # src_len, batch, hidden; 1, batch, hidden

In [None]:
class Attention(nn.Module):
    
    
    def __init__(self,
                 hidden_size):
        
        super().__init__()
        
        self.hidden_size = hidden_size
        self.linear_query = nn.Linear(hidden_size, hidden_size, bias=True)
        self.linear_context = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, 1, bias=False)
        self.linear_out = nn.Linear(hidden_size*2, hidden_size, bias=True)
        self.linear_cover = nn.Linear(1, hidden_size, bias=False)
        
    def score(self, a_d_o, a_e_o):
        batch_size, tgt_len, _ = a_d_o.size()
        src_len = a_e_o.size(1)
        a1 = self.linear_query(a_d_o.view(-1, self.hidden_size)) # batch*tgt_len, hidden
        a1 = a1.view(batch_size, tgt_len, 1, self.hidden_size) # batch, tgt_len, 1, hidden
        a1 = a1.expand(batch_size, tgt_len, src_len, self.hidden_size) # batch, tgt_len, src_len, hidden
        
        a2 = self.linear_context(a_e_o.contiguous().view(-1, self.hidden_size)) # batch*src_len, hidden 
        a2 = a2.view(batch_size, 1, src_len, self.hidden_size)
        a2 = a2.expand(batch_size, tgt_len, src_len, self.hidden_size) # batch, tgt_len, src_len, hidden
        
        a = torch.tanh(a1 + a2) # batch, tgt_len, src_len, hidden_size

        return self.v(a.view(-1, self.hidden_size)).view(batch_size, tgt_len, src_len) # batch, tgt_len, src_len
    
    def forward(self, attn_dec_state, attn_enc_state, attn_coverage):
        d_o = attn_dec_state.permute(1, 0, 2) # batch, tgt_len, hidden
        e_o = attn_enc_state.permute(1, 0, 2) # batch, src_len, hidden
        batch_size, target_l, _= d_o.size()
        source_l = e_o.size(1)
        
        if attn_coverage is not None:
            cover = attn_coverage.view(-1).unsqueeze(1) # tgt_len*batch*src_len, 1(tgt_len=1)
            a_o = self.linear_cover(cover).view(batch_size, source_l, self.hidden_size) # batch, src_len, hidden
            e_o = e_o + a_o # batch, src_len, hidden
            e_o = torch.tanh(e_o) # batch, src_len, hidden 
        
        align = self.score(d_o.contiguous(), e_o.contiguous()) # batch, tgt_len, src_len
        align_vectors = F.softmax(align.view(batch_size*target_l, source_l), -1) # batch*tgt_len, src_len
        align_vectors = align_vectors.view(batch_size, target_l, source_l) # batch, tgt_len, src_len
        c = torch.bmm(align_vectors, e_o)# batch, tgt_len, hidden

        concat_c = torch.cat([c, d_o], 2).view(batch_size*target_l, self.hidden_size*2) # batch, tgt_len, hidden*2
        attn_h = self.linear_out(concat_c).view(batch_size, target_l, self.hidden_size) # batch, tgt_len, hidden
        attn_h2 = attn_h.permute(1, 0, 2).contiguous() # tgt_len, batch, hidden
        align_vectors2 = align_vectors.permute(1, 0, 2).contiguous() # tgt_len, batch, src_len
        
        return attn_h2, align_vectors2 # tgt_len, batch, hidden; tgt_len, batch, src_len

In [None]:
class Decoder(nn.Module):
    
    
    def __init__(self, hidden_size, data_size, attention, dropout):
        
        super().__init__()
        
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(data_size, hidden_size)
        self.GRU = nn.GRU(hidden_size, hidden_size)
        self.attn = attention
        self.dec_out = nn.Linear(hidden_size, data_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, dec_tgt, dec_hidden, attn_memory, dec_g_t, teach, coverage, dec_coverage=True): # dec_g_t is ground truth
        
        if dec_g_t == None:
            dec_emb = self.embedding(dec_tgt) # 1, batch, hidden
        else:
            if random.random() <= teach:
                dec_emb = self.embedding(dec_g_t) # 1, batch, hidden
            else:
                dec_emb = self.embedding(dec_tgt) # 1, batch, hidden
        dec_output, dec_hidden = self.GRU(dec_emb, dec_hidden) # tgt_len, batch, hidden; 1, batch, hidden
        dec_output_attn, dec_cov = self.attn(dec_output, attn_memory, coverage) # tgt_len, batch, hidden; tgt_len, batch, src_len
        if dec_coverage:
            coverage = dec_cov if coverage is None else dec_cov + coverage # tgt_len, batch, src_len
        dec_out = self.dec_out(dec_output_attn.contiguous().view(-1, self.hidden_size)) # tgt_len, batch, output
        #dec_out = self.dropout(dec_out)
        
        return dec_out, dec_hidden, dec_cov, coverage # tgt_len, batch, output; 1, batch, hidden; tgt_len, batch, src_len; tgt_len, batch, src_len

In [None]:
class S2S(nn.Module):
    def __init__(self, encoder, decoder, output_size, gen_len):
        
        super().__init__()
        
        self.gen_len = gen_len-1
        self.encoder = encoder
        self.decoder = decoder
        self.output_size = output_size
                
    def forward(self, theme, keyword, src, tgt, teach):
        attns = {}
        attns["std"] = []
        attns["coverage"] = []
        tgt_len, batch_size = tgt.size()
        encoder_output, encoder_hidden = self.encoder(theme,
                                                      keyword, src) # src_len, batch, hidden; 1, batch, hidden
        decoder_hidden = encoder_hidden.view(1, batch_size, -1) # 1, batch, hidden
        decoder_outputs = torch.ones(self.gen_len, batch_size, self.output_size, device=device) # tgt_len-1, batch
        tgt_start = torch.full((1, batch_size), 2, dtype=torch.long, device=device) # 1, batch
        decoder_attn_coverage = None
        for i in range(tgt_len-1):
            if i == 0:
                current_tgt = tgt_start # 1, batch
                g_t = None
            elif i > 0: #and random.random() <= 0.7:
                g_t = tgt[i].view(1, batch_size) # 1, batch
            decoder_output, decoder_hidden, decoder_attn_std, decoder_attn_coverage = self.decoder(current_tgt, decoder_hidden, encoder_output, g_t, teach, decoder_attn_coverage)

            if attns["std"] == None:
                attns["std"] = decoder_attn_std
                attns["coverage"] = decoder_attn_coverage
            else:
                attns["std"].append(decoder_attn_std)
                attns["coverage"].append(decoder_attn_coverage)
            tem = decoder_output.view(1, batch_size, -1) # 1, batch, output
            tem= F.log_softmax(tem, -1) # 1, batch, output
            decoder_outputs[i] = tem
            current_tgt = tem.max(2)[1].view(1, batch_size) # 1, batch
        
        return decoder_outputs, attns

In [None]:
encoder = Encoder(DATA_SIZE, HIDDEN_SIZE, ENC_DROPOUT)
attention = Attention(HIDDEN_SIZE)
decoder = Decoder(HIDDEN_SIZE, DATA_SIZE, attention, DEC_DROPOUT)

model = S2S(encoder, decoder, DATA_SIZE, GENERATION_LEN).to(device)

In [None]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

In [None]:
print('# generator parameters:', sum(param.numel() for param in model.parameters() if param.requires_grad))

In [None]:
PAD_IDX = lang.word2index['<pad>']
criterion = nn.NLLLoss(ignore_index=PAD_IDX, reduction="sum")

In [None]:
def train(model, dataset, optimizer, criterion, clip, teach):    
    epoch_loss = 0
    model.train()
    
    for _, batch in enumerate(dataset):
        covloss = 0
        theme = batch[0].permute(1, 0) # (batch,1)
        keyword = batch[1].permute(1, 0) # (batch,1)
        src = batch[2].permute(1, 0) # (batch, s_len)
        tgt = batch[3].permute(1, 0) # (batch, t_len)
        loss = 0
        optimizer.zero_grad()
        outputs, attns = model(theme, keyword, src, tgt, teach)
        
        cov = attns.get("coverage", None)
        std = attns.get("std", None)
        for ii in range(len(cov)):
            covloss += torch.min(std[ii], cov[ii]).sum()            
        covloss *= LAMBDA_COVERAGE
        outputs = outputs.contiguous()
        outputs = outputs.view(-1, DATA_SIZE)
        tgt = tgt.contiguous()
        tgt = tgt[1:].view(-1)
        loss = criterion(outputs, tgt)
        loss = loss + covloss
        #print(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss/len(dataset)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=8, threshold=0.0001, cooldown=1, eps=1e-9)
plot_loss = []

for epoch in range(N_EPOCHS):
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    print(optimizer.param_groups[-1]['lr'])
    start_time = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP, TEACH_RATE)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.2f}')
    
    plot_loss.append(train_loss)      
    
    scheduler.step(train_loss)
    
    if epoch % CHECKPOINT_TIMES == 0:
        model_name = f"..."
        para_name = f"..."
        torch.save(model, model_name) 
        checkpoint = {"net": model.state_dict(),
                      "optimizer": optimizer.state_dict(),
                      "epoch": epoch+1,
                      "scheduler": scheduler.state_dict(),
                      "loss": plot_loss}#"scheduler": scheduler.state_dict()
        torch.save(checkpoint, para_name)  

In [None]:
#start from checkpoint
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=8, threshold=0.0001, cooldown=1, eps=1e-9)

path_checkpoint = "..."  # checkepoint path
checkpoint = torch.load(path_checkpoint)  # load checkpoint
model.load_state_dict(checkpoint['net'])  # load parameters
optimizer.load_state_dict(checkpoint['optimizer'])  # load optimizer
start_epoch = checkpoint['epoch']  # set start epoch
scheduler.load_state_dict(checkpoint['scheduler']) # load scheduler
plot_loss = checkpoint['loss'] # load loss

for epoch in range(start_epoch, N_EPOCHS):
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    print(optimizer.param_groups[-1]['lr'])
    start_time = time.time()
    train_loss = train(model, dataloader, optimizer, criterion, CLIP, TEACH_RATE)
    
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.2f}')
    
    plot_loss.append(train_loss)
    scheduler.step(train_loss)
    
    if epoch % CHECKPOINT_TIMES == 0:
        model_name = f"..."
        para_name = f"..."
        torch.save(model, model_name) 
        checkpoint = {
                      "net": model.state_dict(),
                      "optimizer": optimizer.state_dict(),
                      "epoch": epoch+1,
                      "scheduler": scheduler.state_dict(),
                      "loss": plot_loss}#"scheduler": lr_schedule.state_dict()
        torch.save(checkpoint, para_name) 

In [None]:
"""# test load
model = torch.load("")
model.eval()"""

In [None]:
test_datas = get_datas(TEST_DATA_PATH)
test_dataset = Create_Dataset(test_datas)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [None]:
def eval(model, dataset):
    model.eval()
    teach = 0
    result = []
    
    for _, batch in enumerate(dataset):
        theme = batch[0].permute(1, 0) # (batch,1)
        keyword = batch[1].permute(1, 0) # (batch,1)
        src = batch[2].permute(1, 0) # (batch, s_len)
        tgt = batch[3].permute(1, 0) # (batch, t_len)
        outputs, attns = model(theme, keyword, src, tgt, teach)        
        outputs = outputs.contiguous()
        res = outputs.max(2)[1]
        
        tem = [[] for i in range(TEST_BATCH_SIZE)]
        for i in test_res:
            for j in range(len(i)):
                tt = lang.index2word[i[j].cpu().numpy().tolist()]
                tem[j].append(tt)
        ans = []
        for word in tem:
            sent = ""
            for itera in word:
                sent += itera
                sent += ' '
                if itera == '<eos>':
                    break
            ans.append(sent)
        result.append(ans)
    return result

In [None]:
test_res = eval(model, test_dataloader)
for batch in ans:
    for sentence in batch:
        print(sentence)
        print('\n')

In [None]:
"""res = [[] for i in range(TEST_BATCH_SIZE)]
for i in test_res:
    for j in range(len(i)):
        tt = lang.index2word[i[j].cpu().numpy().tolist()]
        res[j].append(tt)
ans = []
for word in res:
    sent = ""
    for itera in word:
        sent += itera
        sent += ' '
        if itera == '<eos>':
            break
    ans.append(sent)
for sentence in ans:
    print(sentence)
    print('\n')"""