In [None]:
import os
import argparse
from train_option import global_train_parser
import warnings
from utils.train_base import check_options, load_data, Setup_model, Out_Wordemb, Save_Emb
from utils.minibatch_processing import Generate_MiniBatch
from utils.train_class import Langage_Model_Class, Trainer

import torch
import torch.nn as nn

class SDGs_unsuper_Model(nn.Module):
    def __init__(self, n_layer, emb_size, h_size, dr_rate, vocab_dict,*args):
        super().__init__()

        self.dr_rate = dr_rate
        self.Ws_share = nn.Linear(h_size, 1, bias=False)
        self.lstm_fwd = nn.LSTM(input_size=emb_size,hidden_size=h_size,num_layers=n_layer,batch_first=True,dropout=dr_rate)
        self.lstm_bkw = nn.LSTM(input_size=emb_size,hidden_size=h_size,num_layers=n_layer,batch_first=True,dropout=dr_rate)
        self.dropout = nn.Dropout(p=dr_rate)

        Max_Word_idx = max(vocab_dict.id2vocab_input[-1].keys())+1
        self.emb = nn.Embedding(Max_Word_idx, emb_size, padding_idx= vocab_dict.vocab2id_input[0]["<PAD>"])
        layer = []
        for lang in range(len(vocab_dict.id2vocab_output)):
            layer.append(nn.Linear(h_size, vocab_dict.V_size[lang]-1, bias=False))
        self.Ws_i = nn.ModuleList(layer)

    def __call__(self, BOS_t_id, t_lengths, *args):
        return self.forward(BOS_t_id, t_lengths, *args)

    def Switch_Lang(self, lang):
        self.lang = lang

    def Switch_fwdbkw(self,type):
        if (type == "fwd"):
            self.lstm = self.lstm_fwd
        elif (type == "bkw"):
            self.lstm = self.lstm_bkw
        else:
            raise Exception("Invalid type")

    def forward(self,input_id, input_id_len, *args):
        ht = self.decode(input_id, input_id_len, *args)
        score_V = self.Ws_i[self.lang](self.dropout(ht))
        score_eos = self.Ws_share(self.dropout(ht))
        score = torch.cat([score_eos, score_V], dim=2)
        return score

    def decode(self, input_id, input_id_len, *args):
        input_id_emb = self.emb(input_id)
        ht, (h_last, c_last) = self.lstm(input_id_emb)
        return  ht

    def set_device(self,is_cuda):
        if is_cuda:
            self.torch = torch.cuda
        else:
            self.torch = torch


In [None]:
parser = argparse.ArgumentParser(parents=[global_train_parser])
opt = parser.parse_args()

In [None]:
if (os.path.isdir(opt.save_dir)):
    message = 'Directory ' + "'" + opt.save_dir + "'" +' already exists.'
    warnings.warn(message)
else:
    os.mkdir(opt.save_dir)

In [None]:
check_options(opt)

In [None]:
file_name = opt.save_dir + '/' + opt.data
print("Save model as: ", file_name)

In [None]:
dataset, vocab_dict = load_data(opt.data)
dataset = Generate_MiniBatch(dataset, opt.batch_size)
print("Number of mini-batches", len(dataset.batch_idx_list))

In [None]:
lm = SDGs_unsuper_Model(opt.n_layer, opt.emb_size,  opt.h_size, opt.dr_rate, vocab_dict)
model = Langage_Model_Class(lm, len(vocab_dict.vocab2id_input), vocab_dict.vocab2id_input[0],vocab_dict.vocab2id_output[0])

In [None]:
model = Setup_model(model, opt.gpuid, vocab_dict)

In [None]:
trainer = Trainer(dataset, file_name)
trainer.set_optimiser(model, opt.opt_type, opt.learning_rate)
bestmodel = trainer.main(model, opt.epoch_size, opt.stop_threshold, opt.remove_models)

In [None]:
print("save embeddings")
vocab2emb_list = Out_Wordemb(vocab_dict.id2vocab_input, bestmodel.lm)
Save_Emb(vocab2emb_list, opt.emb_size, file_name)
