In [None]:
import math
import argparse

import numpy as np
import torch
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from numpy import random
from torch.nn.parameter import Parameter
from OpenAttMultiGL.utils.dataset import dataset
from OpenAttMultiGL.utils.process import * 
from OpenAttMultiGL.layers.hdmi.gcn import GCN
from OpenAttMultiGL.model.GATNE.utils import *
from OpenAttMultiGL.model.GATNE.evaluate import evaluate
#from mGCN_Toolbox.model.GATNE.walk import *
#from OpenAttMultiGL.model.X_GOAL.evaluate import *

def combine_att(h_list):
    att_act1 = nn.Tanh()
    att_act2 = nn.Softmax(dim=-1)
    h_combine_list = []
    for i, h in enumerate(h_list):
        h = w_list[i](h)
        h = y_list[i](h)
        h_combine_list.append(h)
    score = torch.cat(h_combine_list, -1)
    score = att_act1(score)
    score = att_act2(score)
    score = torch.unsqueeze(score, -1)
    h = torch.stack(h_list, dim=1)
    h = score * h
    h = torch.sum(h, dim=1)
    return h

def embed(seq, adj_list, sparse):
    global w_list
    global y_list
    gcn_list = nn.ModuleList([GCN(ft_size, hid_units) for _ in range(n_networks)])
    w_list = nn.ModuleList([nn.Linear(hid_units, hid_units, bias=False) for _ in range(n_networks)])
    y_list = nn.ModuleList([nn.Linear(hid_units, 1) for _ in range(n_networks)])
    h_1_list = []
    for i, adj in enumerate(adj_list):
        h_1 = torch.squeeze(gcn_list[i](seq, adj, sparse))
        h_1_list.append(h_1)
    h = combine_att(h_1_list)
    return h.detach()





def get_batches(pairs, neighbors, batch_size):
    n_batches = (len(pairs) + (batch_size - 1)) // batch_size

    for idx in range(n_batches):
        x, y, t, neigh = [], [], [], []
        for i in range(batch_size):
            index = idx * batch_size + i
            if index >= len(pairs):
                break
            x.append(pairs[index][0])
            y.append(pairs[index][1])
            t.append(pairs[index][2])
            neigh.append(neighbors[pairs[index][0]])
        yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh)


class GATNEModel(nn.Module):
    def __init__(
        self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a, features
    ):
        super(GATNEModel, self).__init__()
        self.num_nodes = num_nodes
        self.embedding_size = embedding_size
        self.embedding_u_size = embedding_u_size
        self.edge_type_count = edge_type_count
        self.dim_a = dim_a

        self.features = None
        if features is not None:
            self.features = features
            feature_dim = self.features.shape[-1]
            self.embed_trans = Parameter(torch.FloatTensor(feature_dim, embedding_size))
            self.u_embed_trans = Parameter(torch.FloatTensor(edge_type_count, feature_dim, embedding_u_size))
        else:
            self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
            self.node_type_embeddings = Parameter(
                torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
            )
        self.trans_weights = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
        )
        self.trans_weights_s1 = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
        )
        self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))

        self.reset_parameters()

    def reset_parameters(self):
        if self.features is not None:
            self.embed_trans.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
            self.u_embed_trans.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        else:
            self.node_embeddings.data.uniform_(-1.0, 1.0)
            self.node_type_embeddings.data.uniform_(-1.0, 1.0)
        self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, train_inputs, train_types, node_neigh):
        if self.features is None:
            node_embed = self.node_embeddings[train_inputs]
            node_embed_neighbors = self.node_type_embeddings[node_neigh]
        else:
            node_embed = torch.mm(self.features[train_inputs], self.embed_trans)
            node_embed_neighbors = torch.einsum('bijk,akm->bijam', self.features[node_neigh], self.u_embed_trans)
        node_embed_tmp = torch.diagonal(node_embed_neighbors, dim1=1, dim2=3).permute(0, 3, 1, 2)
        node_type_embed = torch.sum(node_embed_tmp, dim=2)

        trans_w = self.trans_weights[train_types]
        trans_w_s1 = self.trans_weights_s1[train_types]
        trans_w_s2 = self.trans_weights_s2[train_types]

        attention = F.softmax(
            torch.matmul(
                torch.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2
            ).squeeze(2),
            dim=1,
        ).unsqueeze(1)
        node_type_embed = torch.matmul(attention, node_type_embed)
        node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze(1)

        last_node_embed = F.normalize(node_embed, dim=1)

        return last_node_embed


class NSLoss(nn.Module):
    def __init__(self, num_nodes, num_sampled, embedding_size):
        super(NSLoss, self).__init__()
        self.num_nodes = num_nodes
        self.num_sampled = num_sampled
        self.embedding_size = embedding_size
        self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
        self.sample_weights = F.normalize(
            torch.Tensor(
                [
                    (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
                    for k in range(num_nodes)
                ]
            ),
            dim=0,
        )

        self.reset_parameters()

    def reset_parameters(self):
        self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, input, embs, label):
        n = input.shape[0]
        log_target = torch.log(
            torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
        )
        negs = torch.multinomial(
            self.sample_weights, self.num_sampled * n, replacement=True
        ).view(n, self.num_sampled)
        noise = torch.neg(self.weights[negs])
        sum_log_sampled = torch.sum(
            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
        ).squeeze()

        loss = log_target + sum_log_sampled
        return -loss.sum() / n


def train_model(network_data, feature_dic,dataset):
    vocab, index2word, train_pairs = generate(network_data, args.num_walks, args.walk_length, args.schema, file_name, args.window_size, args.num_workers, args.walk_file)

    edge_types = list(network_data.keys())

    num_nodes = len(index2word)
    edge_type_count = len(edge_types)
    epochs = args.epoch
    batch_size = args.batch_size
    embedding_size = args.dimensions
    embedding_u_size = args.edge_dim
    u_num = edge_type_count
    num_sampled = args.negative_samples
    dim_a = args.att_dim
    att_head = 1
    neighbor_samples = args.neighbor_samples

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    neighbors = generate_neighbors(network_data, vocab, num_nodes, edge_types, neighbor_samples)

    features = None
    if feature_dic is not None:
        feature_dim = len(list(feature_dic.values())[0])
        print('feature dimension: ' + str(feature_dim))
        features = np.zeros((num_nodes, feature_dim), dtype=np.float32)
        for key, value in feature_dic.items():
            if key in vocab:
                features[vocab[key].index, :] = np.array(value)
        features = torch.FloatTensor(features).to(device)

    model = GATNEModel(
        num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a, features
    )
    nsloss = NSLoss(num_nodes, num_sampled, embedding_size)

    model.to(device)
    nsloss.to(device)

    optimizer = torch.optim.Adam(
        [{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4
    )
    best_micro = 0
    best_macro = 0
    best_score = 0
    test_score = (0.0, 0.0, 0.0)
    patience = 0
    sparse = True
    labels = torch.FloatTensor(dataset.gcn_labels)
    idx_train = torch.LongTensor(dataset.train_id)
    idx_val = torch.LongTensor(dataset.valid_id)
    idx_test = torch.LongTensor(dataset.test_id)
    macro = []
    micro = []
    k1_list = []
    sim_list = []
    for epoch in range(epochs):
        random.shuffle(train_pairs)
        batches = get_batches(train_pairs, neighbors, batch_size)

        data_iter = tqdm(
            batches,
            desc="epoch %d" % (epoch),
            total=(len(train_pairs) + (batch_size - 1)) // batch_size,
            bar_format="{l_bar}{r_bar}",
        )
        avg_loss = 0.0
        for i, data in enumerate(data_iter):
            optimizer.zero_grad()
            embs = model(data[0].to(device), data[2].to(device), data[3].to(device),)
            loss = nsloss(data[0].to(device), embs, data[1].to(device))
            loss.backward()
            optimizer.step()
            #print('embs: ', embs)
            avg_loss += loss.item()

            if i % 5000 == 0:
                post_fix = {
                    "epoch": epoch,
                    "iter": i,
                    "avg_loss": avg_loss / (i + 1),
                    "loss": loss.item(),
                }
                data_iter.write(str(post_fix))
        
        model.eval()
        features = torch.FloatTensor(preprocessed_features)
        gcn_adj_list = [normalize_adj(adj) for adj in dataset.gcn_adj_list]
        adj_list = [sparse_mx_to_torch_sparse_tensor(adj) for adj in gcn_adj_list]
        embeds = embed(features, adj_list, sparse)
        
        macro_f1s, micro_f1s, k1, sim = evaluate(embeds, idx_train, idx_val, idx_test, labels)
        f1_macro = np.mean(macro_f1s)
        f1_micro = np.mean(micro_f1s)
        
        macro.append(f1_macro)
        micro.append(f1_micro)
        k1_list.append(k1)
        sim_list.append(sim)
    #return average_micro,average_macro,average_sim,average_nmi
    return macro,micro,k1_list,sim_list


if __name__ == "__main__":
    args = parse_args()
    file_name = args.input
    t = dataset('imdb')
    if args.features is not None:
        feature_dic = load_feature_data(args.features)
    else:
        feature_dic = None
        
    preprocessed_features = preprocess_features(t.features)
    ft_size = preprocessed_features[0].shape[1] 
    hid_units = 128
    n_networks = len(t.adj_list)
    
    #embeds = embed(features, adj_list, self.args.sparse)
    # Write down data in format required for model training
    #f = open("OpenAttMultiGL/data/GATNE/Amazon/testt.txt","a")
    #d = dict()
    #for i in range(len(t.sequence_adj)):
        #d[i] = []
        #for j in range(len(t.sequence_adj[i])):
            #for l in t.test_id:
                #if j == l:
                    #for k in range(len(t.sequence_adj[i][j])):
                        #f.write(str(i))
                        #f.write(' ')
                        #f.write(str(j))
                        #f.write(' ')
                        #f.write(str(k))
                        #f.write(' ')
                        #f.write(str(int(t.sequence_adj[i][j][k])))
                        #f.write('\n')
    #f.close()
    #training_data_by_type = load_training_data("OpenAttMultiGL/data/GATNE/"+file_name + "/train.txt")
    valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(
        "OpenAttMultiGL/data/GATNE/"+file_name + "/valid.txt"
    )
    testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(
        "OpenAttMultiGL/data/GATNE/"+file_name + "/test.txt"
    )
    
    #c = t.sequence_adj[0][]
    d = dict()
    for i in range(len(t.sequence_adj)):
        d[i] = []
        for j in range(len(t.sequence_adj[i])):
            for l in t.train_id:
                if j == l:
                    for k in range(len(t.sequence_adj[i][j])):
                        if t.sequence_adj[i][j][k] == 1:
                            e = (str(j),str(k))
                            d[i].append(e)
    #micro,macro,sim,nmi = train_model(d, feature_dic,t)
    micro,macro,nmi,sim = train_model(d, feature_dic,t)
    
    print("Final score: \n")
    print('Micro: {:.4f} ({:.4f})'.format(np.mean(micro),np.std(micro)))
    print('Macro: {:.4f} ({:.4f})'.format(np.mean(macro),np.std(macro)))
    print('Sim: {:.4f} ({:.4f})'.format(np.mean(sim),np.std(sim)))
    print('NMI: {:.4f} ({:.4f})'.format(np.mean(nmi),np.std(nmi)))
    #print('SIM: ', sim)
    #print('NMI: ', nmi)




Generating random walks for layer 0


15640it [00:05, 2970.74it/s]  

Generating random walks for layer 1



43300it [00:01, 35621.30it/s]


Finish generating the walks
Saving walks for layer 0


100%|█████████████████████████████████| 15640/15640 [00:00<00:00, 656645.79it/s]


Saving walks for layer 1


100%|█████████████████████████████████| 43300/43300 [00:00<00:00, 699196.00it/s]


Counting vocab for layer 0


100%|█████████████████████████████████| 15640/15640 [00:00<00:00, 903528.98it/s]


Counting vocab for layer 1


100%|█████████████████████████████████| 43300/43300 [00:00<00:00, 818605.51it/s]


Generating training pairs for layer 0


100%|██████████████████████████████████| 15640/15640 [00:00<00:00, 87329.45it/s]


Generating training pairs for layer 1


100%|██████████████████████████████████| 43300/43300 [00:00<00:00, 86585.30it/s]


Generating neighbors for layer 0


100%|██████████████████████████████████| 1619/1619 [00:00<00:00, 1779501.62it/s]


Generating neighbors for layer 1


100%|██████████████████████████████████| 8360/8360 [00:00<00:00, 1871098.26it/s]
epoch 0:   0%|| 40/31312 [00:00<02:35, 201.58it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 4.161587238311768, 'loss': 4.161587238311768}


epoch 0:  16%|| 5031/31312 [00:23<01:59, 219.17it/s]

{'epoch': 0, 'iter': 5000, 'avg_loss': 3.173034992343877, 'loss': 2.4089150428771973}


epoch 0:  32%|| 10031/31312 [00:49<01:49, 194.33it/s]

{'epoch': 0, 'iter': 10000, 'avg_loss': 2.809113660009846, 'loss': 2.3913214206695557}


epoch 0:  48%|| 15013/31312 [01:45<02:54, 93.35it/s] 

{'epoch': 0, 'iter': 15000, 'avg_loss': 2.6470494246007634, 'loss': 2.266148090362549}


epoch 0:  64%|| 20006/31312 [02:57<01:59, 94.35it/s] 

{'epoch': 0, 'iter': 20000, 'avg_loss': 2.5322273216881244, 'loss': 2.189408540725708}


epoch 0:  80%|| 25005/31312 [06:01<03:27, 30.37it/s]  

{'epoch': 0, 'iter': 25000, 'avg_loss': 2.4330014767472656, 'loss': 1.9068797826766968}


epoch 0:  96%|| 30010/31312 [08:26<00:20, 63.43it/s]

{'epoch': 0, 'iter': 30000, 'avg_loss': 2.342067594595907, 'loss': 1.82365083694458}


epoch 0: 100%|| 31312/31312 [09:07<00:00, 57.22it/s]


	[Classification] Macro-F1: 0.4141 (0.0180) | Micro-F1: 0.4394 (0.0037)
	[Clustering] NMI: 0.0057 | 0.0002
	[Similarity] [5,10,20,50,100] : [0.4521,0.4276,0.4097,0.3932,0.3842]


epoch 1:   0%|| 5/31312 [00:00<50:03, 10.42it/s]  

{'epoch': 1, 'iter': 0, 'avg_loss': 1.851701259613037, 'loss': 1.851701259613037}


epoch 1:  16%|| 5005/31312 [01:37<11:31, 38.03it/s] 

{'epoch': 1, 'iter': 5000, 'avg_loss': 1.7149731583701113, 'loss': 1.5067921876907349}


epoch 1:  32%|| 10018/31312 [02:39<03:21, 105.63it/s]

{'epoch': 1, 'iter': 10000, 'avg_loss': 1.6571510872260629, 'loss': 1.3738837242126465}


epoch 1:  48%|| 15026/31312 [03:39<01:42, 159.42it/s]

{'epoch': 1, 'iter': 15000, 'avg_loss': 1.6047319251452674, 'loss': 1.3692924976348877}


epoch 1:  64%|| 20005/31312 [04:17<01:27, 129.08it/s]

{'epoch': 1, 'iter': 20000, 'avg_loss': 1.5564500445169507, 'loss': 1.3314712047576904}


epoch 1:  80%|| 25021/31312 [05:03<01:08, 92.42it/s] 

{'epoch': 1, 'iter': 25000, 'avg_loss': 1.5123617521672577, 'loss': 1.2375247478485107}


epoch 1:  96%|| 30027/31312 [06:16<00:09, 140.06it/s]

{'epoch': 1, 'iter': 30000, 'avg_loss': 1.4712961483969655, 'loss': 1.2669026851654053}


epoch 1: 100%|| 31312/31312 [06:25<00:00, 81.12it/s] 


	[Classification] Macro-F1: 0.4251 (0.0110) | Micro-F1: 0.4529 (0.0040)
	[Clustering] NMI: 0.0050 | 0.0005
	[Similarity] [5,10,20,50,100] : [0.4508,0.4227,0.3995,0.3828,0.374]


epoch 2:   0%|| 9/31312 [00:00<14:01, 37.19it/s]  

{'epoch': 2, 'iter': 0, 'avg_loss': 1.259683609008789, 'loss': 1.259683609008789}


epoch 2:  16%|| 5023/31312 [00:40<03:21, 130.76it/s]

{'epoch': 2, 'iter': 5000, 'avg_loss': 1.1882793993574217, 'loss': 1.0462530851364136}


epoch 2:  32%|| 10030/31312 [01:17<02:07, 167.56it/s]

{'epoch': 2, 'iter': 10000, 'avg_loss': 1.1615398831932966, 'loss': 0.9207583665847778}


epoch 2:  48%|| 15030/31312 [01:53<01:35, 169.87it/s]

{'epoch': 2, 'iter': 15000, 'avg_loss': 1.1369225733534956, 'loss': 1.0339769124984741}


epoch 2:  64%|| 20024/31312 [02:33<01:09, 163.36it/s]

{'epoch': 2, 'iter': 20000, 'avg_loss': 1.1138905060570823, 'loss': 1.1325461864471436}


epoch 2:  80%|| 25015/31312 [03:11<00:51, 123.25it/s]

{'epoch': 2, 'iter': 25000, 'avg_loss': 1.0934046220123317, 'loss': 0.8924440145492554}


epoch 2:  96%|| 30028/31312 [03:56<00:06, 190.55it/s]

{'epoch': 2, 'iter': 30000, 'avg_loss': 1.0736286996233515, 'loss': 0.8835773468017578}


epoch 2: 100%|| 31312/31312 [04:02<00:00, 129.02it/s]


	[Classification] Macro-F1: 0.4530 (0.0058) | Micro-F1: 0.4633 (0.0049)
	[Clustering] NMI: 0.0075 | 0.0007
	[Similarity] [5,10,20,50,100] : [0.454,0.431,0.4089,0.3895,0.3795]


epoch 3:   0%|| 20/31312 [00:00<05:21, 97.41it/s]

{'epoch': 3, 'iter': 0, 'avg_loss': 0.8716562986373901, 'loss': 0.8716562986373901}


epoch 3:  16%|| 5040/31312 [00:38<01:56, 225.95it/s]

{'epoch': 3, 'iter': 5000, 'avg_loss': 0.9337514494567174, 'loss': 0.9951727390289307}


epoch 3:  32%|| 10011/31312 [01:07<02:29, 142.38it/s]

{'epoch': 3, 'iter': 10000, 'avg_loss': 0.9200489818900839, 'loss': 0.9295198917388916}


epoch 3:  48%|| 15018/31312 [01:46<01:41, 161.02it/s]

{'epoch': 3, 'iter': 15000, 'avg_loss': 0.9069330028419248, 'loss': 0.8537870645523071}


epoch 3:  64%|| 20020/31312 [02:26<01:26, 130.14it/s]

{'epoch': 3, 'iter': 20000, 'avg_loss': 0.8943198235873776, 'loss': 0.8057821989059448}


epoch 3:  80%|| 25012/31312 [03:14<00:59, 105.62it/s]

{'epoch': 3, 'iter': 25000, 'avg_loss': 0.8825019435149986, 'loss': 0.9892697334289551}


epoch 3:  96%|| 30026/31312 [03:55<00:08, 156.79it/s]

{'epoch': 3, 'iter': 30000, 'avg_loss': 0.8714637545771465, 'loss': 0.7882325053215027}


epoch 3: 100%|| 31312/31312 [04:05<00:00, 127.61it/s]


	[Classification] Macro-F1: 0.4279 (0.0110) | Micro-F1: 0.4507 (0.0043)
	[Clustering] NMI: 0.0054 | 0.0002
	[Similarity] [5,10,20,50,100] : [0.4492,0.423,0.4035,0.3865,0.3781]


epoch 4:   0%|| 6/31312 [00:00<22:37, 23.07it/s]  

{'epoch': 4, 'iter': 0, 'avg_loss': 0.8000925779342651, 'loss': 0.8000925779342651}


epoch 4:  16%|| 5014/31312 [00:40<04:11, 104.57it/s]

{'epoch': 4, 'iter': 5000, 'avg_loss': 0.7917295569206472, 'loss': 0.7259659767150879}


epoch 4:  32%|| 10016/31312 [01:21<02:33, 138.47it/s]

{'epoch': 4, 'iter': 10000, 'avg_loss': 0.7845868310944079, 'loss': 0.8781912922859192}


epoch 4:  48%|| 15027/31312 [02:03<01:33, 173.55it/s]

{'epoch': 4, 'iter': 15000, 'avg_loss': 0.7766036276586548, 'loss': 0.8053305149078369}


epoch 4:  64%|| 20002/31312 [02:47<01:27, 128.57it/s]

{'epoch': 4, 'iter': 20000, 'avg_loss': 0.7684882514048001, 'loss': 0.8197945952415466}


epoch 4:  80%|| 25024/31312 [03:33<00:41, 153.24it/s]

{'epoch': 4, 'iter': 25000, 'avg_loss': 0.7611321695761358, 'loss': 0.7535653114318848}


epoch 4:  96%|| 30024/31312 [04:21<00:10, 120.92it/s]

{'epoch': 4, 'iter': 30000, 'avg_loss': 0.7542750011700033, 'loss': 0.7324659824371338}


epoch 4: 100%|| 31312/31312 [04:31<00:00, 115.31it/s]


	[Classification] Macro-F1: 0.4242 (0.0029) | Micro-F1: 0.4449 (0.0024)
	[Clustering] NMI: 0.0075 | 0.0001
	[Similarity] [5,10,20,50,100] : [0.4468,0.4158,0.3972,0.3792,0.3697]


epoch 5:   0%|| 10/31312 [00:00<10:29, 49.73it/s] 

{'epoch': 5, 'iter': 0, 'avg_loss': 0.6961605548858643, 'loss': 0.6961605548858643}


epoch 5:  16%|| 5022/31312 [00:40<03:01, 145.14it/s]

{'epoch': 5, 'iter': 5000, 'avg_loss': 0.7058140904551576, 'loss': 0.6508585214614868}


epoch 5:  32%|| 10035/31312 [01:25<01:56, 183.22it/s]

{'epoch': 5, 'iter': 10000, 'avg_loss': 0.6995878808618295, 'loss': 0.5227296352386475}


epoch 5:  48%|| 15024/31312 [02:03<01:44, 155.79it/s]

{'epoch': 5, 'iter': 15000, 'avg_loss': 0.6952548046865731, 'loss': 0.680739164352417}


epoch 5:  64%|| 20016/31312 [02:44<01:30, 124.58it/s]

{'epoch': 5, 'iter': 20000, 'avg_loss': 0.6902375023717624, 'loss': 0.6451594233512878}


epoch 5:  80%|| 25020/31312 [03:20<00:40, 156.29it/s]

{'epoch': 5, 'iter': 25000, 'avg_loss': 0.6858015487176992, 'loss': 0.670167088508606}


epoch 5:  96%|| 30022/31312 [04:00<00:12, 103.51it/s]

{'epoch': 5, 'iter': 30000, 'avg_loss': 0.6817379499667049, 'loss': 0.6106115579605103}


epoch 5: 100%|| 31312/31312 [04:09<00:00, 125.51it/s]


	[Classification] Macro-F1: 0.4126 (0.0025) | Micro-F1: 0.4337 (0.0018)
	[Clustering] NMI: 0.0087 | 0.0003
	[Similarity] [5,10,20,50,100] : [0.4555,0.426,0.4065,0.3866,0.3748]


epoch 6:   0%|| 15/31312 [00:00<06:29, 80.34it/s]

{'epoch': 6, 'iter': 0, 'avg_loss': 0.5853804349899292, 'loss': 0.5853804349899292}


epoch 6:  16%|| 5004/31312 [00:40<02:38, 166.19it/s]

{'epoch': 6, 'iter': 5000, 'avg_loss': 0.6505394384542053, 'loss': 0.6364444494247437}


epoch 6:  32%|| 10023/31312 [01:14<01:35, 223.51it/s]

{'epoch': 6, 'iter': 10000, 'avg_loss': 0.6473237150264447, 'loss': 0.6569089293479919}


epoch 6:  48%|| 15040/31312 [01:37<01:11, 227.93it/s]

{'epoch': 6, 'iter': 15000, 'avg_loss': 0.64405349966415, 'loss': 0.800261378288269}


epoch 6:  64%|| 20025/31312 [01:58<00:48, 233.63it/s]

{'epoch': 6, 'iter': 20000, 'avg_loss': 0.6400775616541248, 'loss': 0.7908459901809692}


epoch 6:  80%|| 25040/31312 [02:20<00:27, 225.35it/s]

{'epoch': 6, 'iter': 25000, 'avg_loss': 0.6366297959599522, 'loss': 0.6721832752227783}


epoch 6:  96%|| 30042/31312 [02:43<00:05, 226.98it/s]

{'epoch': 6, 'iter': 30000, 'avg_loss': 0.6337860926664383, 'loss': 0.6649888753890991}


epoch 6: 100%|| 31312/31312 [02:48<00:00, 185.90it/s]


	[Classification] Macro-F1: 0.4087 (0.0095) | Micro-F1: 0.4252 (0.0030)
	[Clustering] NMI: 0.0074 | 0.0004
	[Similarity] [5,10,20,50,100] : [0.4542,0.4238,0.3996,0.3831,0.3726]
