In [1]:
from model import *
from utils import *
from torch_geometric.loader import DataLoader
from torch_geometric.loader import ClusterData, ClusterLoader, NeighborSampler
import torch.nn.functional as F

import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import pickle
import os
from sklearn.metrics import f1_score

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

from functools import partial
from sklearn.preprocessing import OneHotEncoder
from sklearn.cluster import KMeans




%load_ext autoreload
%autoreload 2

In [2]:

def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    else:
        return print('Error')

def one_hot_encoding(l):
    label_types = torch.unique(l).tolist()
    new_labels = []
    for i in range(0, len(l)):
        tmp = []
        for j in range(0, len(label_types)):
            tmp.append(0.)
        tmp[l[i].item()] = 1.
        new_labels.append(tmp)
    return torch.tensor(new_labels)     

def load_files(node_file_path, links_file_path, label_file_path, embedding_file_path, dataset):
    colors = pd.read_csv(node_file_path, sep='\t', header = None)
    colors = colors.dropna(axis=1,how='all')
    labels = pd.read_csv(label_file_path, sep='\t', header = None)
    links = pd.read_csv(links_file_path, sep='\t', header = None)
    labels.rename(columns = {0: 'node', 1: 'label'}, inplace = True)
    source_nodes_with_labels = labels['node'].values.tolist()
    labels = torch.tensor(labels['label'].values)
    colors.rename(columns = {0: 'node', 1: 'color'}, inplace = True)
    links.rename(columns = {0: 'node_1', 1: 'relation_type', 2: 'node_2'}, inplace = True)
    if dataset == 'complex' or dataset == 'simple':
        embedding = pd.read_csv(embedding_file_path, sep='\t', header = None)
        embedding_number = len(embedding.columns)-2
        if embedding_number == 3:
            embedding.rename(columns = {0: 'index', 1: 'second embedding', 2: 'first embedding', 3: 'labels'}, inplace = True)
        elif embedding_number == 4:
            embedding.rename(columns = {0: 'index', 1: 'third embedding', 2: 'second embedding', 3: 'first embedding', 4: 'labels'}, inplace = True)
        elif embedding_number == 5:
            embedding.rename(columns = {0: 'index', 1: 'fourth embedding', 2: 'third embedding', 3: 'second embedding', 4: 'first_embdding', 5: 'labels'}, inplace = True)
        elif embedding_number == 2:
            embedding.rename(columns = {0: 'index', 1: 'first embedding', 2: 'labels'}, inplace = True)
        return labels, colors, links, embedding
    else:
        labels_multi  = one_hot_encoding(labels)
        # for i in range(0, len(labels)):
        #     if labels[i].item() == 0:
        #         labels[i] = 1
        #     else:
        #         labels[i] = 0
        return labels, colors, links, source_nodes_with_labels, labels_multi

def splitting_node_and_labels(lab, feat, src, dataset):
    if dataset == 'complex' or dataset == 'simple':
        node_idx = torch.tensor(feat['node'].values)
    else:
        node_idx = torch.tensor(src)
    train_split = int(len(node_idx)*0.8)
    test_split = len(node_idx) - train_split
    train_idx = node_idx[:train_split]
    test_idx = node_idx[-test_split:]

    train_y = lab[:train_split]
    test_y = lab[-test_split:]
    return node_idx, train_idx, train_y, test_idx, test_y

def get_node_features(colors):
    node_features = pd.get_dummies(colors)
    
    node_features.drop(["node"], axis=1, inplace=True)
    
    x = node_features.to_numpy().astype(np.float32)
    x = np.flip(x, 1).copy()
    x = torch.from_numpy(x) 
    return x

def get_edge_index_and_type_no_reverse(links):
    edge_index = links.drop(['relation_type'], axis=1)
    edge_index = torch.tensor([list(edge_index['node_1'].values), list(edge_index['node_2'].values)])
    
    edge_type = links['relation_type']
    edge_type = torch.tensor(edge_type)
    return edge_index, edge_type

In [35]:

def mpgnn_train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    weight_loss = torch.tensor([1., 100.])
    out = model(data.x, data.edge_index, data.edge_type)
    loss = F.nll_loss(out[data.train_idx].squeeze(-1), data.train_y)#, weight = weight_loss)
    #loss = F.cross_entropy(out[data.train_idx], data.train_y)
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def mpgnn_test(model, data):
    model.eval()
    pred = model(data.x, data.edge_index, data.edge_type)#.argmax(dim=-1)
    train_predictions = torch.argmax(pred[data.train_idx], 1).tolist()
    test_predictions = torch.argmax(pred[data.test_idx], 1).tolist()
    train_y = data.train_y.tolist()
    test_y = data.test_y.tolist()
    # train_acc = (train_predictions == train_y).float().mean()
    # test_acc = (test_predictions == test_y).float().mean()
    f1_train = f1_score(train_predictions, train_y, average='macro')
    f1_test_macro = f1_score(test_predictions, test_y, average = 'macro')
    f1_test_micro = f1_score(test_predictions, test_y, average = 'micro')
    return f1_train, f1_test_micro, f1_test_macro

def mpgnn_parallel(data_mpgnn, input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, metapath):
    metapath=[0, 1, 2]
    mpgnn_model = MPNet(input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, len(metapath), metapath)
    print(mpgnn_model)
    # for name, param in mpgnn_model.named_parameters():
    #     print(name, param, param.size())
    mpgnn_optimizer = torch.optim.Adam(mpgnn_model.parameters(), lr=0.01, weight_decay=0.0005)
    best_macro, best_micro = 0., 0.
    for epoch in tqdm(range(1, 100)):
        loss = mpgnn_train(mpgnn_model, mpgnn_optimizer, data_mpgnn)
        train_acc, f1_test_micro, f1_test_macro = mpgnn_test(mpgnn_model, data_mpgnn)
        if f1_test_macro > best_macro:
            best_macro = f1_test_micro
        if f1_test_micro > best_micro:
            best_micro = f1_test_micro
    return best_micro

def mpgnn_parallel_multiple(data_mpgnn, input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, metapaths):
    #metapaths = [[2, 0]]#, [3, 1]]
    #metapaths = [[1, 4, 2, 0], [1, 0], [1, 5, 3, 0]]
    #metapaths = [[4, 3, 0], [1, 0], [0, 4, 2]]
    #metapaths = [[2, 4, 0], [0, 3, 4], [0, 1]]
    metapaths = [[2,0],[3,1]] #IMDB
    metapaths = [[0,2],[1,3]] #IMDB
    
    mpgnn_model = MPNetm(input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, len(metapaths), metapaths)
    print(mpgnn_model)
    # for name, param in mpgnn_model.named_parameters():
    #     print(name, param, param.size())
    mpgnn_optimizer = torch.optim.Adam(mpgnn_model.parameters(), lr=0.01, weight_decay=0.0005)
    best_macro, best_micro = 0., 0.
    for epoch in range(1, 1000):
        loss = mpgnn_train(mpgnn_model, mpgnn_optimizer, data_mpgnn)
        train_acc, f1_test_micro, f1_test_macro = mpgnn_test(mpgnn_model, data_mpgnn)
        print(epoch, 'loss: ', loss, 'train acc: ', train_acc, 'micro: ', f1_test_micro)
        if f1_test_macro > best_macro:
            best_macro = f1_test_micro
        if f1_test_micro > best_micro:
            best_micro = f1_test_micro
    return best_micro

def main(node_file_path, link_file_path, label_file_path, embedding_file_path, metapath_length, pickle_filename, input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, dataset):
    # Obtain true 0|1 labels for each node, feature matrix (1-hot encoding) and links among nodes
    if dataset == 'complex' or dataset == 'simple':
        sources = []
        true_labels, features, edges, embedding = load_files(node_file_path, link_file_path, label_file_path, embedding_file_path, dataset)
    else: 
        true_labels, features, edges, sources, labels_multi = load_files(node_file_path, link_file_path, label_file_path, embedding_file_path, dataset)
    # Get features' matrix
    x = get_node_features(features)
    # Get edge_index and types
    edge_index, edge_type = get_edge_index_and_type_no_reverse(edges)

    # Split data into train and test
    node_idx, train_idx, train_y, test_idx, test_y = splitting_node_and_labels(true_labels, features, sources, dataset)
    #node_idx, train_idx, train_y, test_idx, test_y = splitting_node_and_labels(labels_multi, features, sources, dataset)

    # Dataset for MPGNN
    data_mpgnn = Data()
    data_mpgnn.x = x
    data_mpgnn.edge_index = edge_index
    data_mpgnn.edge_type = edge_type
    data_mpgnn.train_idx = train_idx
    data_mpgnn.test_idx = test_idx
    data_mpgnn.train_y = train_y
    data_mpgnn.test_y = test_y
    data_mpgnn.num_nodes = node_idx.size(0)
    # Variables
    if sources:
        source_nodes_mask = sources
    else:
        source_nodes_mask = []
    metapath = []

    # Dataset for score function
    data = Data()
    data.x = x
    data.edge_index = edge_index
    data.edge_type = edge_type
    data.labels = true_labels
    data.labels = data.labels.unsqueeze(-1)
    data.num_nodes = x.size(0)
    data.bags = torch.empty(1)
    data.bag_labels = torch.empty(1)

    # All possible relations
    relations = torch.unique(data.edge_type).tolist()
    mp = []
    mpgnn_f1_micro = mpgnn_parallel_multiple(data_mpgnn, input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, mp)
    print(mpgnn_f1_micro)



In [36]:
COMPLEX = True
COMPLEX = "synthetic_multi"
COMPLEX = "IMDB"

metapath_length= 3
tot_rel=5

if COMPLEX == True:
    input_dim = 6
    ll_output_dim = 2
    dataset = "complex"
    folder= "data/" + dataset + "/length_m_" + str(metapath_length) + "__tot_rel_" + str(tot_rel) + "/"
elif COMPLEX == False:
    input_dim = 6
    ll_output_dim = 2
    dataset = "simple"
    folder= "data/" + dataset + "/length_m_" + str(metapath_length) + "__tot_rel_" + str(tot_rel) + "/"
elif COMPLEX == 'IMDB':
    tot_rel=4
    input_dim = 3066
    ll_output_dim = 3
    dataset = 'IMDB' ## 5
    folder= "data/" + dataset + "/"
elif COMPLEX == 'DBLP':
    input_dim = 4231
    tot_rel=6
    ll_output_dim = 4
    dataset = 'DBLP' ## 7
    folder= "data/" + dataset + "/"
elif COMPLEX == 'synthetic_multi':
    input_dim=6
    tot_rel=5
    ll_output_dim=2
    dataset = 'tot_rel_5'
    folder="data/synthetic_multi/" + dataset + "/"

node_file= folder + "node.dat"
link_file= folder + "link.dat"
label_file= folder + "label.dat"
embedding_file = folder + "embedding.dat"
# Define the filename for saving the variables
pickle_filename = folder + "iteration_variables.pkl"
# mpgnn variables
hidden_dim = 32
num_rel = tot_rel
output_dim = 64


In [37]:
main(node_file, link_file, label_file, embedding_file, metapath_length, pickle_filename, input_dim, hidden_dim, num_rel, output_dim, ll_output_dim, dataset)

MPNetm(
  (layers_list): ModuleList(
    (0): ModuleList(
      (0): CustomRGCNConv(3066, 32, num_relations=4)
      (1): CustomRGCNConv(32, 32, num_relations=4)
    )
    (1): ModuleList(
      (0): CustomRGCNConv(3066, 32, num_relations=4)
      (1): CustomRGCNConv(32, 32, num_relations=4)
    )
  )
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=3, bias=True)
  (log_softmax): LogSoftmax(dim=1)
)
1 loss:  1.0967216491699219 train acc:  0.17719647804962432 micro:  0.4100467289719626
2 loss:  1.0876092910766602 train acc:  0.40422099510786486 micro:  0.4672897196261683
3 loss:  1.065207839012146 train acc:  0.48554254958078014 micro:  0.5397196261682243
4 loss:  1.0188186168670654 train acc:  0.4999179517148802 micro:  0.5829439252336449
5 loss:  0.9465958476066589 train acc:  0.5304791681743052 micro:  0.5817757009345794
6 loss:  0.84529048204422 train acc:  0.669466400790118 micro:  0.6051401869158879
7 loss:  0.717055559158325

101 loss:  0.5213574171066284 train acc:  0.2645664207429261 micro:  0.1658878504672897
102 loss:  5.079266548156738 train acc:  0.5426563362390558 micro:  0.45794392523364486
103 loss:  3.767960786819458 train acc:  0.5336021171642022 micro:  0.4369158878504673
104 loss:  2.9091877937316895 train acc:  0.8645329911190339 micro:  0.5549065420560748
105 loss:  0.48006579279899597 train acc:  0.8027086819535193 micro:  0.5619158878504673
106 loss:  0.8859720826148987 train acc:  0.790343855388613 micro:  0.5537383177570093
107 loss:  0.8613075613975525 train acc:  0.9124661160592696 micro:  0.6039719626168224
108 loss:  0.29657870531082153 train acc:  0.9528146804410461 micro:  0.6191588785046729
109 loss:  0.1436212807893753 train acc:  0.9507220724292275 micro:  0.602803738317757
110 loss:  0.1406417340040207 train acc:  0.9416836029992233 micro:  0.5911214953271028
111 loss:  0.17102518677711487 train acc:  0.9370000621336573 micro:  0.5887850467289719
112 loss:  0.18828287720680237 t

202 loss:  0.0031989291310310364 train acc:  1.0 micro:  0.5992990654205608
203 loss:  0.003158093662932515 train acc:  1.0 micro:  0.5992990654205608
204 loss:  0.0031182279344648123 train acc:  1.0 micro:  0.5981308411214953
205 loss:  0.003085486823692918 train acc:  1.0 micro:  0.5969626168224299
206 loss:  0.0030639031901955605 train acc:  1.0 micro:  0.5957943925233645
207 loss:  0.0030541103333234787 train acc:  1.0 micro:  0.5957943925233645
208 loss:  0.0030547003261744976 train acc:  1.0 micro:  0.5957943925233645
209 loss:  0.003056885441765189 train acc:  1.0 micro:  0.5969626168224299
210 loss:  0.003064032644033432 train acc:  1.0 micro:  0.5957943925233645
211 loss:  0.0030712569132447243 train acc:  1.0 micro:  0.594626168224299
212 loss:  0.003076865104958415 train acc:  1.0 micro:  0.594626168224299
213 loss:  0.0030839608516544104 train acc:  1.0 micro:  0.5934579439252337
214 loss:  0.003085966920480132 train acc:  1.0 micro:  0.594626168224299
215 loss:  0.00309232

311 loss:  0.002735845511779189 train acc:  1.0 micro:  0.5642523364485982
312 loss:  0.0027314440812915564 train acc:  1.0 micro:  0.5654205607476636
313 loss:  0.0027252330910414457 train acc:  1.0 micro:  0.5654205607476636
314 loss:  0.0027227397076785564 train acc:  1.0 micro:  0.5642523364485982
315 loss:  0.0027297120541334152 train acc:  1.0 micro:  0.5677570093457944
316 loss:  0.0027313397731631994 train acc:  1.0 micro:  0.5630841121495327
317 loss:  0.0027313667815178633 train acc:  1.0 micro:  0.5654205607476636
318 loss:  0.0027277835179120302 train acc:  1.0 micro:  0.5665887850467289
319 loss:  0.0027272335719317198 train acc:  1.0 micro:  0.5665887850467289
320 loss:  0.002726905746385455 train acc:  1.0 micro:  0.5654205607476636
321 loss:  0.0027295267209410667 train acc:  1.0 micro:  0.5642523364485982
322 loss:  0.002729845931753516 train acc:  1.0 micro:  0.5665887850467289
323 loss:  0.0027283714152872562 train acc:  1.0 micro:  0.5630841121495327
324 loss:  0.00

420 loss:  0.0026769577525556087 train acc:  1.0 micro:  0.5677570093457944
421 loss:  0.0026777025777846575 train acc:  1.0 micro:  0.5654205607476636
422 loss:  0.002675762167200446 train acc:  1.0 micro:  0.5689252336448598
423 loss:  0.0026824523229151964 train acc:  1.0 micro:  0.5665887850467289
424 loss:  0.0026775638107210398 train acc:  1.0 micro:  0.5642523364485982
425 loss:  0.002679102122783661 train acc:  1.0 micro:  0.5665887850467289
426 loss:  0.002674645744264126 train acc:  1.0 micro:  0.5665887850467289
427 loss:  0.002671360271051526 train acc:  1.0 micro:  0.5689252336448598
428 loss:  0.002671268070116639 train acc:  1.0 micro:  0.5665887850467289
429 loss:  0.0026688070502132177 train acc:  1.0 micro:  0.5677570093457944
430 loss:  0.002674225950613618 train acc:  1.0 micro:  0.5677570093457944
431 loss:  0.0026703188195824623 train acc:  1.0 micro:  0.5654205607476636
432 loss:  0.0026744252536445856 train acc:  1.0 micro:  0.5677570093457944
433 loss:  0.00266

529 loss:  0.002572221914306283 train acc:  1.0 micro:  0.5677570093457944
530 loss:  0.002570608165115118 train acc:  1.0 micro:  0.5677570093457944
531 loss:  0.002571535762399435 train acc:  1.0 micro:  0.5665887850467289
532 loss:  0.002569034928455949 train acc:  1.0 micro:  0.5677570093457944
533 loss:  0.0025711683556437492 train acc:  1.0 micro:  0.5665887850467289
534 loss:  0.0025660349056124687 train acc:  1.0 micro:  0.5665887850467289
535 loss:  0.0025649964809417725 train acc:  1.0 micro:  0.5677570093457944
536 loss:  0.002563971560448408 train acc:  1.0 micro:  0.5665887850467289
537 loss:  0.0025636604987084866 train acc:  1.0 micro:  0.5677570093457944
538 loss:  0.0025682009290903807 train acc:  1.0 micro:  0.5665887850467289
539 loss:  0.002562473528087139 train acc:  1.0 micro:  0.5677570093457944
540 loss:  0.0025629715528339148 train acc:  1.0 micro:  0.5677570093457944
541 loss:  0.0025572164449840784 train acc:  1.0 micro:  0.5677570093457944
542 loss:  0.00255

638 loss:  0.0024521048180758953 train acc:  1.0 micro:  0.5654205607476636
639 loss:  0.0024509842041879892 train acc:  1.0 micro:  0.5665887850467289
640 loss:  0.0024510787334293127 train acc:  1.0 micro:  0.5642523364485982
641 loss:  0.0024498996790498495 train acc:  1.0 micro:  0.5642523364485982
642 loss:  0.0024538678117096424 train acc:  1.0 micro:  0.5642523364485982
643 loss:  0.0024499648716300726 train acc:  1.0 micro:  0.5642523364485982
644 loss:  0.00244991946965456 train acc:  1.0 micro:  0.5654205607476636
645 loss:  0.0024445748422294855 train acc:  1.0 micro:  0.5654205607476636
646 loss:  0.0024423617869615555 train acc:  1.0 micro:  0.5642523364485982
647 loss:  0.0024446630850434303 train acc:  1.0 micro:  0.5642523364485982
648 loss:  0.0024401997216045856 train acc:  1.0 micro:  0.5654205607476636
649 loss:  0.002439499367028475 train acc:  1.0 micro:  0.5642523364485982
650 loss:  0.0024387193843722343 train acc:  1.0 micro:  0.5654205607476636
651 loss:  0.00

747 loss:  0.0023145992308855057 train acc:  1.0 micro:  0.5619158878504673
748 loss:  0.002313013421371579 train acc:  1.0 micro:  0.5642523364485982
749 loss:  0.0023122376296669245 train acc:  1.0 micro:  0.5630841121495327
750 loss:  0.0023096525110304356 train acc:  1.0 micro:  0.5642523364485982
751 loss:  0.002310462761670351 train acc:  1.0 micro:  0.5619158878504673
752 loss:  0.0023077428340911865 train acc:  1.0 micro:  0.5630841121495327
753 loss:  0.002306740963831544 train acc:  1.0 micro:  0.5619158878504673
754 loss:  0.0023042394313961267 train acc:  1.0 micro:  0.5619158878504673
755 loss:  0.0023021353408694267 train acc:  1.0 micro:  0.5630841121495327
756 loss:  0.0023010920267552137 train acc:  1.0 micro:  0.5619158878504673
757 loss:  0.002299490850418806 train acc:  1.0 micro:  0.5642523364485982
758 loss:  0.0022989038843661547 train acc:  1.0 micro:  0.5619158878504673
759 loss:  0.002296022605150938 train acc:  1.0 micro:  0.5642523364485982
760 loss:  0.0022

856 loss:  0.0021645228844136 train acc:  1.0 micro:  0.5665887850467289
857 loss:  0.0021631221752613783 train acc:  1.0 micro:  0.5654205607476636
858 loss:  0.0021632693242281675 train acc:  1.0 micro:  0.5665887850467289
859 loss:  0.002161406446248293 train acc:  1.0 micro:  0.5665887850467289
860 loss:  0.0021609619725495577 train acc:  1.0 micro:  0.5654205607476636
861 loss:  0.002159892115741968 train acc:  1.0 micro:  0.5665887850467289
862 loss:  0.0021588474046438932 train acc:  1.0 micro:  0.5654205607476636
863 loss:  0.0021565305069088936 train acc:  1.0 micro:  0.5654205607476636
864 loss:  0.0021554273553192616 train acc:  1.0 micro:  0.5665887850467289
865 loss:  0.0021542359609156847 train acc:  1.0 micro:  0.5654205607476636
866 loss:  0.0021536608692258596 train acc:  1.0 micro:  0.5677570093457944
867 loss:  0.002153020352125168 train acc:  1.0 micro:  0.5642523364485982
868 loss:  0.002151221502572298 train acc:  1.0 micro:  0.5654205607476636
869 loss:  0.002151

965 loss:  0.0020805620588362217 train acc:  1.0 micro:  0.5607476635514018
966 loss:  0.0020805865060538054 train acc:  1.0 micro:  0.5607476635514018
967 loss:  0.0020793562289327383 train acc:  1.0 micro:  0.5630841121495327
968 loss:  0.0020793296862393618 train acc:  1.0 micro:  0.5607476635514018
969 loss:  0.002077927114441991 train acc:  1.0 micro:  0.5607476635514018
970 loss:  0.002076574368402362 train acc:  1.0 micro:  0.5619158878504673
971 loss:  0.002076938049867749 train acc:  1.0 micro:  0.5607476635514018
972 loss:  0.0020752884447574615 train acc:  1.0 micro:  0.5619158878504673
973 loss:  0.0020754628349095583 train acc:  1.0 micro:  0.5607476635514018
974 loss:  0.002075008349493146 train acc:  1.0 micro:  0.5619158878504673
975 loss:  0.0020734616555273533 train acc:  1.0 micro:  0.5607476635514018
976 loss:  0.002073641400784254 train acc:  1.0 micro:  0.5607476635514018
977 loss:  0.0020718532614409924 train acc:  1.0 micro:  0.5630841121495327
978 loss:  0.0020

In [None]:
 loss:  0.37939170002937317 train acc:  0.5448191013562286 micro:  0.316588785046729
184 loss:  0.38017627596855164 train acc:  0.5451542144218879 micro:  0.31542056074766356
185 loss:  0.3799729347229004 train acc:  0.5438081095474381 micro:  0.31542056074766356
186 loss:  0.3803214132785797 train acc:  0.5458356860530773 micro:  0.308411214953271
187 loss:  0.3802218735218048 train acc:  0.5440949772791855 micro:  0.3189252336448598
188 loss:  0.38121509552001953 train acc:  0.5456981889684798 micro:  0.3119158878504673
189 loss:  0.37983420491218567 train acc:  0.545561294016143 micro:  0.31542056074766356
190 loss:  0.3799597918987274 train acc:  0.5433254904561581 micro:  0.33060747663551404
191 loss:  0.3819093108177185 train acc:  0.5459046606321211 micro:  0.308411214953271
192 loss:  0.3824063837528229 train acc:  0.5441477485509431 micro:  0.3235981308411215
193 loss:  0.38051867485046387 train acc:  0.5456296662730763 micro:  0.3142523364485981
194 loss:  0.3802700638771057 train acc:  0.545357078934816 micro:  0.3107476635514019