In [1]:
def parse_args():
    class Args:
        p=0
        data_path = '../'
        seed = 123
        dataset = 'baby'
        verbose = 5
        epoch = 2000
        batch_size = 1024
        regs = '[1e-5,1e-5,1e-2]'
        lr = 0.0005
        model_name = 'lattice'
        embed_size = 64
        feat_embed_dim = 64
        weight_size = '[64,64]'
        core = 5
        topk = 10
        lambda_coeff = 0.9
        loss_ratio=0.03
        cf_model = 'lightgcn'
        n_layers = 1
        layers = 1
        sparse = 1
        norm_type = 'sym'
        mess_dropout = '[0.1, 0.1]'
        early_stopping_patience = 7
        gpu_id = 0
        Ks = '[10, 20,50]'
        test_flag = 'part'
        shuffle='text'
    return Args()


In [2]:
import numpy as np
from sklearn.metrics import roc_auc_score

def recall(rank, ground_truth, N):
    return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))


def precision_at_k(r, k):
    """Score is precision @ k
    Relevance is binary (nonzero is relevant).
    Returns:
        Precision @ k
    Raises:
        ValueError: len(r) must be >= k
    """
    assert k >= 1
    r = np.asarray(r)[:k]
    return np.mean(r)


def average_precision(r,cut):
    """Score is average precision (area under PR curve)
    Relevance is binary (nonzero is relevant).
    Returns:
        Average precision
    """
    r = np.asarray(r)
    out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
    if not out:
        return 0.
    return np.sum(out)/float(min(cut, np.sum(r)))


def mean_average_precision(rs):
    """Score is mean average precision
    Relevance is binary (nonzero is relevant).
    Returns:
        Mean average precision
    """
    return np.mean([average_precision(r) for r in rs])


def dcg_at_k(r, k, method=1):
    """Score is discounted cumulative gain (dcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Returns:
        Discounted cumulative gain
    """
    r = np.asfarray(r)[:k]
    if r.size:
        if method == 0:
            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
        elif method == 1:
            return np.sum(r / np.log2(np.arange(2, r.size + 2)))
        else:
            raise ValueError('method must be 0 or 1.')
    return 0.


def ndcg_at_k(r, k, method=1):
    """Score is normalized discounted cumulative gain (ndcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Returns:
        Normalized discounted cumulative gain
    """
    dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k, method) / dcg_max


def recall_at_k(r, k, all_pos_num):
    r = np.asfarray(r)[:k]
    if all_pos_num == 0:
        return 0
    else:
        return np.sum(r) / all_pos_num


def hit_at_k(r, k):
    r = np.array(r)[:k]
    if np.sum(r) > 0:
        return 1.
    else:
        return 0.

def mrr_at_k(r, k):
    r = np.array(r)[:k]
    #print(r)
    if np.sum(r) > 0:
        #print(1/(np.where(r==1.0)[0]+1).astype(float)[0])
        return 1/(np.where(r==1.0)[0]+1).astype(float)[0]
    else:
        return 0.

def F1(pre, rec):
    if pre + rec > 0:
        return (2.0 * pre * rec) / (pre + rec)
    else:
        return 0.

def auc(ground_truth, prediction):
    try:
        res = roc_auc_score(y_true=ground_truth, y_score=prediction)
    except Exception:
        res = 0.
    return res

In [3]:
import numpy as np
import random as rd
import scipy.sparse as sp
from time import time
import json
#from utility.parser import parse_args
args = parse_args()

class Data(object):
    def __init__(self, path, batch_size):
        self.path = path + '/%d-core' % args.core
        self.batch_size = batch_size

        train_file = path + '/%d-core/train.json' % (args.core)
        val_file = path + '/%d-core/val.json' % (args.core)
        test_file = path + '/%d-core/test.json'  % (args.core)

        #get number of users and items
        self.n_users, self.n_items = 0, 0
        self.n_train, self.n_test = 0, 0
        self.neg_pools = {}

        self.exist_users = []

        train = json.load(open(train_file))
        test = json.load(open(test_file))
        val = json.load(open(val_file))
        for uid, items in train.items():
            if len(items) == 0:
                continue
            uid = int(uid)
            self.exist_users.append(uid)
            self.n_items = max(self.n_items, max(items))
            self.n_users = max(self.n_users, uid)
            self.n_train += len(items)

        for uid, items in test.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_users = max(self.n_users, uid)
                self.n_test += len(items)
            except:
                continue

        for uid, items in val.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_users = max(self.n_users, uid)
                self.n_val += len(items)
            except:
                continue

        self.n_items += 1
        self.n_users += 1

        self.print_statistics()

        self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)
        self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32)
        
        self.train_items, self.test_set, self.val_set = {}, {}, {}
        for uid, train_items in train.items():
            if len(train_items) == 0:
                continue
            uid = int(uid)
            for idx, i in enumerate(train_items):
                self.R[uid, i] = 1.

            self.train_items[uid] = train_items
        sp.save_npz(self.path + '/R.npz', self.R.tocsr())
        
        self.my_test_set=[]
        for uid, test_items in test.items():
            uid = int(uid)
            if len(test_items) == 0:
                continue
            for i in test_items:
                self.my_test_set.append([uid,i])
            try:
                self.test_set[uid] = test_items
            except:
                continue

        for uid, val_items in val.items():
            uid = int(uid)
            if len(val_items) == 0:
                continue
            try:
                self.val_set[uid] = val_items
            except:
                continue            

    def get_R_mat(self):
        t1 = time()
        R = sp.load_npz(self.path + '/R.npz')
        print('already load rating matrix', R.shape, time() - t1)
        return R
    
    def get_adj_mat(self):
        try:
            t1 = time()
            adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')
            norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')
            mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')
            print('already load adj matrix', adj_mat.shape, time() - t1)

        except Exception:
            adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()
            sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat)
            sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat)
            sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat)
        return adj_mat, norm_adj_mat, mean_adj_mat

    def create_adj_mat(self):
        t1 = time()
        adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
        adj_mat = adj_mat.tolil()
        R = self.R.tolil()

        adj_mat[:self.n_users, self.n_users:] = R
        adj_mat[self.n_users:, :self.n_users] = R.T
        adj_mat = adj_mat.todok()
        print('already create adjacency matrix', adj_mat.shape, time() - t1)

        t2 = time()

        def normalized_adj_single(adj):
            rowsum = np.array(adj.sum(1))

            d_inv = np.power(rowsum, -1).flatten()
            d_inv[np.isinf(d_inv)] = 0.
            d_mat_inv = sp.diags(d_inv)

            norm_adj = d_mat_inv.dot(adj)
            # norm_adj = adj.dot(d_mat_inv)
            print('generate single-normalized adjacency matrix.')
            return norm_adj.tocoo()

        def get_D_inv(adj):
            rowsum = np.array(adj.sum(1))

            d_inv = np.power(rowsum, -1).flatten()
            d_inv[np.isinf(d_inv)] = 0.
            d_mat_inv = sp.diags(d_inv)
            return d_mat_inv

        def check_adj_if_equal(adj):
            dense_A = np.array(adj.todense())
            degree = np.sum(dense_A, axis=1, keepdims=False)

            temp = np.dot(np.diag(np.power(degree, -1)), dense_A)
            print('check normalized adjacency matrix whether equal to this laplacian matrix.')
            return temp

        norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))
        mean_adj_mat = normalized_adj_single(adj_mat)

        print('already normalize adjacency matrix', time() - t2)
        return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr()


    def sample(self):
        if self.batch_size <= self.n_users:
            users = rd.sample(self.exist_users, self.batch_size)
        else:
            users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
        # users = self.exist_users[:]

        def sample_pos_items_for_u(u, num):
            pos_items = self.train_items[u]
            n_pos_items = len(pos_items)
            pos_batch = []
            while True:
                if len(pos_batch) == num: break
                pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
                pos_i_id = pos_items[pos_id]

                if pos_i_id not in pos_batch:
                    pos_batch.append(pos_i_id)
            return pos_batch

        def sample_neg_items_for_u(u, num):
            neg_items = []
            while True:
                if len(neg_items) == num: break
                neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
                if neg_id not in self.train_items[u] and neg_id not in neg_items:
                    neg_items.append(neg_id)
            return neg_items

        def sample_neg_items_for_u_from_pools(u, num):
            neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
            return rd.sample(neg_items, num)

        pos_items, neg_items = [], []
        for u in users:
            pos_items += sample_pos_items_for_u(u, 1)
            neg_items += sample_neg_items_for_u(u, 1)
            # neg_items += sample_neg_items_for_u(u, 3)
        return users, pos_items, neg_items



    def print_statistics(self):
        print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
        print('n_interactions=%d' % (self.n_train + self.n_test))
        print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))


In [4]:
#import utility.metrics as metrics
#from utility.parser import parse_args
#from utility.load_data import Data
import multiprocessing
import heapq
import torch
import pickle
import numpy as np
from time import time

cores = multiprocessing.cpu_count() // 5

args = parse_args()
Ks = eval(args.Ks)

data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
BATCH_SIZE = args.batch_size

def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
    item_score = {}
    for i in test_items:
        item_score[i] = rating[i]

    K_max = max(Ks)
    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)

    r = []
    for i in K_max_item_score:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc = 0.
    return r, auc

def get_auc(item_score, user_pos_test):
    item_score = sorted(item_score.items(), key=lambda kv: kv[1])
    item_score.reverse()
    item_sort = [x[0] for x in item_score]
    posterior = [x[1] for x in item_score]

    r = []
    for i in item_sort:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc = metrics.auc(ground_truth=r, prediction=posterior)
    return auc

def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
    item_score = {}
    for i in test_items:
        item_score[i] = rating[i]

    K_max = max(Ks)
    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)

    r = []
    for i in K_max_item_score:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc = get_auc(item_score, user_pos_test)
    return r, auc

def get_performance(user_pos_test, r, auc, Ks):
    precision, recall, ndcg, hit_ratio, mrr = [], [], [], [], []

    for K in Ks:
        precision.append(precision_at_k(r, K))
        recall.append(recall_at_k(r, K, len(user_pos_test)))
        ndcg.append(ndcg_at_k(r, K))
        hit_ratio.append(hit_at_k(r, K))
        mrr.append(mrr_at_k(r, K))

    return {'recall': np.array(recall), 'precision': np.array(precision),
            'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc, 'mrr': np.array(mrr)}


def test_one_user(x):
    # user u's ratings for user u
    is_val = x[-1]
    rating = x[0]
    #uid
    u = x[1]
    #user u's items in the training set
    try:
        training_items = data_generator.train_items[u]
    except Exception:
        training_items = []
    #user u's items in the test set
    if is_val:
        user_pos_test = data_generator.val_set[u]
    else:
        user_pos_test = data_generator.test_set[u]

    all_items = set(range(ITEM_NUM))

    test_items = list(all_items - set(training_items))
    #test_items = list(all_items)
    if args.test_flag == 'part':
        r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)
    else:
        r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)

    return get_performance(user_pos_test, r, auc, Ks)


def test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val, drop_flag=False, batch_test_flag=False):
    result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),
              'hit_ratio': np.zeros(len(Ks)),'mrr': np.zeros(len(Ks)), 'auc': 0.}
    pool = multiprocessing.Pool(cores)

    u_batch_size = BATCH_SIZE * 2
    i_batch_size = BATCH_SIZE

    test_users = users_to_test
    n_test_users = len(test_users)
    n_user_batchs = n_test_users // u_batch_size + 1
    count = 0

    for u_batch_id in range(n_user_batchs):
        start = u_batch_id * u_batch_size
        end = (u_batch_id + 1) * u_batch_size
        user_batch = test_users[start: end]
        if batch_test_flag:
            n_item_batchs = ITEM_NUM // i_batch_size + 1
            rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))

            i_count = 0
            for i_batch_id in range(n_item_batchs):
                i_start = i_batch_id * i_batch_size
                i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)

                item_batch = range(i_start, i_end)
                u_g_embeddings = ua_embeddings[user_batch]
                i_g_embeddings = ia_embeddings[item_batch]
                i_rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))

                rate_batch[:, i_start: i_end] = i_rate_batch
                i_count += i_rate_batch.shape[1]

            assert i_count == ITEM_NUM

        else:
            item_batch = range(ITEM_NUM)
            u_g_embeddings = ua_embeddings[user_batch]
            #print(max(item_batch))
            i_g_embeddings = ia_embeddings[item_batch]
            rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))

        rate_batch = rate_batch.detach().cpu().numpy()
        user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch))

        batch_result = pool.map(test_one_user, user_batch_rating_uid)
        count += len(batch_result)

        for re in batch_result:
            result['precision'] += re['precision'] / n_test_users
            result['recall'] += re['recall'] / n_test_users
            result['ndcg'] += re['ndcg'] / n_test_users
            result['hit_ratio'] += re['hit_ratio'] / n_test_users
            result['auc'] += re['auc'] / n_test_users
            result['mrr'] += re['mrr'] / n_test_users

    assert count == n_test_users
    pool.close()
    return result

n_users=19445, n_items=7050
n_interactions=139110
n_train=118551, n_test=20559, sparsity=0.00101


In [5]:
import os
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.sparse as sparse
import torch.nn.functional as F

#from utility.parser import parse_args
args = parse_args()

def build_knn_neighbourhood(adj, topk):
    knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
    weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
    return weighted_adjacency_matrix

def compute_normalized_laplacian(adj):
    if adj.shape[0]==adj.shape[1]:
        rowsum = torch.sum(adj, -1)
        d_inv_sqrt = torch.pow(rowsum, -0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
        L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
        return L_norm
    else:
        rowsum = torch.sum(adj, -1)
        d_inv_sqrt = torch.pow(rowsum, -0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
        colsum = torch.sum(adj, 0)
        d_inv_sqrt_ = torch.pow(colsum, -0.5)
        d_inv_sqrt_[torch.isinf(d_inv_sqrt_)] = 0.
        d_mat_inv_sqrt_ = torch.diagflat(d_inv_sqrt_)
        L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt_)
        return L_norm


def recombine(img_feat,text_feat):
    img_feat_norm = img_feat.div(torch.norm(img_feat, p=2, dim=-1, keepdim=True))
    text_feat_norm = text_feat.div(torch.norm(text_feat, p=2, dim=-1, keepdim=True))
    rel_i = F.softmax(img_feat_norm.mm(text_feat_norm.T),dim=-1)
    rel_t = F.softmax(text_feat_norm.mm(img_feat_norm.T),dim=-1)
    _, text_indices = torch.max(rel_i, dim=1)
    text_sorted = text_feat[text_indices]
    _, img_indices = torch.max(rel_t, dim=1)
    img_sorted = img_feat[img_indices]
    return img_sorted,text_sorted


class reliability(nn.Module):
    def __init__(self,input_size, hidden_size):
        super().__init__()
        self.query_0 = nn.Linear(input_size, hidden_size, bias=False)
        self.key_0 = nn.Linear(input_size, hidden_size, bias=False)
        self.query_1 = nn.Linear(input_size, hidden_size, bias=False)
        self.key_1 = nn.Linear(input_size, hidden_size, bias=False)
    def forward(self,image_feats, text_feats):
        image_feats_norm = image_feats.div(torch.norm(image_feats, p=2, dim=-1, keepdim=True))
        text_feats_norm = text_feats.div(torch.norm(text_feats, p=2, dim=-1, keepdim=True))
        rel_i = torch.matmul(self.query_0(image_feats_norm),self.key_0(text_feats_norm).T)
        rel_t = torch.matmul(self.query_1(text_feats_norm),self.key_1(image_feats_norm).T)
        #print(rel_i.shape)
        diag_i = torch.diag(rel_i)
        #print(diag_i.shape)
        diag_t = torch.diag(rel_t)
        return diag_i.reshape((1,-1)),diag_t.reshape((1,-1))

class anomaly_encoder(nn.Module):
    def __init__(self,adj,input_size, hidden_size,encoder_layer=3):
        super().__init__()
        self.encoder_layer=encoder_layer
        self.encoder=nn.ModuleList()
        self.encoder_=nn.ModuleList()
        self.encoder.append(nn.Linear(input_size, hidden_size, bias=False))
        for i in range(encoder_layer-1):
            self.encoder.append(nn.Linear(hidden_size, hidden_size, bias=False))
        self.encoder_.append(nn.Linear(input_size, hidden_size, bias=False))
        for i in range(encoder_layer-1):
            self.encoder_.append(nn.Linear(hidden_size, hidden_size, bias=False))
        
        self.ii_adj=adj
    def forward(self,image_feats_norm, text_feats_norm):
        #self.feat_i_tmp=image_feats_norm
        #self.feat_t_tmp=text_feats_norm
        feat_i_all_embeddings = [image_feats_norm]
        feat_t_all_embeddings = [text_feats_norm]
        for i in range(self.encoder_layer):
            tmp_i_embeddings = F.leaky_relu(self.encoder[i](torch.sparse.mm(self.ii_adj, image_feats_norm)))
            image_feats_norm=tmp_i_embeddings
            feat_i_all_embeddings += [image_feats_norm]
            tmp_t_embeddings = F.leaky_relu(self.encoder_[i](torch.sparse.mm(self.ii_adj, text_feats_norm)))
            text_feats_norm=tmp_t_embeddings
            feat_t_all_embeddings += [text_feats_norm]
        feat_i_all_embeddings = torch.stack(feat_i_all_embeddings, dim=1)
        feat_i_all_embeddings = feat_i_all_embeddings.mean(dim=1, keepdim=False)
        feat_t_all_embeddings = torch.stack(feat_t_all_embeddings, dim=1)
        feat_t_all_embeddings = feat_t_all_embeddings.mean(dim=1, keepdim=False)
        return feat_i_all_embeddings, feat_t_all_embeddings

class anomaly_decoder(nn.Module):
    def __init__(self,adj,input_size, hidden_size,decoder_layer=3):
        super().__init__()
        self.decoder_layer=decoder_layer
        self.decoder=nn.ModuleList()
        self.decoder_=nn.ModuleList()
        for i in range(decoder_layer-1):
            self.decoder.append(nn.Linear(hidden_size, hidden_size, bias=False))
        self.decoder.append(nn.Linear(hidden_size, input_size, bias=False))

        for i in range(decoder_layer-1):
            self.decoder_.append(nn.Linear(hidden_size, hidden_size, bias=False))
        self.decoder_.append(nn.Linear(hidden_size, input_size, bias=False))
        self.ii_adj=adj
    def forward(self,image_enc, text_enc):
        #self.feat_i_tmp=image_enc
        #self.feat_t_tmp=text_enc
        feat_i_all_embeddings = [image_enc]
        feat_t_all_embeddings = [text_enc]
        for i in range(self.decoder_layer):
            tmp_i_embeddings = F.leaky_relu(self.decoder[i](torch.sparse.mm(self.ii_adj, image_enc)))
            image_enc=tmp_i_embeddings
            feat_i_all_embeddings += [image_enc]
            tmp_t_embeddings = F.leaky_relu(self.decoder_[i](torch.sparse.mm(self.ii_adj, text_enc)))
            text_enc=tmp_t_embeddings
            feat_t_all_embeddings += [text_enc]
        feat_i_all_embeddings = torch.stack(feat_i_all_embeddings, dim=1)
        feat_i_all_embeddings = feat_i_all_embeddings.mean(dim=1, keepdim=False)
        feat_t_all_embeddings = torch.stack(feat_t_all_embeddings, dim=1)
        feat_t_all_embeddings = feat_t_all_embeddings.mean(dim=1, keepdim=False)
        return feat_i_all_embeddings,feat_t_all_embeddings
  
def build_reliable_knn_neighbourhood(adj, topk,rel):
    adj=adj.mul(F.sigmoid(rel))
    knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
    weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
    return weighted_adjacency_matrix

In [6]:
import torch

def build_sim(context):
    context_norm = context.div(torch.norm(context, p=2, dim=-1, keepdim=True))
    sim = torch.mm(context_norm, context_norm.transpose(1, 0))
    #sim.fill_diagonal_(0.)
    return sim

def build_knn_graph(adj, topk, is_sparse, norm_type):
    adj=adj.to_dense()
    device = adj.device
    knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
    if is_sparse:
        tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]]
        row = [i[0] for i in tuple_list]
        col = [i[1] for i in tuple_list]
        i = torch.LongTensor([row, col]).to(device)
        v = knn_val.flatten()
        return torch.sparse_coo_tensor(i, v, adj.shape)
    else:
        weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
        return weighted_adjacency_matrix

def get_sparse_laplacian(edge_index, edge_weight, num_nodes, normalization='none'):
    def scatter_add(src, index,dim, dim_size):
        output = torch.zeros(dim_size, dtype=src.dtype, device=src.device)
        return output.scatter_add_(dim=dim, index=index, src=src)

    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)

    if normalization == 'sym':
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    elif normalization == 'rw':
        deg_inv = 1.0 / deg
        deg_inv.masked_fill_(deg_inv == float('inf'), 0)
        edge_weight = deg_inv[row] * edge_weight
    return edge_index, edge_weight


def get_dense_laplacian(adj, normalization='none'):
    if normalization == 'sym':
        rowsum = torch.sum(adj, -1)
        d_inv_sqrt = torch.pow(rowsum, -0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
        L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
    elif normalization == 'rw':
        rowsum = torch.sum(adj, -1)
        d_inv = torch.pow(rowsum, -1)
        d_inv[torch.isinf(d_inv)] = 0.
        d_mat_inv = torch.diagflat(d_inv)
        L_norm = torch.mm(d_mat_inv, adj)
    elif normalization == 'none':
        L_norm = adj
    return L_norm

def build_rel_knn_graph(adj, topk, is_sparse, norm_type,rel):
    adj=adj.to_dense()
    adj=adj.mul(F.sigmoid(rel))
    device = adj.device
    knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
    if is_sparse:
        tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]]
        row = [i[0] for i in tuple_list]
        col = [i[1] for i in tuple_list]
        i = torch.LongTensor([row, col]).to(device)
        v = knn_val.flatten()
        return torch.sparse_coo_tensor(i, v, adj.shape)
    else:
        weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
        return weighted_adjacency_matrix
        
def cal_sum_lap(weight,a_sparse_list):
    a_list=[]
    for idx,i in enumerate(a_sparse_list):
        a_list.append(weight[:, idx].unsqueeze(dim=1)*i.to_dense())
    mix_a=torch.stack(a_list).sum(dim=0)
    rowsum = torch.sum(mix_a, -1)
    d_inv_sqrt = torch.pow(rowsum, -0.5)
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
    colsum = torch.sum(mix_a, 0)
    d_inv_sqrt_ = torch.pow(colsum, -0.5)
    d_inv_sqrt_[torch.isinf(d_inv_sqrt_)] = 0.
    d_mat_inv_sqrt_ = torch.diagflat(d_inv_sqrt_)
    L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, mix_a), d_mat_inv_sqrt_)
    return L_norm


def get_dense_norm_rowandcol(adj):
    adj=adj.to_dense()
    rowsum = torch.sum(adj, -1)
    d_inv_sqrt = torch.pow(rowsum, -0.5)
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
    colsum = torch.sum(adj, 0)
    d_inv_sqrt_ = torch.pow(colsum, -0.5)
    d_inv_sqrt_[torch.isinf(d_inv_sqrt_)] = 0.
    d_mat_inv_sqrt_ = torch.diagflat(d_inv_sqrt_)
    L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt_)
    return L_norm

        
def sparse_mat_merge(A,B):
    A=A.coalesce()
    B=B.coalesce()
    values_A = A.values()
    indices_A = A.indices()
    values_B = B.values()
    indices_B = B.indices()
    new_indices_B = indices_B.clone()
    new_indices_B[0, :] += norm_adj.shape[0]-B.shape[0]  # 行索引偏移
    new_indices_B[1, :] += norm_adj.shape[0]-B.shape[0]  # 列索引偏移
    mask = (indices_A[0] < norm_adj.shape[0]-B.shape[0]) | (indices_A[0] >= norm_adj.shape[0]) | \
           (indices_A[1] < norm_adj.shape[0]-B.shape[0]) | (indices_A[1] >= norm_adj.shape[0])
    values_A = values_A[mask]
    indices_A = indices_A[:, mask]
    new_values = torch.cat([values_A, values_B], dim=0)
    new_indices = torch.cat([indices_A, new_indices_B], dim=1)
    new_A = torch.sparse_coo_tensor(new_indices, new_values, A.size())
    return new_A



class PositiveLinear(nn.Module):
    def __init__(self, in_features, out_features, seed):
        super(PositiveLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.log_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.seed = seed
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.manual_seed(self.seed)
        nn.init.xavier_uniform_(self.log_weight)
    
    def forward(self, input):
        return nn.functional.linear(input, self.log_weight.exp())

class NegativeLinear(nn.Module):
    def __init__(self, in_features, out_features, seed):
        super(NegativeLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.log_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.seed = seed
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.manual_seed(self.seed)
        nn.init.xavier_uniform_(self.log_weight)
    
    def forward(self, input):
        return nn.functional.linear(input, -self.log_weight.exp())


In [7]:
class MICRO(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats,rating):
        super().__init__()
        self.num_envs = 2
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.weight_size = weight_size
        self.n_ui_layers = len(self.weight_size)
        self.weight_size = [self.embedding_dim] + self.weight_size
        self.user_embedding = nn.Embedding(n_users, self.embedding_dim)
        self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim).to('cuda')
        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_id_embedding.weight)
        self.rating=compute_normalized_laplacian(rating.to_dense()).to_sparse().to('cuda')
        
        if args.cf_model == 'ngcf':
            self.GC_Linear_list = nn.ModuleList()
            self.Bi_Linear_list = nn.ModuleList()
            self.dropout_list = nn.ModuleList()
            for i in range(self.n_ui_layers):
                self.GC_Linear_list.append(nn.Linear(self.weight_size[i], self.weight_size[i+1]))
                self.Bi_Linear_list.append(nn.Linear(self.weight_size[i], self.weight_size[i+1]))
                self.dropout_list.append(nn.Dropout(dropout_list[i]))
        
        
        self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False).to('cuda')
        self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False).to('cuda')
        self.image_feats_norm = self.image_embedding.weight.detach().div(torch.norm(self.image_embedding.weight.detach(), p=2, dim=-1, keepdim=True))
        self.text_feats_norm = self.text_embedding.weight.detach().div(torch.norm(self.text_embedding.weight.detach(), p=2, dim=-1, keepdim=True))

        self.image_sim = build_sim(self.image_embedding.weight.detach())
        self.text_sim = build_sim(self.text_embedding.weight.detach())
        

        #ini_image_adj = build_sim(self.image_embedding.weight.detach()) 
        #ini_image_adj = build_knn_normalized_graph(ini_image_adj, topk=args.topk, is_sparse=args.sparse, norm_type=args.norm_type)
        #ini_text_adj = build_sim(self.text_embedding.weight.detach()) 
        #ini_text_adj = build_knn_normalized_graph(ini_text_adj, topk=args.topk, is_sparse=args.sparse, norm_type=args.norm_type)
        ii=build_knn_graph(torch.sparse.mm(self.rating.T,self.rating).to_dense(), 50, True, norm_type='sym')
        ii=get_dense_norm_rowandcol(ii)
        #self.an_encoder=anomaly_encoder(ini_text_adj.clone(),ini_image_adj.clone(),image_feats.shape[1],args.embed_size).to('cuda')
        #self.an_decoder=anomaly_decoder(ini_text_adj.clone(),ini_image_adj.clone(),image_feats.shape[1],args.embed_size).to('cuda')
        
        self.an_encoder=anomaly_encoder(ii,image_feats.shape[1],image_feats.shape[1]).to('cuda')
        self.an_decoder=anomaly_decoder(ii,image_feats.shape[1],image_feats.shape[1]).to('cuda')
        
        self.image_trs = nn.Sequential(
            nn.Linear(4 * image_feats.shape[1], 2 * args.embed_size),
            nn.Tanh(),
            nn.Linear(2 * self.embedding_dim, args.embed_size, bias=False)
        )
        self.text_trs = nn.Sequential(
            nn.Linear(4 * text_feats.shape[1], 2 * args.embed_size),
            nn.Tanh(),
            nn.Linear(2 * self.embedding_dim, args.embed_size, bias=False)
        )

        self.hybird_trs = nn.Linear(4*text_feats.shape[1], args.embed_size)

        self.softmax = nn.Softmax(dim=-1)


        self.score = nn.Sequential(
            nn.Linear(text_feats.shape[1], self.embedding_dim),
            nn.Tanh(),
            nn.Linear(self.embedding_dim, 1, bias=False)
        )
        '''
        self.mon_score = nn.Sequential(
            PositiveLinear(text_feats.shape[1], self.embedding_dim),
            nn.Tanh(),
            NegativeLinear(self.embedding_dim, 1)
        )
        '''
        self.mon_score_lsit = []
        seed=[1,100,1000,10000,100000]
        for i in range(self.num_envs):
            self.mon_score_lsit.append(nn.Sequential(
                PositiveLinear(text_feats.shape[1], self.embedding_dim,seed[i]),
                nn.Tanh(),
                NegativeLinear(self.embedding_dim, 1,seed[i])
                ).to('cuda'))
        
        self.score_ = nn.Sequential(
            nn.Linear(image_feats.shape[1], self.embedding_dim),
            nn.Tanh(),
            nn.Linear(self.embedding_dim, 1, bias=False)
        )
        self.query = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Tanh(),
            nn.Linear(self.embedding_dim, 1, bias=False)
        )
        self.query_1 = nn.Sequential(
            nn.Linear(self.embedding_dim, self.embedding_dim),
            nn.Tanh(),
            nn.Linear(self.embedding_dim, 1, bias=False)
        )
        self.dropout_list = nn.ModuleList()
        for i in range(self.n_ui_layers):
            self.dropout_list.append(nn.Dropout(dropout_list[i]))

        self.tau = 0.5
        self.modal_weight = nn.Parameter(torch.Tensor([[0.5, 0.5]]))
        self.softmax = nn.Softmax(dim=0)

        self.env_generators = {}
        for k in range(self.num_envs):
            self.env_generators[k] = self.generator_model()
        
    def mm(self, x, y):
        if args.sparse:
            return torch.sparse.mm(x, y)
        else:
            return torch.mm(x, y)
    
    def sim(self, z1, z2):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())

    def batched_contrastive_loss(self, z1, z2, batch_size=256):
        device = z1.device
        num_nodes = z1.size(0)
        num_batches = (num_nodes - 1) // batch_size + 1
        f = lambda x: torch.exp(x / self.tau)
        indices = torch.arange(0, num_nodes).to(device)
        losses = []

        for i in range(num_batches):
            mask = indices[i * batch_size:(i + 1) * batch_size]
            refl_sim = f(self.sim(z1[mask], z1))  # [B, N]
            between_sim = f(self.sim(z1[mask], z2))  # [B, N]

            losses.append(-torch.log(
                between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
                / (refl_sim.sum(1) + between_sim.sum(1)
                   - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))
                   
        loss_vec = torch.cat(losses)
        return loss_vec.mean()

    import torch.nn as nn



    def generator_model(self):
        model = nn.Sequential(
            nn.Linear(1, self.embedding_dim),  # 输入维度为self.latent_dim，输出维度为1
            nn.ReLU(),  # 激活函数
            # nn.Dropout(p=0.5),  # 如果需要Dropout层，可以取消注释
            # nn.Linear(int(self.latent_dim / 2), 1),  # 如果需要额外的Dense层，可以取消注释
            # nn.ReLU(),  # 激活函数
            # nn.Dropout(p=0.5)  # 如果需要Dropout层，可以取消注释
            nn.Linear(self.embedding_dim, 1, bias=False)  # 输出层
        )
        return model

    def multi_env(self,rel):
        #logit = self.env_generators[k](rel)
        logit = rel.view(-1)
        # bias = 0.0 + 0.0001
        # eps = (bias - (1 - bias)) * torch.rand(logit.shape) + (1 - bias)
        eps = torch.rand(logit.shape, device=logit.device)
        mask_gate_input = torch.log(eps) - torch.log(1 - eps)
        mask_gate_input = (logit + mask_gate_input) / 0.2
        mask_gate_input = torch.sigmoid(mask_gate_input) + 0.5  # self.edge_bias
        return mask_gate_input

    
    def update_graph(self,image_sim,text_sim,img_rel,text_rel):
        print('time to update graph')
        image_adj = build_rel_knn_graph(image_sim, topk=args.topk, is_sparse=args.sparse, norm_type=args.norm_type,rel=img_rel)
        #print(self.image_a.to_dense())
        text_adj = build_rel_knn_graph(text_sim, topk=args.topk, is_sparse=args.sparse, norm_type=args.norm_type,rel=text_rel)
        text_original_adj = text_adj.cuda()
        image_original_adj = image_adj.cuda()
        self.synthesis_adj=cal_sum_lap(torch.tensor([[1,1]]).cuda(),[image_adj,text_adj])
        return get_dense_norm_rowandcol(image_original_adj),get_dense_norm_rowandcol(text_original_adj),self.synthesis_adj

    def forward(self, adj, step_count,build_item_graph=False):
        text_rel_list = []
        img_rel_list = []
        
        u_g_embeddings_list = []
        i_g_embeddings_list = []
        image_item_embeds_list = []
        text_item_embeds_list = []
        h_list = []
        
        img_latent,text_latent=self.an_encoder(self.image_feats_norm,self.text_feats_norm)
        img_rec,text_rec=self.an_decoder(img_latent,text_latent)
        
        #text_res=text_rec-self.text_feats_norm
        #img_res=img_rec-self.image_feats_norm
        #text_rel = self.mon_score((text_rec-self.text_feats_norm)*(text_rec-self.text_feats_norm))
        #img_rel = self.mon_score((img_rec-self.image_feats_norm)*(img_rec-self.image_feats_norm))
        
        #print('diag',diag_i.shape)
        img_rec_loss = ((img_rec-self.image_feats_norm)*(img_rec-self.image_feats_norm)).sum(dim=-1).reshape((1,-1))
        text_rec_loss = ((text_rec-self.text_feats_norm)*(text_rec-self.text_feats_norm)).sum(dim=-1).reshape((1,-1))

        for i in range(self.num_envs):
            text_rel=self.mon_score_lsit[i]((text_rec-self.text_feats_norm)*(text_rec-self.text_feats_norm))
            #text_rel=self.mon_score_lsit[i](text_rec-self.text_feats_norm)
            img_rel=self.mon_score_lsit[i]((img_rec-self.image_feats_norm)*(img_rec-self.image_feats_norm))
            #img_rel=self.mon_score_lsit[i](img_rec-self.image_feats_norm)
            text_rel_list.append(text_rel)
            img_rel_list.append(img_rel)
        #print(text_rel_list)
        #text_rel_list=self.multi_env(text_rel)
        #img_rel_list=self.multi_env(img_rel)
        
        for k in range(self.num_envs):
            text_rel = self.multi_env(text_rel_list[k])
            img_rel = self.multi_env(img_rel_list[k])
            if step_count%50==0:
                self.image_original_adj,self.text_original_adj,synthesis_adj= self.update_graph(self.image_sim,self.text_sim,
                                                                                                img_rel,text_rel)
            else:
                synthesis_adj=self.synthesis_adj.detach()
    
            image_feats = self.image_trs(torch.cat((img_rec,self.image_embedding.weight,
                                                    self.image_embedding.weight+img_rec,self.image_embedding.weight*img_rec),dim=1))
            text_feats = self.text_trs(torch.cat((text_rec,self.text_embedding.weight,
                                                    self.text_embedding.weight+text_rec,self.text_embedding.weight*text_rec),dim=1))
            if step_count%200==0:
                self.image_adj = build_sim(image_feats) 
                self.image_adj = build_knn_graph(self.image_adj, topk=args.topk, is_sparse=args.sparse, norm_type=args.norm_type)
                self.image_adj = get_dense_norm_rowandcol(self.image_adj)
                self.image_adj = (1 - args.lambda_coeff) * self.image_adj + args.lambda_coeff * self.image_original_adj
                self.text_adj = build_sim(text_feats) 
                self.text_adj = build_knn_graph(self.text_adj, topk=args.topk, is_sparse=args.sparse, norm_type=args.norm_type)
                self.text_adj = get_dense_norm_rowandcol(self.text_adj)
                self.text_adj = (1 - args.lambda_coeff) * self.text_adj + args.lambda_coeff * self.text_original_adj
            else:
                self.image_adj = self.image_adj.detach()
                self.text_adj = self.text_adj.detach()
            image_item_embeds = self.item_id_embedding.weight
            text_item_embeds = self.item_id_embedding.weight
    
            for i in range(args.layers):
                image_item_embeds = self.mm(self.image_adj, image_item_embeds)
                text_item_embeds = self.mm(self.text_adj, text_item_embeds)
    
            att = torch.cat([self.query(image_item_embeds), self.query(text_item_embeds)], dim=-1)
            weight = self.softmax(att)
            h = weight[:, 0].unsqueeze(dim=1) * (image_item_embeds) + weight[:,1].unsqueeze(dim=1) * (text_item_embeds)
            
            uii_i_emb=torch.sparse.mm(self.rating, image_item_embeds)
            uii_t_emb=torch.sparse.mm(self.rating, text_item_embeds)
            image_user_feats=torch.sparse.mm(self.rating, image_feats)
            text_user_feats=torch.sparse.mm(self.rating, text_feats)
            
            att_ = torch.cat([self.query_1(uii_i_emb), self.query_1(uii_t_emb)], dim=-1)
            modal_weight = self.softmax(att_)
            hh=modal_weight[:, 0].unsqueeze(dim=1) * uii_i_emb + modal_weight[:, 1].unsqueeze(dim=1) * uii_t_emb
            
            ego_embeddings_u=self.user_embedding.weight
            ego_embeddings_i=self.item_id_embedding.weight
            user_embeddings=[self.user_embedding.weight]
            item_embeddings=[self.item_id_embedding.weight]
            
            for i in range(self.n_ui_layers):
                side_embeddings_u = torch.sparse.mm(self.rating, ego_embeddings_i)
                side_embeddings_u=self.dropout_list[i](side_embeddings_u)
                
    
                side_embeddings_i = torch.sparse.mm(self.rating.T, ego_embeddings_u) + torch.sparse.mm(synthesis_adj, ego_embeddings_i)
                side_embeddings_i=self.dropout_list[i](side_embeddings_i)
                
                ego_embeddings_u = side_embeddings_u
                user_embeddings += [ego_embeddings_u]
                ego_embeddings_i = side_embeddings_i
                item_embeddings += [ego_embeddings_i]
            user_embeddings = torch.stack(user_embeddings, dim=1)
            user_embeddings = user_embeddings.mean(dim=1, keepdim=False)
            
            item_embeddings = torch.stack(item_embeddings, dim=1)
            item_embeddings = item_embeddings.mean(dim=1, keepdim=False)
            
            i_g_embeddings = item_embeddings
            u_g_embeddings = user_embeddings
            i_g_embeddings = i_g_embeddings + F.normalize(h, p=2, dim=1)
            u_g_embeddings = u_g_embeddings + F.normalize(hh, p=2, dim=1)
            u_g_embeddings_list.append(u_g_embeddings)
            i_g_embeddings_list.append(i_g_embeddings)
            image_item_embeds_list.append(image_item_embeds)
            text_item_embeds_list.append(text_item_embeds)
            h_list.append(h)
        final_u_g_embeddings = torch.stack(u_g_embeddings_list).mean(dim=0)
        final_i_g_embeddings = torch.stack(i_g_embeddings_list).mean(dim=0)
        final_image_item_embeds = torch.stack(image_item_embeds_list).mean(dim=0)
        final_text_item_embeds = torch.stack(text_item_embeds_list).mean(dim=0)
        final_h = torch.stack(h_list).mean(dim=0)
        #return final_u_g_embeddings, final_i_g_embeddings, final_image_item_embeds, final_text_item_embeds,\
        #    final_h, img_rec_loss,text_rec_loss
        return u_g_embeddings_list, i_g_embeddings_list, final_image_item_embeds, final_text_item_embeds,\
            final_h, img_rec_loss,text_rec_loss

In [8]:
import datetime
import math
import os
import random
import sys

from tqdm import tqdm
import time as tm
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.sparse as sparse

#from utility.parser import parse_args
#from Models import LATTICE
#from utility.batch_test import *

args = parse_args()
rec_epoch=2

class Trainer(object):
    def __init__(self, data_config):
        # argument settings
        self.n_users = data_config['n_users']
        self.n_items = data_config['n_items']

        self.model_name = args.model_name
        self.mess_dropout = eval(args.mess_dropout)
        self.lr = args.lr
        self.emb_dim = args.embed_size
        self.batch_size = args.batch_size
        self.weight_size = eval(args.weight_size)
        self.n_layers = len(self.weight_size)
        self.regs = eval(args.regs)
        self.decay = self.regs[0]

        self.norm_adj = data_config['norm_adj']
        self.norm_adj = self.sparse_mx_to_torch_sparse_tensor(self.norm_adj).float().cuda()
        self.rating=data_config['rating']
        self.rating = self.sparse_mx_to_torch_sparse_tensor(self.rating).float().cuda()
        if args.shuffle=='all':
            image_feats = np.load(args.data_path+'{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            text_feats = np.load(args.data_path+'{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
            print(args.data_path+'{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            print(args.data_path+'{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
        elif args.shuffle=='text':
            image_feats = np.load(args.data_path+'{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            text_feats = np.load(args.data_path+'{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
            print(args.data_path+'{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            print(args.data_path+'{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
        elif args.shuffle=='image':
            image_feats = np.load(args.data_path+'{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            text_feats = np.load(args.data_path+'{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
            print(args.data_path+'{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            print(args.data_path+'{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
        else:
            image_feats = np.load(args.data_path+'{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            text_feats = np.load(args.data_path+'{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
            print(args.data_path+'{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            print(args.data_path+'{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
        self.model = MICRO(self.n_users, self.n_items, self.emb_dim, self.weight_size, self.mess_dropout, image_feats, text_feats, self.rating)
        self.model = self.model.cuda()
        
        self.rec_optimizer = optim.Adam([{'params': self.model.an_encoder.parameters()},{'params': self.model.an_decoder.parameters()}], lr=0.01)
        

    def set_lr_scheduler(self):
        fac = lambda epoch: 0.96 ** (epoch / 50)
        scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=fac)
        return scheduler
    def forzen_ano_train_other(self):
        for param in self.model.an_encoder.parameters():
            param.requires_grad = False
        for param in self.model.an_decoder.parameters():
            param.requires_grad = False
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr)
        return optimizer
    
    def test(self, users_to_test, is_val):
        self.model.eval()
        with torch.no_grad():
            ua_embeddings, ia_embeddings, *rest = self.model(self.norm_adj,0, build_item_graph=True)
        result = test_torch(torch.stack(ua_embeddings).mean(dim=0), torch.stack(ia_embeddings).mean(dim=0), users_to_test, is_val)
        return result

    def train(self):
        training_time_list = []
        loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger,mrr_loger = [], [], [], [], [], []
        stopping_step = 0
        should_stop = False
        cur_best_pre_0 = 0.

        n_batch = data_generator.n_train // args.batch_size + 1
        best_recall = 0
        for epoch in (range(args.epoch)):
            if epoch==rec_epoch:
                self.optimizer=self.forzen_ano_train_other()
                self.lr_scheduler = self.set_lr_scheduler()
            t1 = time()
            loss, mf_loss, emb_loss, reg_loss,mf_var_loss = 0., 0., 0., 0.,0.
            l_rec_loss= 0. 
            contrastive_loss = 0.
            n_batch = data_generator.n_train // args.batch_size + 1
            f_time, b_time, loss_time, opt_time, clip_time, emb_time = 0., 0., 0., 0., 0., 0.
            sample_time = 0.
            build_item_graph = True
            for idx in (range(n_batch)):
                #print('batch num ',idx)
                self.model.train()
                torch.autograd.set_detect_anomaly(False)
                if epoch>=rec_epoch:
                    self.optimizer.zero_grad()
                self.rec_optimizer.zero_grad()
                sample_t1 = time()
                users, pos_items, neg_items = data_generator.sample()
                sample_time += time() - sample_t1
                #if idx%50==0:
                ua_embeddings, ia_embeddings, image_item_embeds, text_item_embeds, fusion_embed, img_rel,text_rel = self.model(self.norm_adj,idx, build_item_graph=build_item_graph)
                #else:
                    #ua_embeddings, ia_embeddings, image_item_embeds, text_item_embeds, fusion_embed = self.model(self.norm_adj,idx, build_item_graph=build_item_graph)
                #build_item_graph = False
                #u_g_embeddings = ua_embeddings[users]
                #pos_i_g_embeddings = ia_embeddings[pos_items]
                #neg_i_g_embeddings = ia_embeddings[neg_items]
                mf_loss_tmp_list=[]
                batch_mf_loss = 0
                batch_emb_loss = 0
                batch_reg_loss = 0
                for k in range(len(ua_embeddings)):
                    
                    u_g_embeddings = ua_embeddings[k][users]
                    pos_i_g_embeddings = ia_embeddings[k][pos_items]
                    neg_i_g_embeddings = ia_embeddings[k][neg_items]
                    en_batch_mf_loss, en_batch_emb_loss, en_batch_reg_loss = self.bpr_loss(u_g_embeddings, pos_i_g_embeddings,
                                                                              neg_i_g_embeddings)
                    batch_mf_loss+=en_batch_mf_loss
                    batch_emb_loss+=en_batch_emb_loss
                    batch_reg_loss+=en_batch_reg_loss
                    mf_loss_tmp_list.append(batch_mf_loss)

                
                batch_mf_loss /= len(ua_embeddings)
                batch_emb_loss /= len(ua_embeddings)
                batch_reg_loss /= len(ua_embeddings)
                #print(mf_loss_tmp_list)
                if len(mf_loss_tmp_list)>1:
                    batch_mf_var_loss=torch.var(torch.stack(mf_loss_tmp_list))
                else:
                    batch_mf_var_loss=0.
                #print(batch_mf_var_loss)
                
                batch_contrastive_loss = 0
                batch_contrastive_loss += self.model.batched_contrastive_loss(image_item_embeds,fusion_embed)
                batch_contrastive_loss += self.model.batched_contrastive_loss(text_item_embeds,fusion_embed)

                #batch_contrastive_loss += self.model.batched_contrastive_loss(uii_i_emb,hh)
                #batch_contrastive_loss += self.model.batched_contrastive_loss(uii_t_emb,hh)

                batch_l_rec_loss=(img_rel+text_rel).mean()
                batch_l_rec_loss*=0.1
                batch_contrastive_loss *=  args.loss_ratio
                batch_loss = batch_mf_loss + batch_emb_loss + batch_reg_loss + batch_contrastive_loss + batch_l_rec_loss +batch_mf_var_loss
                if epoch<rec_epoch:
                    batch_l_rec_loss.backward(retain_graph=False)
                    self.rec_optimizer.step()
                else:
                    
                    batch_loss.backward(retain_graph=False)
                    self.optimizer.step()
                
                loss += float(batch_loss)
                mf_loss += float(batch_mf_loss)
                emb_loss += float(batch_emb_loss)
                reg_loss += float(batch_reg_loss)
                mf_var_loss += batch_mf_var_loss
                contrastive_loss += float(batch_contrastive_loss)
                l_rec_loss +=batch_l_rec_loss
                #l_rec_loss+=0
            if epoch>=rec_epoch:
                self.lr_scheduler.step()

            del ua_embeddings, ia_embeddings, u_g_embeddings, neg_i_g_embeddings, pos_i_g_embeddings

            if math.isnan(loss) == True:
                print('ERROR: loss is nan.')
                sys.exit()

            perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f + %.5f + %.5f + %.5f]' % (
                epoch, time() - t1, loss, mf_loss, emb_loss, reg_loss, contrastive_loss, l_rec_loss,mf_var_loss)
            training_time_list.append(time() - t1)
            print(perf_str)

            if epoch<rec_epoch+5:
                continue
            
            if epoch % args.verbose != 0:
                continue


            t2 = time()
            users_to_test = list(data_generator.test_set.keys())
            users_to_val = list(data_generator.val_set.keys())
            ret = self.test(users_to_val, is_val=True)
            training_time_list.append(t2 - t1)

            t3 = time()

            loss_loger.append(loss)
            rec_loger.append(ret['recall'])
            pre_loger.append(ret['precision'])
            ndcg_loger.append(ret['ndcg'])
            hit_loger.append(ret['hit_ratio'])
            if args.verbose > 0:
                perf_str = 'Epoch %d [%.1fs + %.1fs]:  val==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \
                           'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f], mrr=[%.5f, %.5f]' % \
                           (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, ret['recall'][0],
                            ret['recall'][1],
                            ret['precision'][0], ret['precision'][1], ret['hit_ratio'][0], ret['hit_ratio'][1],
                            ret['ndcg'][0], ret['ndcg'][1],ret['mrr'][0], ret['mrr'][1])
                print(perf_str)

            if ret['recall'][1] > best_recall:
                best_recall = ret['recall'][1]
                test_ret = self.test(users_to_test, is_val=False)
                perf_str = 'Epoch %d [%.1fs + %.1fs]: test==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f], ' \
                           'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f], mrr=[%.5f, %.5f]' % \
                           (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, test_ret['recall'][0],
                            test_ret['recall'][1],
                            test_ret['precision'][0], test_ret['precision'][1], test_ret['hit_ratio'][0], test_ret['hit_ratio'][1],
                            test_ret['ndcg'][0], test_ret['ndcg'][1],test_ret['mrr'][0], test_ret['mrr'][1])
                print(perf_str)                
                stopping_step = 0
            elif stopping_step < args.early_stopping_patience:
                stopping_step += 1
                print('#####Early stopping steps: %d #####' % stopping_step)
            else:
                print('#####Early stop! #####')
                break
        print(test_ret)
        print(args.dataset,args.shuffle,args.p,tm.strftime("%a %b %d %H:%M:%S %Y", tm.localtime()),test_ret, file=open('test_ret.txt','a'))
        

    def bpr_loss(self, users, pos_items, neg_items):
        pos_scores = torch.sum(torch.mul(users, pos_items), dim=1)
        neg_scores = torch.sum(torch.mul(users, neg_items), dim=1)

        regularizer = 1./2*(users**2).sum() + 1./2*(pos_items**2).sum() + 1./2*(neg_items**2).sum()
        regularizer = regularizer / self.batch_size

        maxi = F.logsigmoid(pos_scores - neg_scores)
        mf_loss = -torch.mean(maxi)

        emb_loss = self.decay * regularizer
        reg_loss = 0.0
        return mf_loss, emb_loss, reg_loss

    def sparse_mx_to_torch_sparse_tensor(self, sparse_mx):
        """Convert a scipy sparse matrix to a torch sparse tensor."""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed) # cpu
    torch.cuda.manual_seed_all(seed)  # gpu


set_seed(args.seed)
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

config = dict()
config['n_users'] = data_generator.n_users
config['n_items'] = data_generator.n_items

plain_adj, norm_adj, mean_adj = data_generator.get_adj_mat()

rating=data_generator.get_R_mat()

config['norm_adj'] = norm_adj
config['rating'] = rating

trainer = Trainer(data_config=config)
trainer.train()


already load adj matrix (26495, 26495) 0.029147624969482422
already load rating matrix (19445, 7050) 0.005362987518310547


  return torch.sparse.FloatTensor(indices, values, shape)


../baby/clip_shuffled_0_image_feats.npy
../baby/clip_shuffled_0_text_feats.npy
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
Epoch 0 [22.8s]: train==[464.12085=58.86221 + 0.00187 + 0.00000 + 55.06337 + 350.19333 + 0.00000]
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
Epoch 1 [24.0s]: train==[127.93129=58.99597 + 0.00187 + 0.00000 + 55.06398 + 13.86947 + 0.00000]
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
Epoch 2 [28.5s]: train==[122.60332=54.90937 + 0.00197 + 0.00000 + 54.61876 + 13.07323 + 0.00000]
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
time to update graph
Epoch 3 [29.1s]: train==[115.96183=48.58265 + 0.00224 + 0.00000 + 54.30373 + 13.07323 + 0.00000]
time to update graph
time to