In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import time, random, math, string
import os
import json

from sklearn.metrics import classification_report,f1_score
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
SEED = 1225

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

from utils.data import processing_data_EUR_Lex, processing_data_wiki31k, processing_data_Amazon670k, processing_data_AmazonCat13k
from utils.plot import train_valid_loss
from utils.embedding import src_embedding_glove, tgt_embedding_glove
from model.seq2seq import Seq2Seq
from model.seq2seq_conv import Seq2Seq as Seq2SeqConv
from model.loss import DynamicHungarianLossAssignAll

In [2]:
gpu_id = 0
device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')

In [3]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name or "bias" in name:
            nn.init.uniform_(param.data, -0.1, 0.1)
        else:
            nn.init.constant_(param.data, 0)
            
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train(model, iterator, optimizer, loss_func, teacher_forcing_ratio, clip = None, lambda_embedding=0.1):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        src, lengths = batch.src
        tgt = batch.tgt
        # src = [src_len, batch_size]
        # tgt = [tgt_len, batch_size]
        optimizer.zero_grad()
        
        if isinstance(loss_func, nn.CrossEntropyLoss):
            output = model(src, lengths, tgt, teacher_forcing_ratio)
            output_dim = output.shape[-1]
            # output = [tgt_len, batch_size, output_dim]
            # transfrom output : flatten the output into 2 dim.
            output = output.view(-1, output_dim)
            tgt = tgt[1:].view(-1) # exclude SOS
            # output = [tgt_len * batch_size, output_dim]
            # tgt = [tgt_len * batch_size]
            loss = loss_func(output, tgt)
        elif isinstance(loss_func, DynamicHungarianLossAssignAll):
            output = model(src, lengths, tgt, teacher_forcing_ratio)
            output = output.transpose(0,1)
            tgt = tgt[1:].transpose(0,1)
            loss = loss_func(output, tgt, model.decoder.embedding)
                
        loss.backward()
        # clip
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def evaluate(model, iterator, loss_func, lambda_embedding=0.1):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
        
        for i, batch in enumerate(iterator):
            
            src, lengths= batch.src
            tgt = batch.tgt
            
            if isinstance(loss_func, nn.CrossEntropyLoss):
                output = model(src, lengths, tgt, 0) # turn off teacher forcing.
                output_dim = output.shape[-1]
                # output = [tgt_len, batch_size, output_dim]
                # transfrom output : flatten the output into 2 dim.
                output = output.view(-1, output_dim)
                tgt = tgt[1:].view(-1) # exclude SOS
                # output = [tgt_len * batch_size, output_dim]
                # tgt = [tgt_len * batch_size]
                loss = loss_func(output, tgt)
            elif isinstance(loss_func, DynamicHungarianLossAssignAll):
                output = model(src, lengths, tgt, 0) # turn off teacher forcing.
                output = output.transpose(0,1)
                tgt = tgt[1:].transpose(0,1)
                loss = loss_func(output, tgt, model.decoder.embedding)
                    
            epoch_loss += loss.item()
            
    return epoch_loss / len(iterator)

# a function that used to tell us how long an epoch takes.
def epoch_time(start_time, end_time):
    
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time  / 60)
    elapsed_secs = int(elapsed_time -  (elapsed_mins * 60))
    return  elapsed_mins, elapsed_secs

In [4]:
def predict_evaluate(iterator, model_result_path, SRC, TGT, conf, suffix="", *, use_final=False):
    INIT_IDX = TGT.vocab.stoi[TGT.init_token]
    PAD_IDX = TGT.vocab.stoi[TGT.pad_token]
    EOS_IDX = TGT.vocab.stoi[TGT.eos_token]
    
    if hasattr(conf, "dl_conv") and conf.dl_conv:
        model = Seq2SeqConv(conf)
    else:
        model = Seq2Seq(conf)
        
    if use_final:
        suffix +="_final"
        model.load_state_dict(torch.load('result/{}_final.pt'.format(model_result_path), map_location=f"cuda:{gpu_id}"))
    else:
        model.load_state_dict(torch.load('result/{}.pt'.format(model_result_path), map_location=f"cuda:{gpu_id}"))
    
    model.to(device)
        
    model.eval()
    
    tgt_true_all = []
    tgt_pred_all = []
    
    predict_file = f"result/{model_result_path}_test{suffix}.txt"
    if os.path.exists(predict_file):
        os.remove(predict_file)
    
    # evalute metrics
    TP = torch.zeros(len(TGT.vocab), dtype=torch.long, device=device)
    FP = torch.zeros(len(TGT.vocab), dtype=torch.long, device=device)
    FN = torch.zeros(len(TGT.vocab), dtype=torch.long, device=device)
    
    output_dim = len(TGT.vocab)
    total = 0
    
    with torch.no_grad():
        for k, batch in enumerate(iterator):
            
            src, lengths = batch.src
            tgt = batch.tgt
            # src : [seq_len, batch_size]
                
            result = None
            
            output = model.sample(src, lengths, INIT_IDX, PAD_IDX, EOS_IDX, conf.mask)
            # output = [seq_len, batch_size, output_dim]
            output = output.argmax(2).T
            # output: [batch_size, seq_len]
            result = output.cpu().numpy().tolist()
            
            src = src.T.cpu().numpy().tolist()
            tgt_tensor = tgt[1:].T
            # tgt_tensor: [batch_size, seq_len]
            tgt = tgt_tensor.cpu().numpy().tolist()
            
            with open(predict_file, 'a') as f:
                for i in range(len(result)):
                    # symptom
                    src_one = src[i]
                    src_true = []
#                     for j in range(len(src_one)):
#                         src_true.append(SRC.vocab.itos[src_one[j]])
                    # true labels
                    tgt_true_one = tgt[i]
                    tgt_true = []
                    for k in range(len(tgt_true_one)):
                        tgt_true.append(TGT.vocab.itos[tgt_true_one[k]])
                    # predicted herb
                    one = result[i]
                    tgt_result=[]
                    for j in range(len(one)):
                        tgt_result.append(TGT.vocab.itos[one[j]])
                    f.write("{} | {} | {} \n".format(" ".join(src_true), " ".join(tgt_true), " ".join(tgt_result)))
                    
                    output_i = output[i]
                    first_eos_index_o = (output_i==EOS_IDX).nonzero(as_tuple=True)[0]
                    if first_eos_index_o.size()[0] > 0:
                        first_eos_index_o = first_eos_index_o[0].item()
                        output_i = output_i[:first_eos_index_o]
                        
                    output_i_onehot = torch.zeros(output_dim, dtype=torch.bool, device=device).scatter_(0, output_i, 1)
                    
                    tgt_tensor_i = tgt_tensor[i]
                    first_eos_index_t = (tgt_tensor_i==EOS_IDX).nonzero(as_tuple=True)[0]
                    if first_eos_index_t.size()[0] > 0:
                        first_eos_index_t = first_eos_index_t[0].item()
                        tgt_tensor_i = tgt_tensor_i[:first_eos_index_t]
                    
                    tgt_tensor_i_onehot = torch.zeros(output_dim, dtype=torch.bool, device=device).scatter_(0, tgt_tensor_i, 1)
                    # for micro f1 score
                    TP_i = torch.logical_and(output_i_onehot, tgt_tensor_i_onehot, out=torch.empty(output_dim, dtype=torch.bool, device=device))
                    TP = TP + TP_i
                    FP_i = torch.logical_and(output_i_onehot, torch.logical_not(tgt_tensor_i_onehot), out=torch.empty(output_dim, dtype=torch.bool, device=device))
                    FP = FP + FP_i
                    FN_i = torch.logical_and(torch.logical_not(output_i_onehot), tgt_tensor_i_onehot, out=torch.empty(output_dim, dtype=torch.bool, device=device))
                    FN = FN + FN_i

    # ignore pad
    TP[PAD_IDX] = 0
    FP[PAD_IDX] = 0
    FN[PAD_IDX] = 0
    TP_sum = torch.sum(TP).item()
    f1_micro = 2*TP_sum / (2*TP_sum + torch.sum(FP).item() + torch.sum(FN).item())
    precision_micro = TP_sum / (TP_sum + torch.sum(FP).item())
    recall_micro = TP_sum / (TP_sum + torch.sum(FN).item())

    
    evaluation = f"f1_micro{f1_micro}\n precision_micro:{precision_micro}\n recall_micro:{recall_micro}"
    with open(f"result/{model_result_path}_evaluation{suffix}_new.txt",'w') as f:
        f.write(evaluation)

In [5]:
# compute weight
def class_balance(freq, beta=0.9):
    return (1-beta)/(1-beta**freq)

def get_class_balanced_weight(TGT, beta=0.9):
    vocab_size = len(TGT.vocab)
    weight = torch.ones(vocab_size, device=device)

    for i in range(vocab_size):
        token = TGT.vocab.itos[i]
        if token == TGT.unk_token or token == TGT.pad_token or token == TGT.init_token:
            weight[i] = (1-beta)*0.2
        else:
            weight[i] = class_balance(TGT.vocab.freqs[token], beta)
    return weight

In [6]:
with open("config/OTSeq2Set.json", 'r') as f:
    config = json.loads(f.read())

In [7]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [None]:
for data_name, conf_list in config.items():
    
    if data_name == "EUR-Lex":
        processing_data = processing_data_EUR_Lex
    elif data_name == "Wiki31k":
        processing_data = processing_data_wiki31k
    elif data_name =="Amazon670k":
        processing_data = processing_data_Amazon670k
    elif data_name == "AmazonCat13k":
        processing_data = processing_data_AmazonCat13k
    else:
        raise Exception("Data set don't exists!")

    for conf in conf_list:
        if conf.get("finish") == False:
            con = Struct(**conf)
            print(f"dataset:{data_name},model:{con.model}:{con.model_number}")
            # load data
            for_CE = True if con.loss_func in ("CE", "DynamicHungarianLoss") else False
            
            tgt_sort = True if hasattr(con, "tgt_sort") and con.tgt_sort else False
            
            if hasattr(con, "max_src_len"):
                train_iter, valid_iter, test_iter, SRC, TGT = processing_data(device, max_src_len=con.max_src_len, include_lengths=True, batch_size=con.batch_size,
                                                                          for_CE=for_CE, valid_split=con.valid_split, tgt_sort=tgt_sort)
            else:
                train_iter, valid_iter, test_iter, SRC, TGT = processing_data(device, include_lengths=True, batch_size=con.batch_size,
                                                                              for_CE=for_CE, valid_split=con.valid_split, tgt_sort=tgt_sort)
                                                                        
            PAD_IDX = TGT.vocab.stoi[TGT.pad_token]
            INIT_IDX = TGT.vocab.stoi[TGT.init_token]
            EOS_IDX = TGT.vocab.stoi[TGT.eos_token]
            print(f"pad_idx:{PAD_IDX},init_idx:{INIT_IDX},eos_idx:{EOS_IDX}")

            
            INPUT_DIM = len(SRC.vocab)
            OUTPUT_DIM = len(TGT.vocab)
            con.src_vocab_size = INPUT_DIM
            con.tgt_vocab_size = OUTPUT_DIM
            
            # initialize model.
            if hasattr(con, "dl_conv") and con.dl_conv:
                model = Seq2SeqConv(con).to(device)
            else:
                model = Seq2Seq(con).to(device)
            model.apply(init_weights)
            
            if con.src_glove:
                model.encoder.embedding = src_embedding_glove(device, SRC)

            if con.tgt_embedding == "glove":
                print("tgt_embedding use: glove")
                model.decoder.embedding = tgt_embedding_glove(device, TGT)
                
            print(model)
            print(f'The model has {count_parameters(model):,} trainable parameters')
            if hasattr(model.decoder, "bottleneck1"):
                print(f'The bottleneck1 has {count_parameters(model.decoder.bottleneck1):,} trainable parameters')
                
            if hasattr(model, "dl_conv"):
                print(f'The dl_conv has {count_parameters(model.dl_conv):,} trainable parameters')
                
            optimizer = optim.Adam(model.parameters(), lr=con.learning_rate)
            
            if con.use_CosineAnnealingLR:
                scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=con.N_EPOCHS)
            else:
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                 'min', 
                                                                 factor=0.5,
                                                                 patience=3,
                                                                 verbose=1,
                                                                 min_lr = 1e-10)
            # loss func
            if con.loss_func == "CE":
                loss_func = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
            elif con.loss_func == "DynamicHungarianLossAssignAll":
                loss_func = DynamicHungarianLossAssignAll(PAD_IDX, con.ignore_index, con.assign_all_pre, con.empty_weight, con.lambda_embedding,
                                                         con.ipot_E_non_empty, con.ipot_E_first_n_pre)
                
            model_name = "_".join([data_name, con.model, con.loss_func])
            # train
            model_dir = f"result/{model_name}"
            if not os.path.exists(model_dir):
                os.mkdir(model_dir)
            model_result_path = f"{model_name}/{con.model_number}"
            best_valid_loss = float('inf')

            train_loss_list = []
            valid_loss_list = []
            result_data = {}
            total_time = 0
            for epoch in range(con.N_EPOCHS):

                start_time = time.time()
                
                train_loss = train(model, train_iter, optimizer, loss_func, con.teacher_forcing_ratio, con.CLIP)

                end_time = time.time()
                
                total_time += end_time - start_time
                epoch_mins, epoch_secs = epoch_time(start_time, end_time)

                train_loss_list.append(train_loss)
                
                print(f"Epoch: {epoch+1:02} | Time {epoch_mins}m {epoch_secs}s| lr: {optimizer.param_groups[0]['lr']}")
                print(f"\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}")
                
#                 if not valid_iter:
#                     valid_iter = test_iter
                    
                if valid_iter:
                    valid_loss = evaluate(model, valid_iter, loss_func)
                        
                    valid_loss_list.append(valid_loss)

                    if valid_loss < best_valid_loss:
                        best_valid_loss = valid_loss
                        torch.save(model.state_dict(), 'result/{}.pt'.format(model_result_path))

                    print(f"\tValid Loss: {valid_loss:.3f} | Valid PPL: {math.exp(valid_loss):7.3f}")

                if con.use_CosineAnnealingLR:
                    scheduler.step()
                elif valid_iter:
                    scheduler.step(valid_loss)

            torch.save(model.state_dict(), 'result/{}_final.pt'.format(model_result_path))
            result_data['total_time'] = total_time
            result_data['train_loss_list'] = train_loss_list
            result_data['valid_loss_list'] = valid_loss_list
            result_data['n_epochs'] = con.N_EPOCHS
            with open('result/{}_stat.json'.format(model_result_path), 'w') as f:
                f.write(json.dumps(result_data))
            # draw
            train_valid_loss(train_loss_list, valid_loss_list, model_result_path)
            # test
            if valid_iter:
                predict_evaluate(test_iter, model_result_path, SRC, TGT, con)
                predict_evaluate(test_iter, model_result_path, SRC, TGT, con, use_final=True)
            else:
                predict_evaluate(test_iter, model_result_path, SRC, TGT, con, use_final=True)

dataset:EUR-Lex,model:seq2seq_attention:1
pad_idx:1,init_idx:2,eos_idx:3
tgt_embedding use: glove
Seq2Seq(
  (encoder): rnn_encoder(
    (embedding): Embedding(50002, 300)
    (rnn): GRU(300, 512, num_layers=2, dropout=0.2, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (decoder): rnn_decoder(
    (embedding): Embedding(3960, 300)
    (rnn): StackedGRU(
      (dropout): Dropout(p=0.2, inplace=False)
      (layers): ModuleList(
        (0): GRUCell(1324, 512)
      )
    )
    (fc_out): Linear(in_features=1836, out_features=3960, bias=True)
    (attention): bahdanau_attention(
      (linear_encoder): Linear(in_features=1024, out_features=512, bias=True)
      (linear_decoder): Linear(in_features=512, out_features=512, bias=True)
      (linear_v): Linear(in_features=512, out_features=1, bias=True)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
)
The model has 34,824,401 trainable parameters
Ep