In [1]:
import os
import sys
import nltk
import time
import math
import torch
import random
import argparse
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset

sys.path.append("../")
from load_pretrain_label import load_preprocess_document_labels
#from model.ide_ae_decoder import IDEDataset, IDEAEDecoder
from utils.toolbox import same_seeds, show_settings, record_settings, get_preprocess_document, get_preprocess_document_embs, get_preprocess_document_labels, get_word_embs, merge_targets
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all, semantic_precision_all, retrieval_precision_all_v2, semantic_precision_all_v2

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_num_threads(15)

In [2]:
def generate_dataset(config):    
    # Data preprocessing
    unpreprocessed_corpus ,preprocessed_corpus = get_preprocess_document(**config)
    texts = [text.split() for text in preprocessed_corpus]
    print('[INFO] Load corpus done.')

    # Generating document embedding
    while True:
        try:
            doc_embs, doc_model, device = get_preprocess_document_embs(preprocessed_corpus, config['encoder'])
            break
        except:
            print('[Error] CUDA Memory Insufficient, retry after 15 secondes.')
            time.sleep(15)
    print('[INFO] Generate embedding done.')
    
    # Generate Decode target & Vocabulary
    if config['target'] == 'keybert' or config['target'] == 'yake':
        labels, vocabularys= load_preprocess_document_labels(config)
        label = labels[config['target']].toarray()
    else:
        labels, vocabularys= get_preprocess_document_labels(preprocessed_corpus)
        label = labels[config['target']]
        vocabularys = vocabularys[config['target']]
    print('[INFO] Load label done.')
    
    # generate idx to token
    id2token = {k: v for k, v in zip(range(0, len(vocabularys)), vocabularys)}
    print('[INFO] Generate id2token done.')
    
    idx = np.arange(len(unpreprocessed_corpus))
    np.random.shuffle(idx)
    train_length = int(len(unpreprocessed_corpus) * 0.8)
    train_idx = idx[:train_length]
    valid_idx = idx[train_length:]

    train_unpreprocessed_corpus = list(np.array(unpreprocessed_corpus)[train_idx])
    valid_unpreprocessed_corpus = list(np.array(unpreprocessed_corpus)[valid_idx])
    train_embs = np.array(doc_embs)[train_idx]
    valid_embs = np.array(doc_embs)[valid_idx]
    train_label = np.array(label)[train_idx]
    valid_label = np.array(label)[valid_idx]
    
    # Generate labeled mask
    label_masks = np.zeros((train_embs.shape[0], 1), dtype=bool)
    num_labeled_data = int(train_embs.shape[0] * config['ratio'])
    while True:
        if num_labeled_data > 0:
            idx = random.randrange(0, train_embs.shape[0])
            if label_masks[idx] == 0:
                label_masks[idx] = 1
                num_labeled_data -= 1
        else:
            break
    print('[INFO] mask labels done.')

    # Balance data if required
    original_num_data = train_embs.shape[0]
    if config['ratio'] != 1 and config['balance']:
        print('[INFO] Balance required.')
        for idx in range(original_num_data): 
            if label_masks[idx]:
                balance = int(1/config['ratio'])
                balance = int(math.log(balance,2))
                if balance < 1:
                    balance = 1
                for b in range(0, int(balance)):
                    train_unpreprocessed_corpus.append(train_unpreprocessed_corpus[idx])
                    train_embs = np.concatenate((train_embs, train_embs[idx].reshape(1, train_embs.shape[1])), axis=0)
                    train_label = np.concatenate((train_label, train_label[idx].reshape(1, train_label.shape[1])), axis=0)
                    label_masks = np.concatenate((label_masks, label_masks[idx].reshape(1, label_masks.shape[1])), axis=0)
    
    training_set = IDEDataset(train_unpreprocessed_corpus, train_embs, train_label, label_masks)
    validation_set = IDEDataset(valid_unpreprocessed_corpus, valid_embs, valid_label, np.ones((valid_embs.shape[0], 1), dtype=bool))
    
    return training_set, validation_set, vocabularys, id2token, device 

In [3]:
config = {
    'experiment': 'autoencoder_testting',
    'model': 'VAE',
    'architecture': 'concatenate',
    'activation': 'sigmoid',
    'dataset': '20news',
    'vocab_size':0,
    'encoder': 'mpnet',
    'target': 'tf-idf-gensim',
    'seed': 123,
    'epochs': 300,
    'ae_epochs':10,
    'lr': 1e-4,
    'ae_lr':1e-4,
    'optim': 'AdamW',
    'scheduler': False,
    'warmup': 'linear',
    'warmup_proportion': 0.1, 
    'loss': 'listnet',
    'batch_size': 32,
    'weight_decay': 0,
    'ratio': 0.1,
    'topk': [5, 10, 15, 20, 25, 30, 35, 40, 45, 50],
    'save': False,
    'threshold': 0.7,
    'balance': False,
}
same_seeds(config["seed"])

In [4]:
# Parameter
if config['dataset'] == '20news':
    config['min_df'], config['max_df'], config['min_doc_word'] = 62, 1.0, 15
elif config['dataset'] == 'agnews':
    config['min_df'], config['max_df'], config['min_doc_word'] = 425, 1.0, 15
elif config['dataset'] == 'IMDB':
    config['min_df'], config['max_df'], config['min_doc_word'] = 166, 1.0, 15
elif config['dataset'] == 'wiki':
    config['min_df'], config['max_df'], config['min_doc_word'] = 2872, 1.0, 15
elif config['dataset'] == 'tweet':
    config['min_df'], config['max_df'], config['min_doc_word'] = 5, 1.0, 15

In [5]:
import sys
import random
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import multiprocessing as mp
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup, BertTokenizer, BertForMaskedLM, RobertaTokenizer, RobertaForMaskedLM, AlbertTokenizer, AlbertForMaskedLM
# from tqdm.auto import tqdm

sys.path.append("./")
from utils.loss import Singular_MythNet
from utils.eval import retrieval_normalized_dcg_all, retrieval_precision_all, semantic_precision_all, retrieval_precision_all_v2, semantic_precision_all_v2
from utils.toolbox import get_free_gpu, record_settings
from model.inference_network import ContextualInferenceNetwork

class IDEDataset(Dataset):
    def __init__(self, corpus, emb, target, mask):
        
        assert len(emb) == len(target)
        self.corpus = corpus
        self.emb = torch.FloatTensor(emb)
        self.target = torch.FloatTensor(target)
        self.mask = torch.BoolTensor(mask)
        
    def __getitem__(self, idx):
        return self.corpus[idx], self.emb[idx], self.target[idx], self.mask[idx]

    def __len__(self):
        return len(self.emb)

class Generator(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertForMaskedLM.from_pretrained('bert-base-uncased').to(device)

    def forward(self, documents):
        return self.get_docvec(documents)

    def get_docvec(self, documents):
        inputs = self.tokenizer(documents, return_tensors='pt', padding=True,
                                truncation=True, max_length=128).to(self.device)
        embedding = self.model.bert(**inputs).last_hidden_state[:, 0, :]
        return embedding

class Discriminator(nn.Module):
    def __init__(self, input_dim=768, output_dim=100, dropout=0.2):
        super(Discriminator, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, input_dim*4),
            nn.BatchNorm1d(input_dim*4),
            nn.Sigmoid(),
            nn.Dropout(dropout),
            nn.Linear(input_dim*4, output_dim),
            nn.BatchNorm1d(output_dim),
        )

    def forward(self, embs):
        recons = self.decoder(embs)
        return recons

class Classifier(nn.Module):
    def __init__(self, input_dim=768, output_dim=2):
        super(Classifier, self).__init__()
        self.logit = nn.Linear(input_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, embs):
        logits = self.logit(embs)
        probs = self.softmax(logits)
        return logits, probs
    
class MLPDecoder(nn.Module):
    def __init__(self, input_dim=768, output_dim=768, dropout=0.2):
        super(MLPDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, input_dim*4),
            nn.BatchNorm1d(input_dim*4),
            nn.Sigmoid(),
            nn.Dropout(dropout),
            nn.Linear(input_dim*4, output_dim),
            nn.BatchNorm1d(output_dim),
        )
    def forward(self, embs):
        recons = self.decoder(embs)
        return recons
    
class VariationalAE(nn.Module):
    def __init__(self, config, device, vocab_size, contextual_size=768, encoded_size=768, n_components=50, hidden_sizes=(100,100), activation='relu', dropout=0.2, learn_priors=True):
        super(VariationalAE, self).__init__()

        assert activation in ['softplus', 'relu']

        self.config = config
        self.device = device
        self.vocab_size = vocab_size
        self.contextual_size = contextual_size
        self.encoded_size = encoded_size
        self.n_components = n_components
        self.hidden_sizes = hidden_sizes
        self.activation = activation
        self.dropout = dropout
        self.learn_priors = learn_priors
        self.topic_word_matrix = None

        # decoder architecture
        self.batch_norm = nn.BatchNorm1d(vocab_size)
        self.word_embedding =  nn.Parameter(torch.randn(vocab_size*4, vocab_size))
        self.decoder = nn.Sequential(
            nn.Linear(encoded_size, contextual_size*4),
            nn.BatchNorm1d(contextual_size*4),
            nn.Sigmoid(),
            nn.Dropout(p=0.2),
            nn.Linear(contextual_size*4, contextual_size),
            nn.BatchNorm1d(contextual_size),
        )
        
        # topic model architecture
        self.inf_net = ContextualInferenceNetwork(encoded_size, contextual_size, n_components, hidden_sizes, activation, label_size=0)
        
        topic_prior_mean = 0.0
        self.prior_mean = torch.tensor([topic_prior_mean] * n_components).to(device)
        if self.learn_priors:
            self.prior_mean = nn.Parameter(self.prior_mean)

        topic_prior_variance = 1. - (1. / self.n_components)
        self.prior_variance = torch.tensor([topic_prior_variance] * n_components).to(device)
        if self.learn_priors:
            self.prior_variance = nn.Parameter(self.prior_variance)

        self.beta = torch.Tensor(n_components, encoded_size).to(device)
        self.beta = nn.Parameter(self.beta)
        
        nn.init.xavier_uniform_(self.beta)
        
        self.beta_batchnorm = nn.BatchNorm1d(encoded_size, affine=False)
        
        self.drop_theta = nn.Dropout(p=self.dropout)
    
    @staticmethod
    def reparameterize(mu, logvar):
        """Reparameterize the theta distribution."""
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, emb, target, labels=None):
        """Forward pass."""
        posterior_mu, posterior_log_sigma = self.inf_net(target, emb, labels)
        posterior_sigma = torch.exp(posterior_log_sigma)

        # generate samples from theta
        theta = F.softmax(self.reparameterize(posterior_mu, posterior_log_sigma), dim=1)
        theta = self.drop_theta(theta)

        # prodLDA
        # in: batch_size x input_size x n_components
        word_dist = F.softmax(self.beta_batchnorm(torch.matmul(theta, self.beta)), dim=1)
        # word_dist: batch_size x input_size
        
        self.topic_word_matrix = self.beta
        
        # decode
        recon = self.decoder(word_dist);
        return self.prior_mean, self.prior_variance, posterior_mu, posterior_sigma, posterior_log_sigma, word_dist, recon
    
    def get_theta(self, target, emb, labels=None):
        with torch.no_grad():
            posterior_mu, posterior_log_sigma = self.inf_net(target, emb, labels)
            theta = F.softmax(self.reparameterize(posterior_mu, posterior_log_sigma), dim=1)

            return theta


class IDEAEDecoder:
    def __init__(self, config, train_set, valid_set, vocab = None, id2token=None, device=None, contextual_dim=768, encoded_dim=768, noise_dim=100, word_embeddings=None, dropout=0.2, momentum=0.99, num_data_loader_workers=mp.cpu_count(), loss_weights=None, eps=1e-8):
        self.config = config
        self.train_set = train_set
        self.valid_set = valid_set
        self.vocab = vocab
        self.id2token = id2token
        self.device = device
        self.contextual_dim = contextual_dim
        self.encoded_dim = encoded_dim
        self.noise_dim = noise_dim
        self.word_embeddings = word_embeddings
        self.dropout = dropout
        self.momentum = momentum
        self.num_data_loader_workers = num_data_loader_workers
        self.loss_weights = loss_weights
        self.eps = eps
        self.relu = torch.nn.ReLU()
        self.cls_loss = torch.nn.CrossEntropyLoss()
        self.mse_loss = torch.nn.MSELoss(reduction='none')

        # model
        self.vae = VariationalAE(config, device, len(vocab), contextual_dim, encoded_dim, 50, (100, 100), 'relu', 0.2, True)
        self.decoder = MLPDecoder(encoded_dim, len(vocab), 0.2)
        self.generator = Generator(device)
        self.discriminator = Discriminator(input_dim=contextual_dim, output_dim=len(vocab), dropout=dropout)
        self.classifier = Classifier(input_dim=contextual_dim, output_dim=2)
        
        # optimizer
        if config['optim'] == 'AdamW':
            self.vae_optimizer = AdamW(self.vae.parameters(), lr=config['ae_lr'], eps=eps)
            self.decoder_optimizer = AdamW(self.decoder.parameters(), lr=config['lr'], eps=eps)
            self.gen_optimizer = AdamW(self.generator.parameters(), lr=config['lr'], eps=eps)
            self.dis_optimizer = AdamW(self.discriminator.parameters(), lr=config['lr'], eps=eps)
            self.cls_optimizer = AdamW(self.classifier.parameters(), lr=config['lr'], eps=eps)
        else:
            self.vae_optimizer = torch.optim.Adam(self.vae.parameters(), lr=config['ae_lr'], betas=(self.momentum, 0.99), weight_decay=config['weight_decay'])
            self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), lr=config['lr'], betas=(self.momentum, 0.99), weight_decay=config['weight_decay'])
            self.gen_optimizer = torch.optim.Adam(self.generator.parameters(), lr=config['lr'], betas=(self.momentum, 0.99), weight_decay=config['weight_decay'])
            self.dis_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=config['lr'], betas=(self.momentum, 0.99), weight_decay=config['weight_decay'])
            self.cls_optimizer = torch.optim.Adam(self.classifier.parameters(), lr=config['lr'], betas=(self.momentum, 0.99), weight_decay=config['weight_decay'])
        
        # scheduler
        if config['scheduler']:
            num_training_steps = int(len(train_set) / config['batch_size'] * config['epochs'])
            num_warmup_steps = int(num_training_steps * config['warmup_proportion'])
            self.vae_optimizer = AdamW(self.vae.parameters(), lr=config['ae_lr'], eps=eps)
            self.decoder_optimizer = AdamW(self.decoder.parameters(), lr=config['lr'], eps=eps)
            self.gen_optimizer = AdamW(self.generator.parameters(), lr=config['lr'], eps=eps)
            self.dis_optimizer = AdamW(self.discriminator.parameters(), lr=config['lr'], eps=eps)
            self.cls_optimizer = AdamW(self.classifier.parameters(), lr=config['lr'], eps=eps)
            if config['warmup'] == 'linear':
                self.vae_scheduler = get_linear_schedule_with_warmup(self.vae_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
                self.decoder_scheduler = get_linear_schedule_with_warmup(self.decoder_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
                self.gen_scheduler = get_linear_schedule_with_warmup(self.gen_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
                self.dis_scheduler = get_linear_schedule_with_warmup(self.dis_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
                self.cls_scheduler = get_linear_schedule_with_warmup(self.cls_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
            else:
                self.vae_scheduler = get_constant_schedule_with_warmup(self.vae_optimizer, num_warmup_steps=num_warmup_steps)
                self.decoder_scheduler = get_constant_schedule_with_warmup(self.decoder_optimizer, num_warmup_steps=num_warmup_steps)
                self.gen_scheduler = get_constant_schedule_with_warmup(self.gen_optimizer, num_warmup_steps=num_warmup_steps)
                self.dis_scheduler = get_constant_schedule_with_warmup(self.dis_optimizer, num_warmup_steps=num_warmup_steps)
                self.cls_scheduler = get_constant_schedule_with_warmup(self.cls_optimizer, num_warmup_steps=num_warmup_steps)
                
    def ae_training(self, epoch, loader):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch + 1, self.config['ae_epochs']))
        print('AutoEncoder Training...')

        ae_train_loss = 0
        ae_train_cos = 0
        
        self.vae.train()

        for batch, (corpus, embs, labels, masks) in enumerate(loader):
            embs, masks = embs.to(self.device), masks.to(self.device)
            _, _, _, _, _, encoded, decoded = self.vae(embs, embs)
            
            # Loss weight
            cos = torch.nn.functional.cosine_similarity(torch.mean(embs, dim=0), torch.mean(decoded, dim=0), dim=0)
            # w = 1 - cos
            
            # Encode-Decode's Loss
            recon_loss = torch.mean(self.mse_loss(decoded, embs), dim=1)
            mask_loss = torch.masked_select(recon_loss, torch.flatten(~masks))     
            decoded_loss = torch.mean(mask_loss)
            print(decoded_loss)
                       
            self.vae_optimizer.zero_grad()
            decoded_loss.backward() 
            self.vae_optimizer.step()

            ae_train_loss += decoded_loss.item()
            ae_train_cos += cos

        avg_ae_train_loss = ae_train_loss / len(loader)  
        avg_ae_train_cos = ae_train_cos / len(loader)
        

        print("")
        print("  Average training loss AutoEncoder: {0:.3f}".format(avg_ae_train_loss))
        print("  Average training Cosine-Similarity AutoEncoder: {0:.3f}".format(avg_ae_train_cos))

        return avg_ae_train_loss, avg_ae_train_cos
    
    def mlp_training(self, epoch, loader):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch + 1, self.config['epochs']))
        print('Decoder Training...')

        decode_train_loss = 0

        self.vae.eval()
        self.decoder.train()

        for batch, (corpus, embs, labels, masks) in enumerate(loader):
            embs, labels, masks = embs.to(self.device), labels.to(self.device), masks.to(self.device)

            # VAE transform
            _, _, _, _, _, encoded, _ = self.vae(embs, embs)   
            
            # Decode
            recons = self.decoder(encoded)
            
            # Decoder's LOSS
            mask_loss = torch.masked_select(Singular_MythNet(recons, labels), torch.flatten(masks))
            labeled_count = mask_loss.type(torch.float32).numel()
            if labeled_count == 0:
                continue
            else:
                decoded_loss = torch.mean(mask_loss)
            
            self.decoder_optimizer.zero_grad()
            decoded_loss.backward() 
            self.decoder_optimizer.step()

            decode_train_loss += decoded_loss.item()

        avg_decoded_train_loss = decode_train_loss / len(loader)             

        print("")
        print("  Average training loss decoder: {0:.3f}".format(avg_decoded_train_loss))

        return avg_decoded_train_loss
    
    def gen_training(self, epoch, loader):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch + 1, self.config['epochs']))
        print('Bert GAN Training...')
        
        gen_train_loss = 0
        
        #self.vae.train()
        self.generator.train()
        self.discriminator.eval()
        self.classifier.eval()

        for batch, (corpus, embs, labels, masks) in enumerate(loader):
            real_embs, labels, masks = embs.to(self.device), labels.to(self.device), masks.to(self.device)
            cur_batch_size = embs.shape[0]
            
            # vae transform
            #real_embs_t = self.vae(real_embs)
            real_embs_t = real_embs
            
            # fake label from BERT
            noise = torch.empty(cur_batch_size, dtype=torch.long).random_(len(self.train_set))
            noise_docs = []
            noise_labels = torch.FloatTensor([])
            for i in range(cur_batch_size):     
                noise_docs.append(self.train_set[i][0])
                noise_labels = torch.cat((noise_labels, self.train_set[i][2]))
            fake_labels = torch.reshape(noise_labels, (cur_batch_size, len(self.vocab))).to(self.device)
            
            fake_embs = self.generator(noise_docs).to(self.device)

            mixed_embs = torch.cat((real_embs_t, fake_embs), dim=0)
            logits, probs = self.classifier(mixed_embs)
            recons = self.discriminator(mixed_embs)         

            recons_list = torch.split(recons, cur_batch_size)
            D_real_recons = recons_list[0]
            D_fake_recons = recons_list[1]
        
            logits_list = torch.split(logits, cur_batch_size)
            D_real_logits = logits_list[0]
            D_fake_logits = logits_list[1]
            
            probs_list = torch.split(probs, cur_batch_size)
            D_real_probs = probs_list[0]
            D_fake_probs = probs_list[1]

            # Generator's LOSS
            g_loss_d = -1 * torch.mean(torch.log(1 - D_fake_probs[:,-1] + self.eps))
            g_feat_emb = torch.mean(torch.pow(torch.mean(real_embs_t, dim=0) - torch.mean(fake_embs, dim=0), 2))
            gen_loss = g_loss_d + g_feat_emb
            

            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            self.gen_optimizer.step()
            if self.config['scheduler']:
                self.gen_scheduler.step()

            gen_train_loss += gen_loss.item()

        avg_gen_train_loss = gen_train_loss / len(loader)           

        print("")
        print("  Average training loss generetor: {0:.3f}".format(avg_gen_train_loss))

        return avg_gen_train_loss
        
    def dis_training(self, epoch, loader):      
        cls_train_loss, dis_train_loss = 0, 0
        
        #self.vae.train()
        self.generator.eval()
        self.discriminator.train()
        self.classifier.train()

        for batch, (corpus, embs, labels, masks) in enumerate(loader):
            real_embs, labels, masks = embs.to(self.device), labels.to(self.device), masks.to(self.device)
            cur_batch_size = embs.shape[0]
            
            # vae transform
            #real_embs_t = self.vae(real_embs)
            real_embs_t = real_embs
            
            # fake label from BERT
            noise = torch.empty(cur_batch_size, dtype=torch.long).random_(len(self.train_set))
            noise_docs = []
            noise_labels = torch.FloatTensor([])
            for i in range(cur_batch_size):     
                noise_docs.append(self.train_set[i][0])
                noise_labels = torch.cat((noise_labels, self.train_set[i][2]))
            fake_labels = torch.reshape(noise_labels, (cur_batch_size, len(self.vocab))).to(self.device)
            
            fake_embs = self.generator(noise_docs).to(self.device)

            mixed_embs = torch.cat((real_embs_t, fake_embs), dim=0)
            logits, probs = self.classifier(mixed_embs)
            recons = self.discriminator(mixed_embs)

            recons_list = torch.split(recons, cur_batch_size)
            D_real_recons = recons_list[0]
            D_fake_recons = recons_list[1]
        
            logits_list = torch.split(logits, cur_batch_size)
            D_real_logits = logits_list[0]
            D_fake_logits = logits_list[1]
            
            probs_list = torch.split(probs, cur_batch_size)
            D_real_probs = probs_list[0]
            D_fake_probs = probs_list[1]
            
            # Classifier's Loss
            D_L_unsupervised1U = -1 * torch.mean(torch.log(1 - D_real_probs[:, -1] + self.eps))
            D_L_unsupervised2U = -1 * torch.mean(torch.log(D_fake_probs[:, -1] + self.eps))
            #D_L_unsupervised1U = self.cls_loss(D_real_logits, torch.ones(cur_batch_size, dtype=torch.long).to(self.device))
            #D_L_unsupervised2U = self.cls_loss(D_fake_logits, torch.zeros(cur_batch_size, dtype=torch.long).to(self.device)) 
            cls_loss =  D_L_unsupervised1U + D_L_unsupervised2U
            
            # Disciminator's LOSS
            recon_loss = torch.masked_select(Singular_MythNet(D_real_recons, labels), torch.flatten(masks))
            g_recon_weight =  D_fake_probs[:, 0]
            fake_recon_loss = Singular_MythNet(D_fake_recons, fake_labels) * g_recon_weight
            labeled_count = recon_loss.type(torch.float32).numel()
            
            if labeled_count == 0:
                D_L_Supervised = torch.mean(fake_recon_loss)
            else:
                D_L_Supervised = torch.mean(recon_loss) + torch.mean(fake_recon_loss)                    
            dis_loss = D_L_Supervised + cls_loss
            
            self.cls_optimizer.zero_grad()
            cls_loss.backward(retain_graph=True)
            self.cls_optimizer.step()
            if self.config['scheduler']:
                self.cls_scheduler.step()
                
            self.dis_optimizer.zero_grad()
            dis_loss.backward()
            self.dis_optimizer.step()
            if self.config['scheduler']:
                self.dis_scheduler.step()
            
            cls_train_loss += cls_loss.item()
            dis_train_loss += dis_loss.item()
        
        avg_cls_train_loss = cls_train_loss / len(loader)
        avg_dis_train_loss = dis_train_loss / len(loader)           
        
        print("  Average training loss classifier: {0:.3f}".format(avg_cls_train_loss))
        print("  Average training loss discriminator: {0:.3f}".format(avg_dis_train_loss))

        return avg_cls_train_loss, avg_dis_train_loss
        
    def ae_validation(self, loader):
        ae_val_loss = 0
        ae_val_cos = 0
        
        self.vae.eval()
        
        with torch.no_grad():
            for batch, (corpus, embs, labels, masks) in enumerate(loader):
                embs, masks = embs.to(self.device), masks.to(self.device)
                prior_mean, prior_variance, posterior_mean, posterior_variance,\
                posterior_log_variance, encoded, decoded = self.vae(embs, embs)

                # Loss weight
                cos = torch.nn.functional.cosine_similarity(torch.mean(embs, dim=0), torch.mean(decoded, dim=0), dim=0)
                w = 1 - cos
                
                # Encode-Decode's Loss
                recon_loss = torch.mean(self.mse_loss(decoded, embs), dim=1)    
                decoded_loss = torch.mean(recon_loss) * w

                ae_val_loss += decoded_loss.item()
                ae_val_cos += cos
                
            avg_ae_val_loss = ae_val_loss / len(loader)
            avg_ae_val_cos = ae_val_cos / len(loader)
        
        return avg_ae_val_loss, avg_ae_val_cos
    
    def mlp_validation(self, loader):
        self.vae.eval()
        self.decoder.eval()
        
        results = defaultdict(list)
        with torch.no_grad():
            for batch, (corpus, embs, labels, masks) in enumerate(loader):
                embs, labels = embs.to(self.device), labels.to(self.device)
                
                # VAE transform
                _, _, _, _, _, encoded, _ = self.vae(embs, embs)   

                # Decode
                recons = self.decoder(encoded)
                
                # Precision for reconstruct
                precision_scores = retrieval_precision_all(recons, labels, k=self.config['topk'])
                for k, v in precision_scores.items():
                    results['[Recon] Precision v1@{}'.format(k)].append(v)
                
                precision_scores = retrieval_precision_all_v2(recons, labels, k=self.config['topk'])
                for k, v in precision_scores.items():
                    results['[Recon] Precision v2@{}'.format(k)].append(v)

                # NDCG for reconstruct
                ndcg_scores = retrieval_normalized_dcg_all(recons, labels, k=self.config['topk'])
                for k, v in ndcg_scores.items():
                    results['[Recon] ndcg@{}'.format(k)].append(v)

        for k in results:
            results[k] = np.mean(results[k])
                
        return results
    
    def gan_validation(self, loader):
        self.vae.eval()
        self.generator.eval()
        self.classifier.eval()
        self.discriminator.eval()
        
        results = defaultdict(list)
        with torch.no_grad():
            for batch, (corpus, embs, labels, masks) in enumerate(loader):
                embs, labels = embs.to(self.device), labels.to(self.device)
                #embs_t = self.vae(embs, embs)
                embs_t = embs
                
                logits, probs = self.classifier(embs_t)
                recons = self.discriminator(embs_t)
                
                # Precision for reconstruct
                precision_scores = retrieval_precision_all(recons, labels, k=self.config['topk'])
                for k, v in precision_scores.items():
                    results['[Recon] Precision v1@{}'.format(k)].append(v)
                
                precision_scores = retrieval_precision_all_v2(recons, labels, k=self.config['topk'])
                for k, v in precision_scores.items():
                    results['[Recon] Precision v2@{}'.format(k)].append(v)

                # NDCG for reconstruct
                ndcg_scores = retrieval_normalized_dcg_all(recons, labels, k=self.config['topk'])
                for k, v in ndcg_scores.items():
                    results['[Recon] ndcg@{}'.format(k)].append(v)

        for k in results:
            results[k] = np.mean(results[k])
                
        return results
    
    def ae_fit(self):
        self.vae.to(self.device)

        train_loader = DataLoader(self.train_set, batch_size=self.config['batch_size'], shuffle=True, num_workers=self.num_data_loader_workers)
        valid_loader = DataLoader(self.valid_set, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.num_data_loader_workers)
        ae_loss = 0

        for epoch in range(self.config['ae_epochs']):
            ae_loss, ae_cos = self.ae_training(epoch, train_loader)
            if (epoch + 1) % 10 == 0:
                val_loss, val_cos = self.ae_validation(valid_loader)
                withscheduler = 'with_scheduler' if self.config['scheduler'] else '_without_scheduler'
                withbalance = 'with_balance' if self.config['balance'] else '_without_balance'
                record = open('./ae_'+self.config['experiment']+'_'+self.config['dataset']+str(int(self.config['ratio'] * 100))+'_'+self.config['encoder']+'_loss_'+self.config['loss']+'_lr'+str(self.config['lr'])+'_optim'+self.config['optim']+withscheduler+withbalance+'_weightdecay'+str(self.config['weight_decay'])+'.txt', 'a')
                print('---------------------------------------')
                record.write('-------------------------------------------------\n')
                print("AutoEncoder Validation loss: {0:.3f}".format(val_loss))
                record.write("AutoEncoder training loss: {0:.3f}\n".format(val_loss))
                print("AutoEncoder validation Cosine-Similarity: {0:.3f}".format(val_cos))
                record.write("AutoEncoder validation Cosine-Similarity: {0:.3f}\n".format(val_cos))
    
    def mlp_fit(self):
        self.decoder.to(self.device)

        train_loader = DataLoader(self.train_set, batch_size=self.config['batch_size'], shuffle=True, num_workers=self.num_data_loader_workers)
        valid_loader = DataLoader(self.valid_set, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.num_data_loader_workers)

        decoded_train_loss = 0

        for epoch in range(self.config['epochs']):
            decoded_train_loss = self.mlp_training(epoch, train_loader)
            if (epoch + 1) % 10 == 0:
                val_res = self.mlp_validation(valid_loader)
                withscheduler = 'with_scheduler' if self.config['scheduler'] else '_without_scheduler'
                withbalance = 'with_balance' if self.config['balance'] else '_without_balance'
                record = open('./ide_semi_'+self.config['experiment']+'_'+self.config['dataset']+str(int(self.config['ratio'] * 100))+'_'+self.config['encoder']+'_'+self.config['target']+'_loss_'+self.config['loss']+'_lr'+str(self.config['lr'])+'_optim'+self.config['optim']+withscheduler+withbalance+'_weightdecay'+str(self.config['weight_decay'])+'.txt', 'a')
                print('---------------------------------------')
                record.write('-------------------------------------------------\n')
                for key,val in val_res.items():
                    print(f"{key}:{val:.4f}")
                    record.write(f"{key}:{val:.4f}\n")
                print("Decoder training loss: {0:.3f}".format(decoded_train_loss))
                record.write("Decoder training loss: {0:.3f}\n".format(decoded_train_loss))

    def gan_fit(self):
        self.generator.to(self.device)
        self.classifier.to(self.device)
        self.discriminator.to(self.device)

        train_loader = DataLoader(self.train_set, batch_size=self.config['batch_size'], shuffle=True, num_workers=self.num_data_loader_workers)
        valid_loader = DataLoader(self.valid_set, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.num_data_loader_workers)

        gen_train_loss, cls_train_loss, dis_train_loss = 0, 0, 0

        for epoch in range(self.config['epochs']):
            gen_train_loss = self.gen_training(epoch, train_loader)
            cls_train_loss, dis_train_loss = self.dis_training(epoch, train_loader)
            if (epoch + 1) % 10 == 0:
                val_res = self.gan_validation(valid_loader)
                withscheduler = 'with_scheduler' if self.config['scheduler'] else '_without_scheduler'
                withbalance = 'with_balance' if self.config['balance'] else '_without_balance'
                record = open('./ide_gan_'+self.config['experiment']+'_'+self.config['dataset']+str(int(self.config['ratio'] * 100))+'_'+self.config['encoder']+'_'+self.config['target']+'_loss_'+self.config['loss']+'_lr'+str(self.config['lr'])+'_optim'+self.config['optim']+withscheduler+withbalance+'_weightdecay'+str(self.config['weight_decay'])+'.txt', 'a')
                print('---------------------------------------')
                record.write('-------------------------------------------------\n')
                for key,val in val_res.items():
                    print(f"{key}:{val:.4f}")
                    record.write(f"{key}:{val:.4f}\n")
                print("Generator training loss: {0:.3f}".format(gen_train_loss))
                record.write("Generator training loss: {0:.3f}\n".format(gen_train_loss))
                print("Classifier training loss: {0:.3f}".format(cls_train_loss))
                record.write("Classifier training loss: {0:.3f}\n".format(cls_train_loss))
                print("Discriminator training loss: {0:.3f}".format(dis_train_loss))
                record.write("Discriminator training loss: {0:.3f}\n".format(dis_train_loss))

In [None]:
training_set, validation_set, vocabularys, id2token, device = generate_dataset(config)

Getting preprocess documents: 20news
min_df: 62 max_df: 1.0 vocabulary_size: None min_doc_word: 15




In [36]:
model = IDEAEDecoder(config, training_set, validation_set, vocabularys, id2token, device)
#model.ae_fit()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
model.gan_fit()


Bert GAN Training...

  Average training loss generetor: 0.061
  Average training loss classifier: 1.749
  Average training loss discriminator: 13.628

Bert GAN Training...

  Average training loss generetor: 0.095
  Average training loss classifier: 1.278
  Average training loss discriminator: 11.809

Bert GAN Training...

  Average training loss generetor: 0.116
  Average training loss classifier: 1.306
  Average training loss discriminator: 11.340

Bert GAN Training...

  Average training loss generetor: 0.116
  Average training loss classifier: 1.066
  Average training loss discriminator: 10.708

Bert GAN Training...

  Average training loss generetor: 0.104
  Average training loss classifier: 1.007
  Average training loss discriminator: 10.264

Bert GAN Training...

  Average training loss generetor: 0.097
  Average training loss classifier: 0.938
  Average training loss discriminator: 9.821

Bert GAN Training...

  Average training loss generetor: 0.085
  Average training loss c


  Average training loss generetor: 0.051
  Average training loss classifier: 1.043
  Average training loss discriminator: 10.250

Bert GAN Training...

  Average training loss generetor: 0.048
  Average training loss classifier: 1.020
  Average training loss discriminator: 10.090

Bert GAN Training...

  Average training loss generetor: 0.062
  Average training loss classifier: 1.055
  Average training loss discriminator: 10.276

Bert GAN Training...

  Average training loss generetor: 0.047
  Average training loss classifier: 1.009
  Average training loss discriminator: 9.938

Bert GAN Training...

  Average training loss generetor: 0.071
  Average training loss classifier: 1.013
  Average training loss discriminator: 9.944

Bert GAN Training...

  Average training loss generetor: 0.051
  Average training loss classifier: 1.031
  Average training loss discriminator: 10.207

Bert GAN Training...

  Average training loss generetor: 0.050
  Average training loss classifier: 1.073
  Aver


  Average training loss generetor: 0.145
  Average training loss classifier: 1.034
  Average training loss discriminator: 9.598

Bert GAN Training...

  Average training loss generetor: 0.119
  Average training loss classifier: 1.221
  Average training loss discriminator: 10.066

Bert GAN Training...

  Average training loss generetor: 0.106
  Average training loss classifier: 1.190
  Average training loss discriminator: 10.028

Bert GAN Training...

  Average training loss generetor: 0.082
  Average training loss classifier: 1.113
  Average training loss discriminator: 9.644

Bert GAN Training...

  Average training loss generetor: 0.128
  Average training loss classifier: 1.133
  Average training loss discriminator: 9.784

Bert GAN Training...

  Average training loss generetor: 0.092
  Average training loss classifier: 1.147
  Average training loss discriminator: 9.726

Bert GAN Training...

  Average training loss generetor: 0.111
  Average training loss classifier: 1.087
  Averag


  Average training loss generetor: 0.183
  Average training loss classifier: 1.163
  Average training loss discriminator: 9.393

Bert GAN Training...

  Average training loss generetor: 0.129
  Average training loss classifier: 1.225
  Average training loss discriminator: 9.628

Bert GAN Training...

  Average training loss generetor: 0.136
  Average training loss classifier: 1.180
  Average training loss discriminator: 9.437

Bert GAN Training...

  Average training loss generetor: 0.121
  Average training loss classifier: 1.200
  Average training loss discriminator: 9.479

Bert GAN Training...

  Average training loss generetor: 0.167
  Average training loss classifier: 1.137
  Average training loss discriminator: 9.273

Bert GAN Training...

  Average training loss generetor: 0.209
  Average training loss classifier: 1.079
  Average training loss discriminator: 9.014

Bert GAN Training...

  Average training loss generetor: 0.162
  Average training loss classifier: 1.265
  Average 


  Average training loss generetor: 0.263
  Average training loss classifier: 1.105
  Average training loss discriminator: 8.588

Bert GAN Training...

  Average training loss generetor: 0.182
  Average training loss classifier: 1.280
  Average training loss discriminator: 9.240

Bert GAN Training...

  Average training loss generetor: 0.163
  Average training loss classifier: 1.244
  Average training loss discriminator: 9.276

Bert GAN Training...

  Average training loss generetor: 0.216
  Average training loss classifier: 1.140
  Average training loss discriminator: 8.682

Bert GAN Training...

  Average training loss generetor: 0.176
  Average training loss classifier: 1.267
  Average training loss discriminator: 9.058

Bert GAN Training...

  Average training loss generetor: 0.143
  Average training loss classifier: 1.201
  Average training loss discriminator: 8.917

Bert GAN Training...

  Average training loss generetor: 0.135
  Average training loss classifier: 1.198
  Average 


  Average training loss generetor: 0.188
  Average training loss classifier: 1.050
  Average training loss discriminator: 7.871

Bert GAN Training...

  Average training loss generetor: 0.289
  Average training loss classifier: 1.150
  Average training loss discriminator: 8.262

Bert GAN Training...

  Average training loss generetor: 0.227
  Average training loss classifier: 1.037
  Average training loss discriminator: 8.124

Bert GAN Training...

  Average training loss generetor: 0.199
  Average training loss classifier: 1.246
  Average training loss discriminator: 8.748

Bert GAN Training...

  Average training loss generetor: 0.224
  Average training loss classifier: 1.244
  Average training loss discriminator: 8.851

Bert GAN Training...

  Average training loss generetor: 0.254
  Average training loss classifier: 1.070
  Average training loss discriminator: 8.244

Bert GAN Training...

  Average training loss generetor: 0.218
  Average training loss classifier: 1.150
  Average 


  Average training loss generetor: 0.160
  Average training loss classifier: 1.168
  Average training loss discriminator: 8.108

Bert GAN Training...

  Average training loss generetor: 0.157
  Average training loss classifier: 1.262
  Average training loss discriminator: 8.394

Bert GAN Training...

  Average training loss generetor: 0.261
  Average training loss classifier: 1.063
  Average training loss discriminator: 7.752

Bert GAN Training...

  Average training loss generetor: 0.113
  Average training loss classifier: 1.230
  Average training loss discriminator: 8.617

Bert GAN Training...

  Average training loss generetor: 0.195
  Average training loss classifier: 1.132
  Average training loss discriminator: 7.942

Bert GAN Training...

  Average training loss generetor: 0.078
  Average training loss classifier: 1.154
  Average training loss discriminator: 8.009

Bert GAN Training...

  Average training loss generetor: 0.175
  Average training loss classifier: 1.276
  Average 


  Average training loss generetor: 0.197
  Average training loss classifier: 1.137
  Average training loss discriminator: 7.881

Bert GAN Training...

  Average training loss generetor: 0.199
  Average training loss classifier: 1.150
  Average training loss discriminator: 7.938

Bert GAN Training...

  Average training loss generetor: 0.172
  Average training loss classifier: 1.248
  Average training loss discriminator: 8.189

Bert GAN Training...

  Average training loss generetor: 0.144
  Average training loss classifier: 1.253
  Average training loss discriminator: 8.202

Bert GAN Training...

  Average training loss generetor: 0.157
  Average training loss classifier: 1.179
  Average training loss discriminator: 7.969

Bert GAN Training...

  Average training loss generetor: 0.196
  Average training loss classifier: 1.145
  Average training loss discriminator: 7.786

Bert GAN Training...

  Average training loss generetor: 0.126
  Average training loss classifier: 1.145
  Average 

In [76]:
input = torch.randn(1, 768)
label = torch.randn(1, 768)

In [78]:
m_input = torch.unsqueeze(torch.mean(input, dim=0), 0)
n_label = torch.unsqueeze(torch.mean(label, dim=0), 0)

In [79]:
m_input.shape

torch.Size([1, 768])

In [80]:
dis = torch.cdist(m_input, n_label, p=2)

In [63]:
dis1 = torch.mean(dis, dim=0)

In [65]:
dis2 = torch.mean(dis1, dim=0)

In [84]:
dis.squeeze()

tensor(38.9615)

In [23]:
label

tensor([[-0.5169, -0.8715],
        [ 1.5264, -0.2267]])