In [1]:
import torch.nn
import math
from torch.nn.parameter import Parameter
import numpy as np


# node-level adaption model
class GraphNeuralNode(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphNeuralNode, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))

        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):

        stdv = 1. / math.sqrt(self.weight.size(1))

        self.weight.data.uniform_(-stdv, stdv)

        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):

        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output


# used in node-level adaption
class GNN4NodeLevel(torch.nn.Module):
    # 构建模型
    def __init__(self, nfeat, nhid, dropout):
        super(GNN4NodeLevel, self).__init__()

        self.gc1 = GraphNeuralNode(nfeat, nhid)
        self.gc2 = GraphNeuralNode(nhid, nhid)
        self.dropout = dropout

    # 前向传播
    def forward(self, x, adj):
        return self.gc1(x, adj)


# node-level adaption model
class GraphNeuralClass(torch.nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphNeuralClass, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))

        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))

        self.weight.data.uniform_(-stdv, stdv)

        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj, w, b):
        if w is None and b is None:
            alpha_w = alpha_b = beta_w = beta_b = 0
        else:
            alpha_w = w[0]
            beta_w = w[1]
            alpha_b = b[0]
            beta_b = b[1]

        support = torch.mm(input, self.weight * (1 + alpha_w) + beta_w)  # formula (6)
        output = torch.mm(adj, support)

        if self.bias is not None:
            return output + self.bias * (1 + alpha_b) + beta_b
        else:
            return output


# used in node-level adaption
class GNN4ClassLevel(torch.nn.Module):
    def __init__(self, nfeat, nhid, dropout):
        super(GNN4ClassLevel, self).__init__()

        self.gc1 = GraphNeuralClass(nfeat, nhid)
        self.gc2 = GraphNeuralClass(nhid, nhid)
        self.generater = torch.nn.Linear(nfeat, (nfeat + 1) * nhid * 2 + (nhid + 1) * nhid * 2)

        self.dropout = dropout

    def permute(self, input_adj, input_feat, drop_rate=0.1):
        # return input_adj

        adj_random = torch.rand(input_adj.shape).cuda() + torch.eye(input_adj.shape[0]).cuda()

        feat_random = np.random.choice(input_feat.shape[0], int(input_feat.shape[0] * drop_rate),
                                       replace=False).tolist()

        masks = torch.zeros(input_feat.shape).cuda()
        masks[feat_random] = 1

        random_tensor = torch.rand(input_feat.shape).cuda()

        return input_adj * (adj_random > drop_rate), input_feat * (1 - masks) + random_tensor * masks

    def forward(self, x, adj, w1=None, b1=None, w2=None, b2=None):
        x = torch.nn.functional.relu(self.gc1(x, adj, w1, b1))
        x = torch.nn.functional.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj, w2, b2)
        return x


# used in classifier
class Linear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.bias = Parameter(torch.FloatTensor(out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))

        self.weight.data.uniform_(-stdv, stdv)

        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, w=None, b=None):
        if w is not None:
            return torch.mm(input, w) + b
        else:
            return torch.mm(input, self.weight) + self.bias


In [2]:
import numpy as np
import scipy.sparse as sp
import torch


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def sparse_matrix2torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor.

        Parameter:
            sparse_matrix: the scipy matrix to be conversed

        Return:
            torch sparse tensor
    """
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def l2_normalize(x):
    norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
    out = x.div(norm)
    return out


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


def InforNCE_Loss(anchor, sample, tau, dataset, all_negative=False, temperature_matrix=None):
    def _similarity(h1: torch.Tensor, h2: torch.Tensor):
        h1 = torch.nn.functional.normalize(h1)
        h2 = torch.nn.functional.normalize(h2)
        return h1 @ h2.t()

    assert anchor.shape[0] == sample.shape[0]

    pos_mask = torch.eye(anchor.shape[0], dtype=torch.float)

    pos_mask = pos_mask.cuda()

    neg_mask = 1. - pos_mask

    sim = _similarity(anchor, sample / temperature_matrix if temperature_matrix != None else sample) / tau
    exp_sim = torch.exp(sim) * (pos_mask + neg_mask)

    if not all_negative:
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))
    else:
        log_prob = - torch.log(exp_sim.sum(dim=1, keepdim=True))

    loss = log_prob * pos_mask
    loss = loss.sum(dim=1) / pos_mask.sum(dim=1)

    return -loss.mean(), sim

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def load_npz_to_sparse_graph(file_name):
    """Load a SparseGraph from a Numpy binary file.
    Parameters
    ----------
    file_name : str
        Name of the file to load.
    Returns
    -------
    sparse_graph : SparseGraph
        Graph in sparse matrix format.
    """
    with np.load(file_name) as loader:
        loader = dict(loader)
        adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']),
                                   shape=loader['adj_shape'])

        if 'attr_data' in loader:
            # Attributes are stored as a sparse CSR matrix
            attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']),
                                        shape=loader['attr_shape'])
        elif 'attr_matrix' in loader:
            # Attributes are stored as a (dense) np.ndarray
            attr_matrix = loader['attr_matrix']
        else:
            attr_matrix = None

        if 'labels_data' in loader:
            # Labels are stored as a CSR matrix
            labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']),
                                   shape=loader['labels_shape'])
        elif 'labels' in loader:
            # Labels are stored as a numpy array
            labels = loader['labels']
        else:
            labels = None

    return adj_matrix, attr_matrix, labels

In [3]:
import json
from collections import defaultdict

import numpy
import scipy.io as sio
import scipy.sparse
import scipy.sparse as sp
import numpy as np
import torch
from sklearn import preprocessing
from scipy.sparse import coo_matrix
from numpy import ndarray



'''
                           _ooOoo_
                          o8888888o
                          88" . "88
                          (| -_- |)
                          O\  =  /O
                       ____/`---'\____
                     .'  \\|     |//  `.
                    /  \\|||  :  |||//  \
                   /  _||||| -:- |||||-  \
                   |   | \\\  -  /// |   |
                   | \_|  ''\---/''  |   |
                   \  .-\__  `-`  ___/-. /
                 ___`. .'  /--.--\  `. . __
              ."" '<  `.___\_<|>_/___.'  >'"".
             | | :  `- \`.;`\ _ /`;.`/ - ` : | |
             \  \ `-.   \_ __\ /__ _/   .-` /  /
        ======`-.____`-.___\_____/___.-`____.-'======
                           `=---='
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                 佛祖保佑       永无BUG
'''


# the name of dataset
datasets = ['Amazon_eletronics', 'dblp']
CORA_FULL = 'cora-full'
AMAZON_ELECTRONICS = 'Amazon_eletronics'
DBLP = 'dblp'
CORA_FULL_NPZ = "cora_full.npz"

# directory of dataset
dataset_dir = "../input/dataset/dataset/"

class Graph:
    """ class Graph is used to store the information of a graph,
        including adjacency_matrix, features_matrix, and etc.

    """

    def __init__(self, adjacency_matrix, features_matrix, labels,
                 train_node_index, valid_node_index, test_node_index,
                 class_train_dict, class_valid_dict, class_test_dict):

        self.adjacency_matrix: coo_matrix = adjacency_matrix
        self.features_matrix: ndarray = features_matrix
        self.labels: ndarray = labels

        self.train_node_index: list = train_node_index
        self.valid_node_index: list = valid_node_index
        self.test_node_index: list = test_node_index

        self.class_train_dict: defaultdict = class_train_dict
        self.class_valid_dict: defaultdict = class_valid_dict
        self.class_test_dict: defaultdict = class_test_dict


def data_preprocess(dataset: str) -> Graph:
    """pre-process the data before training

        Parameter:
            dataset: the name of dataset, including 'Amazon_eletronics', 'dblp', 'cora-full' and 'ogbn-arxiv'
        Return:
            graph: an object of Graph class, storing a graph data
    """
    # [class0_id, class1_id, class2_id, ...]
    train_class_list: list = list()
    test_class_list: list = list()
    valid_class_list: list = list()
    train_class_list, valid_class_list, test_class_list = json.load(
        open(dataset_dir + '{}_class_split.json'.format(dataset)))
    if dataset == "Amazon_eletronics" or dataset == 'dblp':
        # all the edges in graph is denoted by (node1[i], node2[i])
        node1: list = list()
        node2: list = list()
        for line in open(dataset_dir + "{}_network".format(dataset)):
            n1, n2 = line.strip().split("\t")
            node1.append(int(n1))
            node2.append(int(n2))
        node_number: int = max(max(node1), max(node2)) + 1

        # data_train and data_test are dicts, they have useful keys 'Index', 'Label' and 'Attributes'
        # Index: [[1,2,3...]]
        # Label: [[1],[1],[2],[2],...]
        # Attributes: matrix
        data_train: dict = sio.loadmat(dataset_dir + "{}_train.mat".format(dataset))
        data_test: dict = sio.loadmat(dataset_dir + "{}_test.mat".format(dataset))
        # label of nodes
        labels = np.zeros((node_number, 1))
        labels[data_train['Index']] = data_train["Label"]
        labels[data_test['Index']] = data_test["Label"]
        # print(labels)
        # feature matrix
        features_matrix = np.zeros((node_number, data_train["Attributes"].shape[1]))
        features_matrix[data_train['Index']] = data_train["Attributes"].toarray()
        features_matrix[data_test['Index']] = data_test["Attributes"].toarray()
        # adjacency matrix
        adjacency_matrix = sp.coo_matrix((np.ones(len(node1)), (node1, node2)), shape=(node_number, node_number))

        # all the classes in a list
        all_class_list: list = []
        for cls in labels:
            if cls[0] not in all_class_list:
                all_class_list.append(cls[0])

        # {class_id -> [node_id, node_id, ...]}
        class_dict: dict = {}
        for cls in all_class_list:
            class_dict[cls] = []
        for node_id, class_id in enumerate(labels):
            class_dict[class_id[0]].append(node_id)

        label_binarizer = preprocessing.LabelBinarizer()
        labels = label_binarizer.fit_transform(labels)
        features_matrix = torch.FloatTensor(features_matrix)
        # labels = tensor([99, 61, 99, ..., 57, 97, 34])
        labels = torch.LongTensor(np.where(labels)[1])
        adjacency_matrix = sparse_matrix2torch_sparse_tensor(
            normalize(adjacency_matrix + sp.eye(adjacency_matrix.shape[0])))
        # print(labels)
    
    elif dataset == 'cora-full':
        adjacency_matrix, features_matrix, labels = load_npz_to_sparse_graph(dataset_dir + CORA_FULL_NPZ)

        adjacency_matrix = normalize(adjacency_matrix.tocoo() + sp.eye(adjacency_matrix.shape[0]))
        adjacency_matrix = sparse_mx_to_torch_sparse_tensor(adjacency_matrix)
        features_matrix = features_matrix.todense()
        features_matrix = torch.FloatTensor(features_matrix)
        labels = torch.LongTensor(labels).squeeze()

        all_class_list = train_class_list + valid_class_list + test_class_list

        class_dict: dict = {}
        for i in all_class_list:
            class_dict[i] = []
        for node_id, cls in enumerate(labels.numpy().tolist()):
            class_dict[cls].append(node_id)
    
    # store node id
    train_node_index: list = list()
    valid_node_index: list = list()
    test_node_index: list = list()
    for idx, class_list in zip([train_node_index, valid_node_index, test_node_index],
                               [train_class_list, valid_class_list, test_class_list]):
        for class_id in class_list:
            idx.extend(class_dict[class_id])

    # {class_id => [node0_id, node1_id, ...]}
    class_train_dict = defaultdict(list)
    for one in train_class_list:
        for i, label in enumerate(labels.numpy().tolist()):
            if label == one:
                class_train_dict[one].append(i)
    class_valid_dict = defaultdict(list)
    for one in valid_class_list:
        for i, label in enumerate(labels.numpy().tolist()):
            if label == one:
                class_valid_dict[one].append(i)

    class_test_dict = defaultdict(list)
    for one in test_class_list:
        for i, label in enumerate(labels.numpy().tolist()):
            if label == one:
                class_test_dict[one].append(i)

    graph = Graph(adjacency_matrix, features_matrix, labels,
                  train_node_index, valid_node_index, test_node_index,
                  class_train_dict, class_valid_dict, class_test_dict)

    return graph




def print_content_to_file(data, file_name):
    """ This function is used for debugging

    """
    with open("./debug/" + file_name, 'w') as f:
        f.write(data)
        



In [None]:
import argparse
import json
import random
from collections import defaultdict

import numpy as np
import torch
from torch import Tensor, optim
from sklearn import metrics
from sklearn.linear_model import LogisticRegression


# the name of dataset
datasets = ['cora-full','dblp','Amazon_eletronics']
CORA_FULL = 'cora-full'
AMAZON_ELECTRONICS = 'Amazon_eletronics'
DBLP = 'dblp'

# the name of mode
TRAIN = 'train'
VALID = 'valid'
TEST = 'test'



# K and N in experiment
# k nodes for each of n classes
Ks: list = [3, 5]
Ns: list = [5, 10]

query_size = 10

# repeat times for each (n, k)
repeat_times = 5

final_results = defaultdict(dict)

# loss function
loss_function = torch.nn.CrossEntropyLoss()

# parameters
class args:
    use_cuda = True
    seed = 1234
    epochs = 2000
    test_epochs = 100
    lr = 0.05
    weight_decay = 5e-4
    hidden1 = 16
    hidden2 = 16
    dropout = 0.2


def main():
    """main function of the whole project
    """

    # parameter

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)

    train_and_test()


def train_and_test():
    # conduct the experiment in each dataset
    for dataset in datasets:
        graph = data_preprocess(dataset)

        adj_dense = graph.adjacency_matrix.to_dense()
        adj_dense = adj_dense.cuda()

        for n in Ns:
            for k in Ks:
                for repeat in range(repeat_times):
                    print("begin ", dataset, "n= ", n, "k= ", k)

                    # two models in class-level and node level adaption
                    node_level_model: GNN4NodeLevel = GNN4NodeLevel(nfeat=graph.features_matrix.shape[1],
                                                                    nhid=args.hidden1,
                                                                    dropout=args.dropout)
                    class_level_model: GNN4ClassLevel = GNN4ClassLevel(nfeat=args.hidden1,
                                                                       nhid=args.hidden2,
                                                                       dropout=args.dropout)
                    support_labels: Tensor = torch.zeros(n * k, dtype=torch.long)
                    query_labels: Tensor = torch.zeros(n * k, dtype=torch.long)

                    for i in range(n):
                        support_labels[i * k:(i + 1) * k] = i
                    for i in range(n):
                        query_labels[i * query_size:(i + 1) * query_size] = i

                    classifier: Linear = Linear(args.hidden1, graph.labels.max().item() + 1)
                    optimizer: optim.Adam = optim.Adam(
                        [{'params': class_level_model.parameters()},
                         {'params': classifier.parameters()},
                         {'params': node_level_model.parameters()}],
                        lr=args.lr, weight_decay=args.weight_decay)

                    if args.use_cuda:
                        node_level_model = node_level_model.cuda()
                        class_level_model = class_level_model.cuda()

                        graph.features_matrix = graph.features_matrix.cuda()
                        graph.adjacency_matrix = graph.adjacency_matrix.cuda()
                        graph.labels = graph.labels.cuda()

                        classifier = classifier.cuda()

                        support_labels = support_labels.cuda()
                        query_labels = query_labels.cuda()

                    def calculate_accuracy(epoch: int,
                                           n: int, k: int,
                                           mode: str) -> float:
                        if mode == 'train':
                            class_level_model.train()
                            optimizer.zero_grad()
                        else:
                            class_level_model.eval()

                            # first-step node representation?
                        emb_features = node_level_model(graph.features_matrix, graph.adjacency_matrix)

                        target_idx = []
                        target_graph_adj_and_feat = []
                        support_graph_adj_and_feat = []

                        pos_node_idx = []

                        if mode == 'train':
                            class_dict = graph.class_train_dict
                        elif mode == 'test':
                            class_dict = graph.class_test_dict
                        elif mode == 'valid':
                            class_dict = graph.class_valid_dict

                        K = k
                        N = n
                        Q = query_size

                        classes = np.random.choice(list(class_dict.keys()), N, replace=False).tolist()

                        pos_graph_adj_and_feat = []
                        # construct class-ego subgraphs?
                        for i in classes:
                            # sample from one specific class
                            sampled_idx = np.random.choice(class_dict[i], K + Q, replace=False).tolist()
                            pos_node_idx.extend(sampled_idx[:K])
                            target_idx.extend(sampled_idx[K:])

                            class_pos_idx = sampled_idx[:K]

                            # why k = 1?
                            if K == 1 and torch.nonzero(adj_dense[class_pos_idx, :]).shape[0] == 1:
                                pos_class_graph_adj = adj_dense[class_pos_idx, class_pos_idx].reshape([1, 1])
                                pos_graph_feat = emb_features[class_pos_idx]
                            else:
                                pos_graph_neighbors = torch.nonzero(adj_dense[class_pos_idx, :].sum(0)).squeeze()

                                pos_graph_adj = adj_dense[pos_graph_neighbors, :][:, pos_graph_neighbors]

                                pos_class_graph_adj = torch.eye(pos_graph_neighbors.shape[0] + 1, dtype=torch.float)

                                pos_class_graph_adj[1:, 1:] = pos_graph_adj

                                pos_graph_feat = torch.cat([emb_features[class_pos_idx].mean(0, keepdim=True),
                                                            emb_features[pos_graph_neighbors]], 0)

                            if dataset != 'ogbn-arxiv':
                                pos_class_graph_adj = pos_class_graph_adj.cuda()

                            pos_graph_adj_and_feat.append((pos_class_graph_adj, pos_graph_feat))

                        target_graph_adj_and_feat = []
                        for node in target_idx:
                            if torch.nonzero(adj_dense[node, :]).shape[0] == 1:
                                pos_graph_adj = adj_dense[node, node].reshape([1, 1])
                                pos_graph_feat = emb_features[node].unsqueeze(0)
                            else:
                                pos_graph_neighbors = torch.nonzero(adj_dense[node, :]).squeeze()
                                pos_graph_neighbors = torch.nonzero(adj_dense[pos_graph_neighbors, :].sum(0)).squeeze()
                                pos_graph_adj = adj_dense[pos_graph_neighbors, :][:, pos_graph_neighbors]
                                pos_graph_feat = emb_features[pos_graph_neighbors]

                            target_graph_adj_and_feat.append((pos_graph_adj, pos_graph_feat))

                        class_generate_emb = torch.stack([sub[1][0] for sub in pos_graph_adj_and_feat], 0).mean(0)

                        parameters = class_level_model.generater(class_generate_emb)

                        gc1_parameters = parameters[:(args.hidden1 + 1) * args.hidden2 * 2]
                        gc2_parameters = parameters[(args.hidden1 + 1) * args.hidden2 * 2:]

                        gc1_w = gc1_parameters[:args.hidden1 * args.hidden2 * 2].reshape(
                            [2, args.hidden1, args.hidden2])
                        gc1_b = gc1_parameters[args.hidden1 * args.hidden2 * 2:].reshape([2, args.hidden2])

                        gc2_w = gc2_parameters[:args.hidden2 * args.hidden2 * 2].reshape(
                            [2, args.hidden2, args.hidden2])
                        gc2_b = gc2_parameters[args.hidden2 * args.hidden2 * 2:].reshape([2, args.hidden2])

                        class_level_model.eval()
                        ori_emb = []
                        for i, one in enumerate(target_graph_adj_and_feat):
                            sub_adj, sub_feat = one[0], one[1]
                            ori_emb.append(class_level_model(sub_feat, sub_adj, gc1_w, gc1_b, gc2_w, gc2_b).mean(0))  # .mean(0))

                        target_embs = torch.stack(ori_emb, 0)

                        class_ego_embs = []
                        for sub_adj, sub_feat in pos_graph_adj_and_feat:
                            class_ego_embs.append(class_level_model(sub_feat, sub_adj, gc1_w, gc1_b, gc2_w, gc2_b)[0])
                        class_ego_embs = torch.stack(class_ego_embs, 0)

                        target_embs = target_embs.reshape([N, Q, -1]).transpose(0, 1)

                        support_features = emb_features[pos_node_idx].reshape([N, K, -1])
                        class_features = support_features.mean(1)
                        taus = []
                        for j in range(N):
                            taus.append(torch.linalg.norm(support_features[j] - class_features[j], -1).sum(0))
                        taus = torch.stack(taus, 0)

                        similarities = []
                        for j in range(Q):
                            class_contras_loss, similarity = InforNCE_Loss(target_embs[j],
                                                                           class_ego_embs / taus.unsqueeze(-1), tau=0.5,dataset=dataset)
                            similarities.append(similarity)

                        loss_supervised = loss_function(classifier(emb_features[graph.train_node_index]), graph.labels[graph.train_node_index])

                        loss = loss_supervised

                        labels_train = graph.labels[target_idx]
                        for j, class_idx in enumerate(classes[:N]):
                            labels_train[labels_train == class_idx] = j

                        loss += loss_function(torch.stack(similarities, 0).transpose(0, 1).reshape([N * Q, -1]), labels_train)

                        acc_train = accuracy(torch.stack(similarities, 0).transpose(0, 1).reshape([N * Q, -1]),
                                             labels_train)

                        if mode == 'valid' or mode == 'test' or (mode == 'train' and epoch % 250 == 249):
                            support_features = l2_normalize(emb_features[pos_node_idx].detach().cpu()).numpy()
                            query_features = l2_normalize(emb_features[target_idx].detach().cpu()).numpy()

                            support_labels = torch.zeros(N * K, dtype=torch.long)
                            for i in range(N):
                                support_labels[i * K:(i + 1) * K] = i

                            query_labels = torch.zeros(N * Q, dtype=torch.long)
                            for i in range(N):
                                query_labels[i * Q:(i + 1) * Q] = i

                            clf = LogisticRegression(penalty='l2',
                                                     random_state=0,
                                                     C=1.0,
                                                     solver='lbfgs',
                                                     max_iter=1000,
                                                     multi_class='multinomial')
                            clf.fit(support_features, support_labels.numpy())
                            query_ys_pred = clf.predict(query_features)

                            acc_train = metrics.accuracy_score(query_labels, query_ys_pred)

                        if mode == 'train':
                            loss.backward()
                            optimizer.step()

                        if epoch % 250 == 249 and mode == 'train':
                            print('Epoch: {:04d}'.format(epoch + 1),
                                  'loss_train: {:.4f}'.format(loss.item()),
                                  'acc_train: {:.4f}'.format(acc_train.item()))
                        return acc_train.item()

                    # begin to train and test
                    cnt: int = 0
                    valid_accuracy_best: float = 0.0
                    test_accuracy_best: list = []
                    for epoch in range(args.epochs):
                        train_accuracy: float = calculate_accuracy(
                                                                   epoch=epoch,
                                                                   n=n, k=k,
                                                                   mode=TRAIN)

                        # epoch for test and valid
                        if epoch % 50 == 0 and epoch != 0:
                            tmp_accuracies: list = []
                            for test_epoch in range(50):
                                tmp_accuracy = calculate_accuracy(
                                                                  epoch=test_epoch,
                                                                  n=n, k=k,
                                                                  mode=TEST)
                                tmp_accuracies.append(tmp_accuracy)

                            valid_accuracies: list = []
                            for valid_epoch in range(50):
                                tmp_accuracy = calculate_accuracy(
                                                                  epoch=valid_epoch,
                                                                  n=n, k=k,
                                                                  mode=VALID)
                                valid_accuracies.append(tmp_accuracy)

                            valid_accuracy = np.array(valid_accuracies).mean(axis=0)

                            print("Epoch: {:04d} Meta-valid_Accuracy: {:.4f}".format(epoch + 1, valid_accuracy))

                            if valid_accuracy > valid_accuracy_best:
                                valid_accuracy_best = valid_accuracy
                                test_accuracy_best = tmp_accuracies
                                cnt = 0
                            else:
                                cnt += 1
                                if cnt >= 10:
                                    break

                    print('Test Acc', np.array(test_accuracy_best).mean(axis=0))
                    final_results[dataset]['{}-way {}-shot {}-repeat'.format(n, k, repeat)] = [
                        np.array(test_accuracy_best).mean(axis=0)]
                    json.dump(final_results[dataset], open('./TENT-result_{}.json'.format(dataset), 'w'))

                final_accuracies: list = []
                for i in range(repeat_times):
                    final_accuracies.append(final_results[dataset]['{}-way {}-shot {}-repeat'.format(n, k, i)][0])

                final_results[dataset]['{}-way {}-shot'.format(n, k)] = [np.mean(final_accuracies)]
                final_results[dataset]['{}-way {}-shot_print'.format(n, k)] = 'acc: {:.4f}'.format(
                    np.mean(final_accuracies))

                json.dump(final_results[dataset], open('./TENT-result_{}.json'.format(dataset), 'w'))

                del node_level_model
                del class_level_model

        del graph
        del adj_dense





# return 0.0


# entry of the program
if __name__ == '__main__':
    main()


begin  cora-full n=  5 k=  3
Epoch: 0051 Meta-valid_Accuracy: 0.6396
Epoch: 0101 Meta-valid_Accuracy: 0.6464
Epoch: 0151 Meta-valid_Accuracy: 0.6156
Epoch: 0201 Meta-valid_Accuracy: 0.6528
Epoch: 0250 loss_train: 0.9675 acc_train: 0.9400
Epoch: 0251 Meta-valid_Accuracy: 0.6452
Epoch: 0301 Meta-valid_Accuracy: 0.6448
Epoch: 0351 Meta-valid_Accuracy: 0.6264
Epoch: 0401 Meta-valid_Accuracy: 0.6464
Epoch: 0451 Meta-valid_Accuracy: 0.6580
Epoch: 0500 loss_train: 0.9370 acc_train: 0.9800
Epoch: 0501 Meta-valid_Accuracy: 0.6632
Epoch: 0551 Meta-valid_Accuracy: 0.6588
Epoch: 0601 Meta-valid_Accuracy: 0.6948
Epoch: 0651 Meta-valid_Accuracy: 0.6588
Epoch: 0701 Meta-valid_Accuracy: 0.6700
Epoch: 0750 loss_train: 1.1206 acc_train: 0.9400
Epoch: 0751 Meta-valid_Accuracy: 0.6536
Epoch: 0801 Meta-valid_Accuracy: 0.6820
Epoch: 0851 Meta-valid_Accuracy: 0.6620
Epoch: 0901 Meta-valid_Accuracy: 0.6560
Epoch: 0951 Meta-valid_Accuracy: 0.6740
Epoch: 1000 loss_train: 0.9437 acc_train: 0.9600
Epoch: 1001 Met