In [1]:
import dgl
import numpy as np
import torch
import torch.nn as nn
import tqdm
from dgl.nn.pytorch import RelGraphConv
from mydataset import JDDataset
from args import args
from utils import get_subset_g
from model import LinkPredict
from trainer import Trainer
from utils import calc_mrr
import logging
import torch.nn.functional as F
from torchmetrics.regression import TweedieDevianceScore

if __name__ == "__main__":
    data = JDDataset(reverse=False, name=args.data_name,
                     raw_dir=f'../data/{args.data_name}/{args.task}', train_path=args.train_path, eval_path=args.eval_path, test_path=args.test_path)
    g = data[0]
    num_nodes = g.num_nodes()
    num_rels = data.num_rels
    pos_num_nodes = (g.edges()[0].max() + 1).item()
    skill_num_nodes = (g.edges()[1].max() - g.edges()[0].max()).item()

    if args.bias == 'yes':
        entity2embedding = torch.load(
            f'../data/{args.data_name}/{args.task}/entity2embedding.pt')
    else:
        entity2embedding = None

    if args.time == "yes":
        time_embedding = torch.load(
            f'../data/{args.data_name}/{args.task}/time_embedding.pt')
    else:
        time_embedding = None

    rg_loss_fn = {
        "l1": F.l1_loss,
        "mse": F.mse_loss,
        "tweedie": TweedieDevianceScore(1.5)
    }

    rg_activate_fn = {
        "elu": nn.ELU(),
        "softplus": nn.Softplus(),
        "relu": nn.ReLU(),
        "sigmoid": nn.Sigmoid(),
        "leakyrelu": nn.LeakyReLU()
    }

    model = LinkPredict(num_nodes, pos_num_nodes, skill_num_nodes, num_rels, cross_attn=args.cross_attn, embedding=entity2embedding, time=args.time,
                        rg_weight=args.rg_weight, lp_weight=args.lp_weight, rank_weight=args.rank_weight, con_weight=args.con_weight,
                        gaussian=args.gaussian, bias=args.bias, initial_embedding=args.initial_embedding,
                        rg_loss_fn=rg_loss_fn[args.rg_loss_fn], rg_activate_fn=rg_activate_fn[args.rg_activate_fn]).to(args.device)
    for data in ['Dai', 'Fin', 'IT', 'Man']:
        pred = torch.load(f'/code/chenxi02/AAAI/data/{data}/dates/STAA/output.pt')
        label = torch.load(f'/code/chenxi02/AAAI/data/{data}/dates/7/matrix.pt')
        mae = model.calc_metrics(pred, label)
        print



usage: ipykernel_launcher.py [-h] [--data_name DATA_NAME] [--task TASK]
                             [--rg_loss_fn RG_LOSS_FN] [--device DEVICE]
                             [--train_path TRAIN_PATH] [--test_path TEST_PATH]
                             [--eval_path EVAL_PATH] [--mode MODE]
                             [--time TIME] [--fix_emb FIX_EMB]
                             [--cross_attn CROSS_ATTN] [--gaussian GAUSSIAN]
                             [--bias BIAS] [--owner_id OWNER_ID]
                             [--rg_activate_fn RG_ACTIVATE_FN]
                             [--num_epochs NUM_EPOCHS] [--eval_step EVAL_STEP]
                             [--sample_size SAMPLE_SIZE] [--lr LR]
                             [--rg_weight RG_WEIGHT] [--lp_weight LP_WEIGHT]
                             [--rank_weight RANK_WEIGHT]
                             [--con_weight CON_WEIGHT] [--seed SEED] [--k K]
                             [--initial_embedding INITIAL_EMBEDDING]
ipykernel_launc

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
