In [None]:
from torchtext.data import Field, BucketIterator, TabularDataset
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor
import math
import time
import torchtext.vocab as vocab

In [None]:
device = torch.device('cuda:0')
BATCH_SIZE = 20
HIDDEN_SIZE = 512
N_EPOCHS = 50
CLIP = 1
LEARNING_RATE = 4e-4 
#ENC_DROPOUT = 0.2
#DEC_DROPOUT = 0.2
TEACH_RATE = 1
LAMBDA_COVERAGE = 1
GENERATION_LEN = 80
#PRE_EMBEDDING_SIZE = 200
TRAINING_DATA_PATH = 'c:/users/derri/PycharmProjects/exp/expdata/checklist/baselinedata12_8.csv'

In [None]:
DATA = Field(truncate_first=80, fix_length=80)

train_data = TabularDataset(
        path=TRAINING_DATA_PATH,
        format='csv',
        fields=[('src', DATA), ('tgt', DATA)],
        skip_header = True)

DATA.build_vocab(train_data, min_freq=1)

DATA_SIZE = len(DATA.vocab)

print(DATA_SIZE)

train_dataloader = BucketIterator(
    train_data,
    batch_size = BATCH_SIZE,
    device = device,
    shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self,
                 input_size,
                 hidden_size,
                 dropout):
        
        super().__init__()
        self.hidden_size = hidden_size
        self.src_embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, bidirectional=True)
        self.enc_hid_out = nn.Linear(hidden_size*2, hidden_size, bias=False)
        self.enc_out_out = nn.Linear(hidden_size*2, hidden_size, bias=True)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        s_l, batch_size = src.size()

        src_emb = self.src_embedding(src) # src_len, batch, hidden
        enc_output, enc_hidden = self.rnn(src_emb) # src_len, batch, hidden*2; 2, batch, hidden
        enc_hidden = torch.cat([enc_hidden[0:enc_hidden.size(0):2],
                                enc_hidden[1:enc_hidden.size(0):2]], 2) # 1, batch, hidden*2
        enc_a = torch.cat([enc_output[:,:,:self.hidden_size],
                          enc_output[:,:,self.hidden_size:]], 2) # src_len, batch, hidden*2

        enc_out = self.enc_out_out(enc_a).view(s_l, batch_size, self.hidden_size) # src_len, batch, hidden
        #enc_out = self.dropout(enc_out)
        enc_hid = self.enc_hid_out(enc_hidden) # 1, batch, hidden

        return enc_out, enc_hid # src_len, batch, hidden; 1, batch, hidden

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_query = nn.Linear(hidden_size, hidden_size, bias=True)
        self.linear_context = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v = nn.Linear(hidden_size, 1, bias=False)
        self.linear_out = nn.Linear(hidden_size*2, hidden_size, bias=True)
        self.linear_cover = nn.Linear(1, hidden_size, bias=False)
        
    def score(self, a_d_o, a_e_o):
        batch_size, tgt_len, _ = a_d_o.size()
        src_len = a_e_o.size(1)
        a1 = self.linear_query(a_d_o.view(-1, self.hidden_size)) # batch*tgt_len, hidden
        a1 = a1.view(batch_size, tgt_len, 1, self.hidden_size) # batch, tgt_len, 1, hidden
        a1 = a1.expand(batch_size, tgt_len, src_len, self.hidden_size) # batch, tgt_len, src_len, hidden
        
        
        a2 = self.linear_context(a_e_o.contiguous().view(-1, self.hidden_size)) # batch*src_len, hidden 
        a2 = a2.view(batch_size, 1, src_len, self.hidden_size)
        a2 = a2.expand(batch_size, tgt_len, src_len, self.hidden_size) # batch, tgt_len, src_len, hidden
        
        a = torch.tanh(a1 + a2) # batch, tgt_len, src_len, hidden_size

        return self.v(a.view(-1, self.hidden_size)).view(batch_size, tgt_len, src_len) # batch, tgt_len, src_len
    
    def forward(self, attn_dec_state, attn_enc_state, attn_coverage):

        d_o = attn_dec_state.permute(1, 0, 2) # batch, tgt_len, hidden
        e_o = attn_enc_state.permute(1, 0, 2) # batch, src_len, hidden
        batch_size, target_l, _= d_o.size()
        source_l = e_o.size(1)
        if attn_coverage is not None:
            cover = attn_coverage.view(-1).unsqueeze(1) # tgt_len*batch*src_len, 1(tgt_len=1)
            a_o = self.linear_cover(cover).view(batch_size, source_l, self.hidden_size) # batch, src_len, hidden
            e_o = e_o + a_o # batch, src_len, hidden
            e_o = torch.tanh(e_o) # batch, src_len, hidden 
        align = self.score(d_o.contiguous(), e_o.contiguous()) # batch, tgt_len, src_len
        align_vectors = F.softmax(align.view(batch_size*target_l, source_l), -1) # batch*tgt_len, src_len
        align_vectors = align_vectors.view(batch_size, target_l, source_l) # batch, tgt_len, src_len
        c = torch.bmm(align_vectors, e_o)# batch, tgt_len, hidden

        # concatenate
        concat_c = torch.cat([c, d_o], 2).view(batch_size*target_l, self.hidden_size*2) # batch, tgt_len, hidden*2
        attn_h = self.linear_out(concat_c).view(batch_size, target_l, self.hidden_size) # batch, tgt_len, hidden
        attn_h2 = attn_h.permute(1, 0, 2).contiguous() # tgt_len, batch, hidden
        align_vectors2 = align_vectors.permute(1, 0, 2).contiguous() # tgt_len, batch, src_len
        return attn_h2, align_vectors2 # tgt_len, batch, hidden; tgt_len, batch, src_len

In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, attention, dropout):
        
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.GRU = nn.GRU(hidden_size, hidden_size)
        self.attn = attention
        self.dec_out = nn.Linear(hidden_size, output_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, dec_tgt, dec_hidden, attn_memory, dec_g_t, teach, coverage, dec_coverage=True): # dec_g_t is ground truth
        if dec_g_t == None:
            dec_emb = self.embedding(dec_tgt) # 1, batch, hidden
        else:
            if random.random() <= teach:
                dec_emb = self.embedding(dec_g_t) # 1, batch, hidden
            else:
                dec_emb = self.embedding(dec_tgt) # 1, batch, hidden
        dec_output, dec_hidden = self.GRU(dec_emb, dec_hidden) # tgt_len, batch, hidden; 1, batch, hidden
        dec_output_attn, dec_cov = self.attn(dec_output, attn_memory, coverage) # tgt_len, batch, hidden; tgt_len, batch, src_len
        if dec_coverage:
            coverage = dec_cov if coverage is None else dec_cov + coverage # tgt_len, batch, src_len
        dec_out = self.dec_out(dec_output_attn.contiguous().view(-1, self.hidden_size)) # tgt_len, batch, output
        #dec_out = self.dropout(dec_out)
        
        return dec_out, dec_hidden, dec_cov, coverage # tgt_len, batch, output; 1, batch, hidden; tgt_len, batch, src_len; tgt_len, batch, src_len

In [None]:
class S2S(nn.Module):
    def __init__(self, encoder, decoder, output_size, gen_len):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.output_size = output_size
        self.gen_len = gen_len-1
                
    def forward(self, src, tgt, teach):
        attns = {}
        attns["std"] = []
        attns["coverage"] = []
        tgt_len, batch_size = tgt.size()
        encoder_output, encoder_hidden = self.encoder(src) # hidden; 1, batch, hidden
        decoder_hidden = encoder_hidden.view(1, batch_size, -1) # 1, batch, hidden
        decoder_outputs = torch.ones(self.gen_len, batch_size, self.output_size, device=device) # tgt_len-1, batch
        tgt_start = torch.full((1, batch_size), 2, dtype=torch.long, device=device) # 1, batch
        decoder_attn_coverage = None
        for i in range(tgt_len-1):
            if i == 0:
                current_tgt = tgt_start # 1, batch
                g_t = None
            elif i > 0: #and random.random() <= 0.7:
                g_t = tgt[i].view(1, batch_size) # 1, batch
            decoder_output, decoder_hidden, decoder_attn_std, decoder_attn_coverage = self.decoder(current_tgt, decoder_hidden, encoder_output, g_t, teach, decoder_attn_coverage) 
            if attns["std"] == None:
                attns["std"] = decoder_attn_std
                attns["coverage"] = decoder_attn_coverage
            else:
                attns["std"].append(decoder_attn_std)
                attns["coverage"].append(decoder_attn_coverage)
            tem = decoder_output.view(1, batch_size, -1) # 1, batch, output
            tem= F.log_softmax(tem, -1) # 1, batch, output
            decoder_outputs[i] = tem
            current_tgt = tem.max(2)[1].view(1, batch_size) # 1, batch
        
        return decoder_outputs, attns

In [None]:
encoder = Encoder(DATA_SIZE, HIDDEN_SIZE, ENC_DROPOUT)
attention = Attention(HIDDEN_SIZE)
decoder = Decoder(HIDDEN_SIZE, DATA_SIZE, attention, DEC_DROPOUT)

model = S2S(encoder, decoder, DATA_SIZE, GENERATION_LEN).to(device)

In [None]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

model.apply(init_weights)

In [None]:
print('# generator parameters:', sum(param.numel() for param in model.parameters() if param.requires_grad))

In [None]:
PAD_IDX = DATA.vocab.stoi['<pad>']
criterion = nn.NLLLoss(ignore_index=PAD_IDX, reduction="sum")

In [None]:
def train(model, dataset, optimizer, criterion, clip, teach):
    epoch_loss = 0
    model.train()
    for _, batch in enumerate(dataset):
        covloss = 0
        src = batch.src[0:2, :] # (s_len,batch)
        tgt = batch.tgt# (t_len, batch)
        loss = 0
        optimizer.zero_grad()
        outputs, attns = model(src, tgt, teach)
        
        cov = attns.get("coverage", None)
        std = attns.get("std", None)
        for ii in range(len(cov)):
            covloss += torch.min(std[ii], cov[ii]).sum()            
        covloss *= LAMBDA_COVERAGE
        outputs = outputs.contiguous()
        outputs = outputs.view(-1, DATA_SIZE)
        tgt = tgt.contiguous()
        tgt = tgt[1:].view(-1)
        loss = criterion(outputs, tgt)
        loss = loss + covloss
        #print(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss/len(dataset)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=8, threshold=0.0001, threshold_mode='rel', cooldown=3, min_lr=0, eps=1e-9, verbose=False)
plot_loss = []
def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

for epoch in range(N_EPOCHS):
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    print(optimizer.param_groups[-1]['lr'])
    start_time = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP, TEACH_RATE)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.2f}')
    
    plot_loss.append(train_loss)
    scheduler.step(train_loss)
    if epoch % 10 == 0: 
        model_name = f"/opt/work/g1/s1910195/pt512/baselinenew_{epoch}.pt"
        para_name = f"/opt/work/g1/s1910195/pt512/baselinenew_para{epoch}.pt"
        torch.save(model, model_name) 
        checkpoint = {
                      "net": model.state_dict(),
                      "optimizer": optimizer.state_dict(),
                      "epoch": epoch+1,
                      "scheduler": scheduler.state_dict(),
                      "loss": plot_loss}
        torch.save(checkpoint, para_name) 

In [None]:
#reload checkpoint
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=8, threshold=0.0001, threshold_mode='rel', cooldown=3, min_lr=0, eps=1e-9, verbose=False)

def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

path_checkpoint = "/opt/work/g1/s1910195/pt512/baselinenew_para_12.pt" # checkpoint path
checkpoint = torch.load(path_checkpoint)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
scheduler.load_state_dict(checkpoint['scheduler'])
plot_loss = checkpoint['loss']

for epoch in range(start_epoch, N_EPOCHS):
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    print(optimizer.param_groups[0]['lr'])
    start_time = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP, TEACH_RATE)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.2f}')
    
    plot_loss.append(train_loss)
    scheduler.step(train_loss)
    
    if epoch % 10 == 0:
        model_name = f"/opt/work/g1/s1910195/pt512/baselinenew_{epoch}.pt"
        para_name = f"/opt/work/g1/s1910195/pt512/baselinenew_para_{epoch}.pt"
        torch.save(model, model_name) 
        checkpoint = {
                      "net": model.state_dict(),
                      "optimizer": optimizer.state_dict(),
                      "epoch": epoch+1,
                      "scheduler": scheduler.state_dict(),
                      "loss": plot_loss}
        torch.save(checkpoint, para_name) 

In [None]:
model = torch.load('c:/users/derri/PycharmProjects/exp/pt/baselinenew_40.pt', device)
model.eval()

In [None]:
test_data = TabularDataset(
        path='c:/users/derri/PycharmProjects/exp/finaltest/base/s4.csv',
        format='csv',
        fields=[('src', DATA), ('tgt', DATA)],
        skip_header = True)

test_dataloader = BucketIterator(
    test_data,
    batch_size = BATCH_SIZE,
    device = device,
    shuffle=False)

In [None]:
def eval(model, dataset):
    
    model.eval()
    teach = 0
    for _, batch in enumerate(dataset):   
        src = batch.src[0:2, :] # (len,batch)
        tgt = batch.tgt # (len,batch)
        outputs, attns = model(src, tgt, teach)        
        outputs = outputs.contiguous()
        res = outputs.max(2)[1]
        
    return res

In [None]:
gen_text = eval(model, test_dataloader)
print(sss.size())

In [None]:
res = [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
print(DATA.vocab.itos[1])
for i in gen_text:
    for j in range(len(i)):
        tt = DATA.vocab.itos[i[j].cpu().numpy().tolist()]
        res[j].append(tt)
#print(res)
for c in res:
    sent = ""
    for d in c:
        sent += d
        sent += ' '
        if d == '<eos>':
            #print(sent)
            break
    print(sent)
    #print('\n')