In [None]:
import os
os.environ['DGLBACKEND'] = 'pytorch'

In [None]:
from datetime import datetime
class Logger():
    def __init__(self, filename, is_debug, path='/root/autodl-fs/model/mmssl_log/'):
        self.filename = filename
        self.path = path
        self.log_ = not is_debug
    def logging(self, s):
        s = str(s)
        print(datetime.now().strftime('%Y-%m-%d %H:%M: '), s)
        if self.log_:
            with open(os.path.join(os.path.join(self.path, self.filename)), 'a+') as f_log:
                f_log.write(str(datetime.now().strftime('%Y-%m-%d %H:%M:  ')) + s + '\n')

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score

def recall(rank, ground_truth, N):
    return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))


def precision_at_k(r, k):
    """Score is precision @ k
    Relevance is binary (nonzero is relevant).
    Returns:
        Precision @ k
    Raises:
        ValueError: len(r) must be >= k
    """
    assert k >= 1
    r = np.asarray(r)[:k]
    return np.mean(r)


def average_precision(r,cut):
    """Score is average precision (area under PR curve)
    Relevance is binary (nonzero is relevant).
    Returns:
        Average precision
    """
    r = np.asarray(r)
    out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
    if not out:
        return 0.
    return np.sum(out)/float(min(cut, np.sum(r)))


def mean_average_precision(rs):
    """Score is mean average precision
    Relevance is binary (nonzero is relevant).
    Returns:
        Mean average precision
    """
    return np.mean([average_precision(r) for r in rs])


def dcg_at_k(r, k, method=1):
    """Score is discounted cumulative gain (dcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Returns:
        Discounted cumulative gain
    """
    r = np.asfarray(r)[:k]
    if r.size:
        if method == 0:
            return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
        elif method == 1:
            return np.sum(r / np.log2(np.arange(2, r.size + 2)))
        else:
            raise ValueError('method must be 0 or 1.')
    return 0.


def ndcg_at_k(r, k, method=1):
    """Score is normalized discounted cumulative gain (ndcg)
    Relevance is positive real values.  Can use binary
    as the previous methods.
    Returns:
        Normalized discounted cumulative gain
    """
    dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k, method) / dcg_max


def recall_at_k(r, k, all_pos_num):
    r = np.asfarray(r)[:k]
    if all_pos_num == 0:
        return 0
    else:
        return np.sum(r) / all_pos_num


def hit_at_k(r, k):
    r = np.array(r)[:k]
    if np.sum(r) > 0:
        return 1.
    else:
        return 0.

def mrr_at_k(r, k):
    r = np.array(r)[:k]
    #print(r)
    if np.sum(r) > 0:
        #print(1/(np.where(r==1.0)[0]+1).astype(float)[0])
        return 1/(np.where(r==1.0)[0]+1).astype(float)[0]
    else:
        return 0.


def F1(pre, rec):
    if pre + rec > 0:
        return (2.0 * pre * rec) / (pre + rec)
    else:
        return 0.

def auc(ground_truth, prediction):
    try:
        res = roc_auc_score(y_true=ground_truth, y_score=prediction)
    except Exception:
        res = 0.
    return res

In [None]:
def parse_args():
    class Args:
        p=0
        dataset = 'beauty'
        verbose = 5
        core = 5
        lambda_coeff = 0.9
        early_stopping_patience = 7
        layers = 1
        mess_dropout = '[0.1, 0.1]'
        sparse = 1
        test_flag = 'part'
        metapath_threshold = 2
        sc = 1.0
        ssl_c_rate = 1.3
        ssl_s_rate = 0.8
        g_rate = 0.000029
        sample_num = 1
        sample_num_neg = 1
        sample_num_ii = 8
        sample_num_co = 2
        mask_rate = 0.75
        gss_rate = 0.85
        anchor_rate = 0.75
        feat_reg_decay = 1e-5
        ad1_rate = 0.2
        ad2_rate = 0.2
        ad_sampNum = 1
        ad_topk_multi_num = 100
        fake_gene_rate = 0.0001
        ID_layers = 1
        reward_rate = 1
        G_embed_size = 64
        model_num = 2
        negrate = 0.01
        cis = 25
        confidence = 0.5
        ii_it = 15
        isload = False
        isJustTest = False
        loadModelPath = '/home/ww/Code/work3/BSTRec/Model/retailrocket/for_meta_hidden_dim_dim__8_retailrocket_2021_07_10__18_35_32_lr_0.0003_reg_0.01_batch_size_1024_gnn_layer_[16,16,16].pth'
        title = "try_to_draw_line"
        data_path = './'
        seed = 123
        epoch = 1000
        batch_size = 128
        embed_size = 64
        D_lr = 3e-4
        topk = 10
        cf_model = 'slmrec'
        debug = False
        cl_rate = 0.03
        norm_type = 'sym'
        gpu_id = 0
        Ks = '[10, 20, 50, 5]'
        regs = '[1e-5,1e-5,1e-2]'
        lr = 0.00055
        emm = 1e-3
        L2_alpha = 1e-3
        weight_decay = 1e-4
        drop_rate = 0.2
        model_cat_rate = 0.55
        gnn_cat_rate = 0.55
        id_cat_rate = 0.36
        id_cat_rate1 = 0.36
        head_num = 4
        dgl_nei_num = 8
        weight_size = '[64, 64]'
        G_rate = 0.0001
        G_drop1 = 0.31
        G_drop2 = 0.5
        gp_rate = 1
        real_data_tau = 0.005
        ui_pre_scale = 100
        T = 1
        tau = 0.5
        geneGraph_rate = 0.1
        geneGraph_rate_pos = 2
        geneGraph_rate_neg = -1
        m_topk_rate = 0.0001
        log_log_scale = 0.00001
        point = ''
        test_flag = 'part'
        shuffle='text'
    return Args()


In [None]:
import numpy as np
import random as rd
import scipy.sparse as sp
from time import time
import json
#from utility.parser import parse_args
args = parse_args()

class Data(object):
    def __init__(self, path, batch_size):
        self.path = path + '%d-core' % args.core
        self.batch_size = batch_size

        train_file = path + '/%d-core/train.json' % (args.core)
        val_file = path + '/%d-core/val.json' % (args.core)
        test_file = path + '/%d-core/test.json'  % (args.core)

        #get number of users and items
        self.n_users, self.n_items = 0, 0
        self.n_train, self.n_test = 0, 0
        self.neg_pools = {}

        self.exist_users = []
        

        
        train = json.load(open(train_file))
        test = json.load(open(test_file))
        val = json.load(open(val_file))
        
        for uid, items in train.items():
            if len(items) == 0:
                continue
            uid = int(uid)
            self.exist_users.append(uid)
            self.n_items = max(self.n_items, max(items))
            self.n_users = max(self.n_users, uid)
            self.n_train += len(items)

        for uid, items in test.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_test += len(items)
            except:
                continue

        for uid, items in val.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_val += len(items)
            except:
                continue

        self.n_items += 1
        self.n_users += 1

        self.print_statistics()

        self.R = sp.dok_matrix((self.n_users, self.n_items), dtype=np.float32)
        self.R_Item_Interacts = sp.dok_matrix((self.n_items, self.n_items), dtype=np.float32)

        self.train_items, self.test_set, self.val_set = {}, {}, {}
        for uid, train_items in train.items():
            if len(train_items) == 0:
                continue
            uid = int(uid)
            for idx, i in enumerate(train_items):
                self.R[uid, i] = 1.

            self.train_items[uid] = train_items

        self.my_test_set=[]
        for uid, test_items in test.items():
            uid = int(uid)
            if len(test_items) == 0:
                continue
            for i in test_items:
                self.my_test_set.append([uid,i])
            try:
                self.test_set[uid] = test_items
            except:
                continue

        for uid, val_items in val.items():
            uid = int(uid)
            if len(val_items) == 0:
                continue
            try:
                self.val_set[uid] = val_items
            except:
                continue            

    def get_adj_mat(self):
        try:
            t1 = time()
            adj_mat = sp.load_npz(self.path + '/s_adj_mat.npz')
            norm_adj_mat = sp.load_npz(self.path + '/s_norm_adj_mat.npz')
            mean_adj_mat = sp.load_npz(self.path + '/s_mean_adj_mat.npz')
            print('already load adj matrix', adj_mat.shape, time() - t1)

        except Exception:
            adj_mat, norm_adj_mat, mean_adj_mat = self.create_adj_mat()
            sp.save_npz(self.path + '/s_adj_mat.npz', adj_mat)
            sp.save_npz(self.path + '/s_norm_adj_mat.npz', norm_adj_mat)
            sp.save_npz(self.path + '/s_mean_adj_mat.npz', mean_adj_mat)
        return adj_mat, norm_adj_mat, mean_adj_mat

    def create_adj_mat(self):
        t1 = time()
        adj_mat = sp.dok_matrix((self.n_users + self.n_items, self.n_users + self.n_items), dtype=np.float32)
        adj_mat = adj_mat.tolil()
        R = self.R.tolil()

        adj_mat[:self.n_users, self.n_users:] = R
        adj_mat[self.n_users:, :self.n_users] = R.T
        adj_mat = adj_mat.todok()
        print('already create adjacency matrix', adj_mat.shape, time() - t1)

        t2 = time()

        def normalized_adj_single(adj):
            rowsum = np.array(adj.sum(1))

            d_inv = np.power(rowsum, -1).flatten()
            d_inv[np.isinf(d_inv)] = 0.
            d_mat_inv = sp.diags(d_inv)

            norm_adj = d_mat_inv.dot(adj)
            # norm_adj = adj.dot(d_mat_inv)
            print('generate single-normalized adjacency matrix.')
            return norm_adj.tocoo()

        def get_D_inv(adj):
            rowsum = np.array(adj.sum(1))

            d_inv = np.power(rowsum, -1).flatten()
            d_inv[np.isinf(d_inv)] = 0.
            d_mat_inv = sp.diags(d_inv)
            return d_mat_inv

        def check_adj_if_equal(adj):
            dense_A = np.array(adj.todense())
            degree = np.sum(dense_A, axis=1, keepdims=False)

            temp = np.dot(np.diag(np.power(degree, -1)), dense_A)
            print('check normalized adjacency matrix whether equal to this laplacian matrix.')
            return temp

        norm_adj_mat = normalized_adj_single(adj_mat + sp.eye(adj_mat.shape[0]))
        mean_adj_mat = normalized_adj_single(adj_mat)

        print('already normalize adjacency matrix', time() - t2)
        return adj_mat.tocsr(), norm_adj_mat.tocsr(), mean_adj_mat.tocsr()


    def sample(self):
        if self.batch_size <= self.n_users:
            users = rd.sample(self.exist_users, self.batch_size)
        else:
            users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
        # users = self.exist_users[:]

        def sample_pos_items_for_u(u, num):
            pos_items = self.train_items[u]
            n_pos_items = len(pos_items)
            pos_batch = []
            while True:
                if len(pos_batch) == num: break
                pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
                pos_i_id = pos_items[pos_id]

                if pos_i_id not in pos_batch:
                    pos_batch.append(pos_i_id)
            return pos_batch

        def sample_neg_items_for_u(u, num):
            neg_items = []
            while True:
                if len(neg_items) == num: break
                neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
                if neg_id not in self.train_items[u] and neg_id not in neg_items:
                    neg_items.append(neg_id)
            return neg_items

        def sample_neg_items_for_u_from_pools(u, num):
            neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
            return rd.sample(neg_items, num)

        pos_items, neg_items = [], []
        for u in users:
            pos_items += sample_pos_items_for_u(u, 1)
            neg_items += sample_neg_items_for_u(u, 1)
            # neg_items += sample_neg_items_for_u(u, 3)
        return users, pos_items, neg_items
        

    def print_statistics(self):
        print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
        print('n_interactions=%d' % (self.n_train + self.n_test))
        print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))


In [None]:
#import utility.metrics as metrics
#from utility.parser import parse_args
#from utility.load_data import Data
import multiprocessing
import heapq
import torch
import pickle
import numpy as np
from time import time

cores = multiprocessing.cpu_count() // 5

args = parse_args()
Ks = eval(args.Ks)
print(args.data_path + args.dataset)
data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
BATCH_SIZE = args.batch_size

def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
    item_score = {}
    for i in test_items:
        item_score[i] = rating[i]

    K_max = max(Ks)
    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)

    r = []
    for i in K_max_item_score:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc = 0.
    return r, auc

def get_auc(item_score, user_pos_test):
    item_score = sorted(item_score.items(), key=lambda kv: kv[1])
    item_score.reverse()
    item_sort = [x[0] for x in item_score]
    posterior = [x[1] for x in item_score]

    r = []
    for i in item_sort:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc_s = auc(ground_truth=r, prediction=posterior)
    return auc_s

def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
    item_score = {}
    for i in test_items:
        item_score[i] = rating[i]

    K_max = max(Ks)
    K_max_item_score = heapq.nlargest(K_max, item_score, key=item_score.get)

    r = []
    for i in K_max_item_score:
        if i in user_pos_test:
            r.append(1)
        else:
            r.append(0)
    auc = get_auc(item_score, user_pos_test)
    return r, auc

def get_performance(user_pos_test, r, auc, Ks):
    precision, recall, ndcg, hit_ratio, mrr = [], [], [], [], []

    for K in Ks:
        precision.append(precision_at_k(r, K))
        recall.append(recall_at_k(r, K, len(user_pos_test)))
        ndcg.append(ndcg_at_k(r, K))
        hit_ratio.append(hit_at_k(r, K))
        mrr.append(mrr_at_k(r, K))

    return {'recall': np.array(recall), 'precision': np.array(precision),
            'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc, 'mrr': np.array(mrr)}


def test_one_user(x):
    # user u's ratings for user u
    is_val = x[-1]
    rating = x[0]
    #uid
    u = x[1]
    #user u's items in the training set
    try:
        training_items = data_generator.train_items[u]
    except Exception:
        training_items = []
    #user u's items in the test set
    if is_val:
        user_pos_test = data_generator.val_set[u]
    else:
        user_pos_test = data_generator.test_set[u]

    all_items = set(range(ITEM_NUM))

    test_items = list(all_items - set(training_items))

    if args.test_flag == 'part':
        r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)
    else:
        r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)

    return get_performance(user_pos_test, r, auc, Ks)

In [None]:
def test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val, drop_flag=False, batch_test_flag=False):
    result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),
              'hit_ratio': np.zeros(len(Ks)),'mrr': np.zeros(len(Ks)), 'auc': 0.}
    pool = multiprocessing.Pool(cores)

    u_batch_size = BATCH_SIZE * 2
    i_batch_size = BATCH_SIZE

    test_users = users_to_test
    n_test_users = len(test_users)
    n_user_batchs = n_test_users // u_batch_size + 1
    count = 0

    for u_batch_id in range(n_user_batchs):
        start = u_batch_id * u_batch_size
        end = (u_batch_id + 1) * u_batch_size
        user_batch = test_users[start: end]
        if batch_test_flag:
            n_item_batchs = ITEM_NUM // i_batch_size + 1
            rate_batch = np.zeros(shape=(len(user_batch), ITEM_NUM))

            i_count = 0
            for i_batch_id in range(n_item_batchs):
                i_start = i_batch_id * i_batch_size
                i_end = min((i_batch_id + 1) * i_batch_size, ITEM_NUM)

                item_batch = range(i_start, i_end)
                u_g_embeddings = ua_embeddings[user_batch]
                i_g_embeddings = ia_embeddings[item_batch]
                i_rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))

                rate_batch[:, i_start: i_end] = i_rate_batch
                i_count += i_rate_batch.shape[1]

            assert i_count == ITEM_NUM

        else:
            item_batch = range(ITEM_NUM)
            u_g_embeddings = ua_embeddings[user_batch]
            i_g_embeddings = ia_embeddings[item_batch]
            rate_batch = torch.matmul(u_g_embeddings, torch.transpose(i_g_embeddings, 0, 1))

        rate_batch = rate_batch.detach().cpu().numpy()
        #print(rate_batch.shape)
        
        user_batch_rating_uid = zip(rate_batch, user_batch, [is_val] * len(user_batch))

        batch_result = pool.map(test_one_user, user_batch_rating_uid)
        count += len(batch_result)

        for re in batch_result:
            result['precision'] += re['precision'] / n_test_users
            result['recall'] += re['recall'] / n_test_users
            result['ndcg'] += re['ndcg'] / n_test_users
            result['hit_ratio'] += re['hit_ratio'] / n_test_users
            result['auc'] += re['auc'] / n_test_users
            result['mrr'] += re['mrr'] / n_test_users

    assert count == n_test_users
    pool.close()
    return result

In [None]:
import torch
import numpy as np
from scipy.sparse import csr_matrix 

def build_sim(context):
    context_norm = context.div(torch.norm(context, p=2, dim=-1, keepdim=True))
    sim = torch.sparse.mm(context_norm, context_norm.transpose(1, 0))
    # a, b = context_norm.shape
    # b, c = context_norm.transpose(1, 0).shape
    # ab = context_norm.unsqueeze(-1)  #.repeat(1,1,c)
    # bc = context_norm.transpose(1, 0).unsqueeze(0)  #.repeat(a, 1,1)
    # sim = torch.mul(ab, bc).sum(dim=1, keepdim=False)

    return sim

# def build_knn_normalized_graph(adj, topk, is_sparse, norm_type):
#     device = adj.device
#     knn_val, knn_ind = torch.topk(adj, topk, dim=-1)
#     if is_sparse:
#         tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]]
#         row = [i[0] for i in tuple_list]
#         col = [i[1] for i in tuple_list]
#         i = torch.LongTensor([row, col]).to(device)
#         v = knn_val.flatten()
#         edge_index, edge_weight = get_sparse_laplacian(i, v, normalization=norm_type, num_nodes=adj.shape[0])
#         return torch.sparse_coo_tensor(edge_index, edge_weight, adj.shape)
#     else:
#         weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
#         return get_dense_laplacian(weighted_adjacency_matrix, normalization=norm_type)

def build_knn_normalized_graph(adj, topk, is_sparse, norm_type):
    device = adj.device
    knn_val, knn_ind = torch.topk(adj, topk, dim=-1)  #[7050, 10] [7050, 10]
    n_item = knn_val.shape[0]
    n_data = knn_val.shape[0]*knn_val.shape[1]
    data = np.ones(n_data)
    if is_sparse:
        tuple_list = [[row, int(col)] for row in range(len(knn_ind)) for col in knn_ind[row]]  #[70500]
        # data = np.array(knn_val.flatten().cpu())  #args.topk_rate*
        row = [i[0] for i in tuple_list]  #[70500]
        col = [i[1] for i in tuple_list]  #[70500]
        # #-----------------------------------------------------------------------------------------------------
        # i = torch.LongTensor([row, col]).to(device)
        # v = knn_val.flatten()
        # edge_index, edge_weight = get_sparse_laplacian(i, v, normalization=norm_type, num_nodes=adj.shape[0])
        # #-----------------------------------------------------------------------------------------------------
        ii_graph = csr_matrix((data, (row, col)) ,shape=(n_item, n_item))
        # return torch.sparse_coo_tensor(edge_index, edge_weight, adj.shape)
        return ii_graph
    else:
        weighted_adjacency_matrix = (torch.zeros_like(adj)).scatter_(-1, knn_ind, knn_val)
        return get_dense_laplacian(weighted_adjacency_matrix, normalization=norm_type)


def get_sparse_laplacian(edge_index, edge_weight, num_nodes, normalization='none'):  #[2, 70500], [70500]
    from torch_scatter import scatter_add
    row, col = edge_index[0], edge_index[1]  #[70500] [70500]
    deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)  #[7050]

    if normalization == 'sym':
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    elif normalization == 'rw':
        deg_inv = 1.0 / deg
        deg_inv.masked_fill_(deg_inv == float('inf'), 0)
        edge_weight = deg_inv[row] * edge_weight
    return edge_index, edge_weight


def get_dense_laplacian(adj, normalization='none'):
    if normalization == 'sym':
        rowsum = torch.sum(adj, -1)
        d_inv_sqrt = torch.pow(rowsum, -0.5)
        d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
        L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
    elif normalization == 'rw':
        rowsum = torch.sum(adj, -1)
        d_inv = torch.pow(rowsum, -1)
        d_inv[torch.isinf(d_inv)] = 0.
        d_mat_inv = torch.diagflat(d_inv)
        L_norm = torch.mm(d_mat_inv, adj)
    elif normalization == 'none':
        L_norm = adj
    return L_norm

In [None]:
import os
import numpy as np
from time import time
import pickle
import scipy.sparse as sp
from scipy.sparse import csr_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

#from utility.parser import parse_args
#from utility.norm import build_sim, build_knn_normalized_graph
args = parse_args()

class MMSSL(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats, text_feats):

        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.weight_size = weight_size
        self.n_ui_layers = len(self.weight_size)
        self.weight_size = [self.embedding_dim] + self.weight_size

        self.image_trans = nn.Linear(image_feats.shape[1], args.embed_size)
        self.text_trans = nn.Linear(text_feats.shape[1], args.embed_size)
        nn.init.xavier_uniform_(self.image_trans.weight)
        nn.init.xavier_uniform_(self.text_trans.weight)             
        self.encoder = nn.ModuleDict() 
        self.encoder['image_encoder'] = self.image_trans
        self.encoder['text_encoder'] = self.text_trans

        self.common_trans = nn.Linear(args.embed_size, args.embed_size)
        nn.init.xavier_uniform_(self.common_trans.weight)
        self.align = nn.ModuleDict() 
        self.align['common_trans'] = self.common_trans

        self.user_id_embedding = nn.Embedding(n_users, self.embedding_dim)
        self.item_id_embedding = nn.Embedding(n_items, self.embedding_dim)

        nn.init.xavier_uniform_(self.user_id_embedding.weight)
        nn.init.xavier_uniform_(self.item_id_embedding.weight)
        self.image_feats = torch.tensor(image_feats).float().cuda()
        self.text_feats = torch.tensor(text_feats).float().cuda()
        self.image_embedding = nn.Embedding.from_pretrained(torch.Tensor(image_feats), freeze=False)
        self.text_embedding = nn.Embedding.from_pretrained(torch.Tensor(text_feats), freeze=False)

        self.softmax = nn.Softmax(dim=-1)
        self.act = nn.Sigmoid()  
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(p=args.drop_rate)
        self.batch_norm = nn.BatchNorm1d(args.embed_size)
        self.tau = 0.5

        initializer = nn.init.xavier_uniform_
        self.weight_dict = nn.ParameterDict({
            'w_q': nn.Parameter(initializer(torch.empty([args.embed_size, args.embed_size]))),
            'w_k': nn.Parameter(initializer(torch.empty([args.embed_size, args.embed_size]))),
            'w_v': nn.Parameter(initializer(torch.empty([args.embed_size, args.embed_size]))),
            'w_self_attention_item': nn.Parameter(initializer(torch.empty([args.embed_size, args.embed_size]))),
            'w_self_attention_user': nn.Parameter(initializer(torch.empty([args.embed_size, args.embed_size]))),
            'w_self_attention_cat': nn.Parameter(initializer(torch.empty([args.head_num*args.embed_size, args.embed_size]))),
        })
        self.embedding_dict = {'user':{}, 'item':{}}

    def mm(self, x, y):
        if args.sparse:
            return torch.sparse.mm(x, y)
        else:
            return torch.mm(x, y)
    def sim(self, z1, z2):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())

    def batched_contrastive_loss(self, z1, z2, batch_size=4096):
        device = z1.device
        num_nodes = z1.size(0)
        num_batches = (num_nodes - 1) // batch_size + 1
        f = lambda x: torch.exp(x / self.tau)
        indices = torch.arange(0, num_nodes).to(device)
        losses = []

        for i in range(num_batches):
            mask = indices[i * batch_size:(i + 1) * batch_size]
            refl_sim = f(self.sim(z1[mask], z1))  
            between_sim = f(self.sim(z1[mask], z2))  

            losses.append(-torch.log(
                between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
                / (refl_sim.sum(1) + between_sim.sum(1)
                   - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))
                   
        loss_vec = torch.cat(losses)
        return loss_vec.mean()

    def csr_norm(self, csr_mat, mean_flag=False):
        rowsum = np.array(csr_mat.sum(1))
        rowsum = np.power(rowsum+1e-8, -0.5).flatten()
        rowsum[np.isinf(rowsum)] = 0.
        rowsum_diag = sp.diags(rowsum)

        colsum = np.array(csr_mat.sum(0))
        colsum = np.power(colsum+1e-8, -0.5).flatten()
        colsum[np.isinf(colsum)] = 0.
        colsum_diag = sp.diags(colsum)

        if mean_flag == False:
            return rowsum_diag*csr_mat*colsum_diag
        else:
            return rowsum_diag*csr_mat

    def matrix_to_tensor(self, cur_matrix):
        if type(cur_matrix) != sp.coo_matrix:
            cur_matrix = cur_matrix.tocoo()  #
        indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64))  #
        values = torch.from_numpy(cur_matrix.data)  #
        shape = torch.Size(cur_matrix.shape)
        return torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32, device='cuda')

    def para_dict_to_tenser(self, para_dict):  
        """
        :param para_dict: nn.ParameterDict()
        :return: tensor
        """
        tensors = []

        for beh in para_dict.keys():
            tensors.append(para_dict[beh])
        tensors = torch.stack(tensors, dim=0)

        return tensors


    def multi_head_self_attention(self, trans_w, embedding_t_1, embedding_t):  
       
        q = self.para_dict_to_tenser(embedding_t)
        v = k = self.para_dict_to_tenser(embedding_t_1)
        beh, N, d_h = q.shape[0], q.shape[1], args.embed_size/args.head_num

        Q = torch.matmul(q, trans_w['w_q'])  
        K = torch.matmul(k, trans_w['w_k'])
        V = v

        Q = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)  
        K = Q.reshape(beh, N, args.head_num, int(d_h)).permute(2, 0, 1, 3)

        Q = torch.unsqueeze(Q, 2) 
        K = torch.unsqueeze(K, 1)  
        V = torch.unsqueeze(V, 1)  

        att = torch.mul(Q, K) / torch.sqrt(torch.tensor(d_h))  
        att = torch.sum(att, dim=-1) 
        att = torch.unsqueeze(att, dim=-1)  
        att = F.softmax(att, dim=2)  

        Z = torch.mul(att, V)  
        Z = torch.sum(Z, dim=2)  

        Z_list = [value for value in Z]
        Z = torch.cat(Z_list, -1)
        Z = torch.matmul(Z, self.weight_dict['w_self_attention_cat'])

        args.model_cat_rate*F.normalize(Z, p=2, dim=2)
        return Z, att.detach()

    def forward(self, ui_graph, iu_graph, image_ui_graph, image_iu_graph, text_ui_graph, text_iu_graph):

        image_feats = image_item_feats = self.dropout(self.image_trans(self.image_feats))
        text_feats = text_item_feats = self.dropout(self.text_trans(self.text_feats))

        for i in range(args.layers):
            image_user_feats = self.mm(ui_graph, image_feats)
            image_item_feats = self.mm(iu_graph, image_user_feats)
            image_user_id = self.mm(image_ui_graph, self.item_id_embedding.weight)
            image_item_id = self.mm(image_iu_graph, self.user_id_embedding.weight)

            text_user_feats = self.mm(ui_graph, text_feats)
            text_item_feats = self.mm(iu_graph, text_user_feats)
            text_user_id = self.mm(text_ui_graph, self.item_id_embedding.weight)
            text_item_id = self.mm(text_iu_graph, self.user_id_embedding.weight)

        self.embedding_dict['user']['image'] = image_user_id
        self.embedding_dict['user']['text'] = text_user_id
        self.embedding_dict['item']['image'] = image_item_id
        self.embedding_dict['item']['text'] = text_item_id
        user_z, _ = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['user'], self.embedding_dict['user'])
        item_z, _ = self.multi_head_self_attention(self.weight_dict, self.embedding_dict['item'], self.embedding_dict['item'])
        user_emb = user_z.mean(0)
        item_emb = item_z.mean(0)
        u_g_embeddings = self.user_id_embedding.weight + args.id_cat_rate*F.normalize(user_emb, p=2, dim=1)
        i_g_embeddings = self.item_id_embedding.weight + args.id_cat_rate*F.normalize(item_emb, p=2, dim=1)

        user_emb_list = [u_g_embeddings]
        item_emb_list = [i_g_embeddings]
        for i in range(self.n_ui_layers):    
            if i == (self.n_ui_layers-1):
                u_g_embeddings = self.softmax( torch.mm(ui_graph, i_g_embeddings) ) 
                i_g_embeddings = self.softmax( torch.mm(iu_graph, u_g_embeddings) )

            else:
                u_g_embeddings = torch.mm(ui_graph, i_g_embeddings) 
                i_g_embeddings = torch.mm(iu_graph, u_g_embeddings) 

            user_emb_list.append(u_g_embeddings)
            item_emb_list.append(i_g_embeddings)

        u_g_embeddings = torch.mean(torch.stack(user_emb_list), dim=0)
        i_g_embeddings = torch.mean(torch.stack(item_emb_list), dim=0)


        u_g_embeddings = u_g_embeddings + args.model_cat_rate*F.normalize(image_user_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_user_feats, p=2, dim=1)
        i_g_embeddings = i_g_embeddings + args.model_cat_rate*F.normalize(image_item_feats, p=2, dim=1) + args.model_cat_rate*F.normalize(text_item_feats, p=2, dim=1)

        return u_g_embeddings, i_g_embeddings, image_item_feats, text_item_feats, image_user_feats, text_user_feats, u_g_embeddings, i_g_embeddings, image_user_id, text_user_id, image_item_id, text_item_id



class Discriminator(nn.Module):
    def __init__(self, dim):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(dim, int(dim/4)),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(int(dim/4)),
    		nn.Dropout(args.G_drop1),

            nn.Linear(int(dim/4), int(dim/8)),
            nn.LeakyReLU(True),
            nn.BatchNorm1d(int(dim/8)),
    		nn.Dropout(args.G_drop2),

            nn.Linear(int(dim/8), 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = 100*self.net(x.float())  
        return output.view(-1)


In [None]:
from datetime import datetime
import math
import os
import random
import sys
from time import time
from tqdm import tqdm

import pickle
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csr_matrix
#import  visdom

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.sparse as sparse
from torch import autograd


import copy

In [None]:
import time as tm


#from utility.parser import parse_args
#from Models import MMSSL, Discriminator
#from utility.batch_test import *
#from utility.logging import Logger
#from utility.norm import build_sim, build_knn_normalized_graph
#from torch.utils.tensorboard import SummaryWriter


args = parse_args()


class Trainer(object):
    def __init__(self, data_config):
       
        self.task_name = "%s_%s_%s" % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), args.dataset, args.cf_model,)
        self.logger = Logger(filename=self.task_name, is_debug=args.debug)
        self.logger.logging("PID: %d" % os.getpid())
        self.logger.logging(str(args))

        self.mess_dropout = eval(args.mess_dropout)
        self.lr = args.lr
        self.emb_dim = args.embed_size
        self.batch_size = args.batch_size
        self.weight_size = eval(args.weight_size)
        self.n_layers = len(self.weight_size)
        self.regs = eval(args.regs)
        self.decay = self.regs[0]
 
        if args.shuffle=='all':
            self.image_feats = np.load('./{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            self.text_feats = np.load('./{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
            print('./{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            print('./{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
        elif args.shuffle=='text':
            self.image_feats = np.load('./{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            self.text_feats = np.load('./{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
            print('./{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            print('./{}/clip_shuffled_{}_text_feats.npy'.format(args.dataset,args.p))
        elif args.shuffle=='image':
            self.image_feats = np.load('./{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            self.text_feats = np.load('./{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
            print('./{}/clip_shuffled_{}_image_feats.npy'.format(args.dataset,args.p))
            print('./{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
        else:
            self.image_feats = np.load('./{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            self.text_feats = np.load('./{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
            print('./{}/clip_shuffled_0_image_feats.npy'.format(args.dataset))
            print('./{}/clip_shuffled_0_text_feats.npy'.format(args.dataset))
        
        self.image_feat_dim = self.image_feats.shape[-1]
        self.text_feat_dim = self.text_feats.shape[-1]
        self.ui_graph = self.ui_graph_raw = pickle.load(open(args.data_path + args.dataset + '/%d-core' % args.core + '/train_mat','rb'))
        self.image_ui_graph_tmp = self.text_ui_graph_tmp = torch.tensor(self.ui_graph_raw.todense()).cuda()
        self.image_iu_graph_tmp = self.text_iu_graph_tmp = torch.tensor(self.ui_graph_raw.T.todense()).cuda()
        self.image_ui_index = {'x':[], 'y':[]}
        self.text_ui_index = {'x':[], 'y':[]}
        self.n_users = self.ui_graph.shape[0]
        self.n_items = self.ui_graph.shape[1]        
        self.iu_graph = self.ui_graph.T
        self.ui_graph = self.matrix_to_tensor(self.csr_norm(self.ui_graph, mean_flag=True))
        self.iu_graph = self.matrix_to_tensor(self.csr_norm(self.iu_graph, mean_flag=True))
        self.image_ui_graph = self.text_ui_graph = self.ui_graph
        self.image_iu_graph = self.text_iu_graph = self.iu_graph
        self.model = MMSSL(self.n_users, self.n_items, self.emb_dim, self.weight_size, self.mess_dropout, self.image_feats, self.text_feats)      
        self.model = self.model.cuda()
        self.D = Discriminator(self.n_items).cuda()
        self.D.apply(self.weights_init)
        self.optim_D = optim.Adam(self.D.parameters(), lr=args.D_lr, betas=(0.5, 0.9))  

        self.optimizer_D = optim.AdamW(
        [
            {'params':self.model.parameters()},      
        ]
            , lr=self.lr)  
        self.scheduler_D = self.set_lr_scheduler()


    def set_lr_scheduler(self):
        fac = lambda epoch: 0.96 ** (epoch / 50)
        scheduler_D = optim.lr_scheduler.LambdaLR(self.optimizer_D, lr_lambda=fac)
        return scheduler_D  

    def csr_norm(self, csr_mat, mean_flag=False):
        rowsum = np.array(csr_mat.sum(1))
        rowsum = np.power(rowsum+1e-8, -0.5).flatten()
        rowsum[np.isinf(rowsum)] = 0.
        rowsum_diag = sp.diags(rowsum)

        colsum = np.array(csr_mat.sum(0))
        colsum = np.power(colsum+1e-8, -0.5).flatten()
        colsum[np.isinf(colsum)] = 0.
        colsum_diag = sp.diags(colsum)

        if mean_flag == False:
            return rowsum_diag*csr_mat*colsum_diag
        else:
            return rowsum_diag*csr_mat

    def matrix_to_tensor(self, cur_matrix):
        if type(cur_matrix) != sp.coo_matrix:
            cur_matrix = cur_matrix.tocoo()  #
        indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64))  #
        values = torch.from_numpy(cur_matrix.data)  #
        shape = torch.Size(cur_matrix.shape)

        return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda()  #

    def innerProduct(self, u_pos, i_pos, u_neg, j_neg):  
        pred_i = torch.sum(torch.mul(u_pos,i_pos), dim=-1) 
        pred_j = torch.sum(torch.mul(u_neg,j_neg), dim=-1)  
        return pred_i, pred_j

    def sampleTrainBatch_dgl(self, batIds, pos_id=None, g=None, g_neg=None, sample_num=None, sample_num_neg=None):

        sub_g = dgl.sampling.sample_neighbors(g.cpu(), {'user':batIds}, sample_num, edge_dir='out', replace=True)
        row, col = sub_g.edges()
        row = row.reshape(len(batIds), sample_num)
        col = col.reshape(len(batIds), sample_num)

        if g_neg==None:
            return row, col
        else: 
            sub_g_neg = dgl.sampling.sample_neighbors(g_neg, {'user':batIds}, sample_num_neg, edge_dir='out', replace=True)
            row_neg, col_neg = sub_g_neg.edges()
            row_neg = row_neg.reshape(len(batIds), sample_num_neg)
            col_neg = col_neg.reshape(len(batIds), sample_num_neg)
            return row, col, col_neg 

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            m.bias.data.fill_(0)

    def gradient_penalty(self, D, xr, xf):

        LAMBDA = 0.3

        xf = xf.detach()
        xr = xr.detach()

        alpha = torch.rand(args.batch_size*2, 1).cuda()
        alpha = alpha.expand_as(xr)

        interpolates = alpha * xr + ((1 - alpha) * xf)
        interpolates.requires_grad_()

        disc_interpolates = D(interpolates)

        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                grad_outputs=torch.ones_like(disc_interpolates),
                                create_graph=True, retain_graph=True, only_inputs=True)[0]

        gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

        return gp

    def weighted_sum(self, anchor, nei, co):  

        ac = torch.multiply(anchor, co).sum(-1).sum(-1)  
        nc = torch.multiply(nei, co).sum(-1).sum(-1)  

        an = (anchor.permute(1, 0, 2)[0])
        ne = (nei.permute(1, 0, 2)[0])

        an_w = an*(ac.unsqueeze(-1).repeat(1, args.embed_size))
        ne_w = ne*(nc.unsqueeze(-1).repeat(1, args.embed_size))                                     
  
        res = (args.anchor_rate*an_w + (1-args.anchor_rate)*ne_w).reshape(-1, args.sample_num_ii, args.embed_size).sum(1)

        return res


    def sample_topk(self, u_sim, users, emb_type=None):
        topk_p, topk_id = torch.topk(u_sim, args.ad_topk*10, dim=-1)  
        topk_data = topk_p.reshape(-1).cpu()
        topk_col = topk_id.reshape(-1).cpu().int()
        topk_row = torch.tensor(np.array(users)).unsqueeze(1).repeat(1, args.ad_topk*args.ad_topk_multi_num).reshape(-1).int()  #
        topk_csr = csr_matrix((topk_data.detach().numpy(), (topk_row.detach().numpy(), topk_col.detach().numpy())), shape=(self.n_users, self.n_items))
        topk_g = dgl.heterograph({('user','ui','item'):topk_csr.nonzero()})
        _, topk_id = self.sampleTrainBatch_dgl(users, g=topk_g, sample_num=args.ad_topk, pos_id=None, g_neg=None, sample_num_neg=None)
        self.gene_fake[emb_type] = topk_id

        topk_id_u = torch.arange(len(users)).unsqueeze(1).repeat(1, args.ad_topk)
        topk_p = u_sim[topk_id_u, topk_id]
        return topk_p, topk_id

    def ssl_loss_calculation(self, ssl_image_logit, ssl_text_logit, ssl_common_logit):
        ssl_label_1_s2 = torch.ones(1, self.n_items).cuda()
        ssl_label_0_s2 = torch.zeros(1, self.n_items).cuda()
        ssl_label_s2 = torch.cat((ssl_label_1_s2, ssl_label_0_s2), 1)
        ssl_image_s2 = self.bce(ssl_image_logit, ssl_label_s2)
        ssl_text_s2 = self.bce(ssl_text_logit, ssl_label_s2)
        ssl_loss_s2 = ssl_image_s2 + ssl_text_s2

        ssl_label_1_c2 = torch.ones(1, self.n_items*2).cuda()
        ssl_label_0_c2 = torch.zeros(1, self.n_items*2).cuda()
        ssl_label_c2 = torch.cat((ssl_label_1_c2, ssl_label_0_c2), 1)
        ssl_result_c2 = self.bce(ssl_common_logit, ssl_label_c2)  
        ssl_loss_c2 = ssl_result_c2

        ssl_loss2 = args.ssl_s_rate*ssl_loss_s2 + args.ssl_c_rate*ssl_loss_c2 
        return ssl_loss2


    def sim(self, z1, z2):
        z1 = F.normalize(z1)  
        z2 = F.normalize(z2)
        # z1 = z1/((z1**2).sum(-1) + 1e-8)
        # z2 = z2/((z2**2).sum(-1) + 1e-8)
        return torch.mm(z1, z2.t())

    def batched_contrastive_loss(self, z1, z2, batch_size=1024):

        device = z1.device
        num_nodes = z1.size(0)
        num_batches = (num_nodes - 1) // batch_size + 1
        f = lambda x: torch.exp(x / args.tau)   #       

        indices = torch.arange(0, num_nodes).to(device)
        losses = []

        for i in range(num_batches):
            tmp_i = indices[i * batch_size:(i + 1) * batch_size]

            tmp_refl_sim_list = []
            tmp_between_sim_list = []
            for j in range(num_batches):
                tmp_j = indices[j * batch_size:(j + 1) * batch_size]
                tmp_refl_sim = f(self.sim(z1[tmp_i], z1[tmp_j]))  
                tmp_between_sim = f(self.sim(z1[tmp_i], z2[tmp_j]))  

                tmp_refl_sim_list.append(tmp_refl_sim)
                tmp_between_sim_list.append(tmp_between_sim)

            refl_sim = torch.cat(tmp_refl_sim_list, dim=-1)
            between_sim = torch.cat(tmp_between_sim_list, dim=-1)

            losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag()/ (refl_sim.sum(1) + between_sim.sum(1) - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())+1e-8))

            del refl_sim, between_sim, tmp_refl_sim_list, tmp_between_sim_list
                   
        loss_vec = torch.cat(losses)
        return loss_vec.mean()


    def feat_reg_loss_calculation(self, g_item_image, g_item_text, g_user_image, g_user_text):
        feat_reg = 1./2*(g_item_image**2).sum() + 1./2*(g_item_text**2).sum() \
            + 1./2*(g_user_image**2).sum() + 1./2*(g_user_text**2).sum()        
        feat_reg = feat_reg / self.n_items
        feat_emb_loss = args.feat_reg_decay * feat_reg
        return feat_emb_loss


    def fake_gene_loss_calculation(self, u_emb, i_emb, emb_type=None):
        if self.gene_u!=None:
            gene_real_loss = (-F.logsigmoid((u_emb[self.gene_u]*i_emb[self.gene_real]).sum(-1)+1e-8)).mean()
            gene_fake_loss = (1-(-F.logsigmoid((u_emb[self.gene_u]*i_emb[self.gene_fake[emb_type]]).sum(-1)+1e-8))).mean()

            gene_loss = gene_real_loss + gene_fake_loss
        else:
            gene_loss = 0

        return gene_loss

    def reward_loss_calculation(self, users, re_u, re_i, topk_id, topk_p):
        self.gene_u = torch.tensor(np.array(users)).unsqueeze(1).repeat(1, args.ad_topk)
        reward_u = re_u[self.gene_u]
        reward_i = re_i[topk_id]
        reward_value = (reward_u*reward_i).sum(-1)

        reward_loss = -(((topk_p*reward_value).sum(-1)).mean()+1e-8).log()
        
        return reward_loss



    def u_sim_calculation(self, users, user_final, item_final):
        topk_u = user_final[users]
        u_ui = torch.tensor(self.ui_graph_raw[users].todense()).cuda()

        num_batches = (self.n_items - 1) // args.batch_size + 1
        indices = torch.arange(0, self.n_items).cuda()
        u_sim_list = []

        for i_b in range(num_batches):
            index = indices[i_b * args.batch_size:(i_b + 1) * args.batch_size]
            sim = torch.mm(topk_u, item_final[index].T)
            sim_gt = torch.multiply(sim, (1-u_ui[:, index]))
            u_sim_list.append(sim_gt)
                
        u_sim = F.normalize(torch.cat(u_sim_list, dim=-1), p=2, dim=1)   
        return u_sim


    def test(self, users_to_test, is_val):
        self.model.eval()
        with torch.no_grad():
            ua_embeddings, ia_embeddings, *rest = self.model(self.ui_graph, self.iu_graph, self.image_ui_graph, self.image_iu_graph, self.text_ui_graph, self.text_iu_graph)
        result = test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val)
        #result = my_test_torch(ua_embeddings, ia_embeddings, users_to_test, is_val)
        return result

    def train(self):

        now_time = datetime.now()
        run_time = datetime.strftime(now_time,'%Y_%m_%d__%H_%M_%S')

        training_time_list = []
        loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger,mrr_loger = [], [], [], [], [], []
        line_var_loss, line_g_loss, line_d_loss, line_cl_loss, line_var_recall, line_var_precision, line_var_ndcg = [], [], [], [], [], [], []
        stopping_step = 0
        should_stop = False
        cur_best_pre_0 = 0. 
        # tb_writer = SummaryWriter(log_dir="/home/ww/Code/work5/MICRO2Ours/tensorboard/")
        # tensorboard_cnt = 0

        n_batch = data_generator.n_train // args.batch_size + 1
        best_recall = 0
        for epoch in range(args.epoch):
            t1 = time()
            loss, mf_loss, emb_loss, reg_loss = 0., 0., 0., 0.
            contrastive_loss = 0.
            n_batch = data_generator.n_train // args.batch_size + 1
            sample_time = 0.
            self.gene_u, self.gene_real, self.gene_fake = None, None, {}
            self.topk_p_dict, self.topk_id_dict = {}, {}

            for idx in tqdm(range(n_batch)):
                self.model.train()
                sample_t1 = time()
                users, pos_items, neg_items = data_generator.sample()
                sample_time += time() - sample_t1       

                with torch.no_grad():
                    ua_embeddings, ia_embeddings, image_item_embeds, text_item_embeds, image_user_embeds, text_user_embeds \
                                    , _, _, _, _, _, _ \
                            = self.model(self.ui_graph, self.iu_graph, self.image_ui_graph, self.image_iu_graph, self.text_ui_graph, self.text_iu_graph)
                ui_u_sim_detach = self.u_sim_calculation(users, ua_embeddings, ia_embeddings).detach()
                image_u_sim_detach = self.u_sim_calculation(users, image_user_embeds, image_item_embeds).detach()
                text_u_sim_detach = self.u_sim_calculation(users, text_user_embeds, text_item_embeds).detach()
                inputf = torch.cat((image_u_sim_detach, text_u_sim_detach), dim=0)
                predf = (self.D(inputf))
                lossf = (predf.mean())
                u_ui = torch.tensor(self.ui_graph_raw[users].todense()).cuda()
                u_ui = F.softmax(u_ui - args.log_log_scale*torch.log(-torch.log(torch.empty((u_ui.shape[0], u_ui.shape[1]), dtype=torch.float32).uniform_(0,1).cuda()+1e-8)+1e-8)/args.real_data_tau, dim=1) #0.002  
                u_ui += ui_u_sim_detach*args.ui_pre_scale                  
                u_ui = F.normalize(u_ui, dim=1)  
                inputr = torch.cat((u_ui, u_ui), dim=0)
                predr = (self.D(inputr))
                lossr = - (predr.mean())
                gp = self.gradient_penalty(self.D, inputr, inputf.detach())
                loss_D = lossr + lossf + args.gp_rate*gp 
                self.optim_D.zero_grad()
                loss_D.backward()
                self.optim_D.step()
                line_d_loss.append(loss_D.detach().data)

                G_ua_embeddings, G_ia_embeddings, G_image_item_embeds, G_text_item_embeds, G_image_user_embeds, G_text_user_embeds \
                                , G_user_emb, _, G_image_user_id, G_text_user_id, _, _ \
                        = self.model(self.ui_graph, self.iu_graph, self.image_ui_graph, self.image_iu_graph, self.text_ui_graph, self.text_iu_graph)


                G_u_g_embeddings = G_ua_embeddings[users]
                G_pos_i_g_embeddings = G_ia_embeddings[pos_items]
                G_neg_i_g_embeddings = G_ia_embeddings[neg_items]
                G_batch_mf_loss, G_batch_emb_loss, G_batch_reg_loss = self.bpr_loss(G_u_g_embeddings, G_pos_i_g_embeddings, G_neg_i_g_embeddings)
                G_image_u_sim = self.u_sim_calculation(users, G_image_user_embeds, G_image_item_embeds)
                G_text_u_sim = self.u_sim_calculation(users, G_text_user_embeds, G_text_item_embeds)
                G_image_u_sim_detach = G_image_u_sim.detach() 
                G_text_u_sim_detach = G_text_u_sim.detach()


                if idx%args.T==0 and idx!=0:
                    self.image_ui_graph_tmp = csr_matrix((torch.ones(len(self.image_ui_index['x'])),(self.image_ui_index['x'], self.image_ui_index['y'])), shape=(self.n_users, self.n_items))
                    self.text_ui_graph_tmp = csr_matrix((torch.ones(len(self.text_ui_index['x'])),(self.text_ui_index['x'], self.text_ui_index['y'])), shape=(self.n_users, self.n_items))
                    self.image_iu_graph_tmp = self.image_ui_graph_tmp.T
                    self.text_iu_graph_tmp = self.text_ui_graph_tmp.T
                    self.image_ui_graph = self.sparse_mx_to_torch_sparse_tensor( \
                        self.csr_norm(self.image_ui_graph_tmp, mean_flag=True)
                        ).cuda() 
                    self.text_ui_graph = self.sparse_mx_to_torch_sparse_tensor(
                        self.csr_norm(self.text_ui_graph_tmp, mean_flag=True)
                        ).cuda()
                    self.image_iu_graph = self.sparse_mx_to_torch_sparse_tensor(
                        self.csr_norm(self.image_iu_graph_tmp, mean_flag=True)
                        ).cuda()
                    self.text_iu_graph = self.sparse_mx_to_torch_sparse_tensor(
                        self.csr_norm(self.text_iu_graph_tmp, mean_flag=True)
                        ).cuda()

                    self.image_ui_index = {'x':[], 'y':[]}
                    self.text_ui_index = {'x':[], 'y':[]}

                else:
                    _, image_ui_id = torch.topk(G_image_u_sim_detach, int(self.n_items*args.m_topk_rate), dim=-1)
                    self.image_ui_index['x'] += np.array(torch.tensor(users).repeat(1, int(self.n_items*args.m_topk_rate)).view(-1)).tolist()
                    self.image_ui_index['y'] += np.array(image_ui_id.cpu().view(-1)).tolist()
                    _, text_ui_id = torch.topk(G_text_u_sim_detach, int(self.n_items*args.m_topk_rate), dim=-1)
                    self.text_ui_index['x'] += np.array(torch.tensor(users).repeat(1, int(self.n_items*args.m_topk_rate)).view(-1)).tolist()
                    self.text_ui_index['y'] += np.array(text_ui_id.cpu().view(-1)).tolist()


                feat_emb_loss = self.feat_reg_loss_calculation(G_image_item_embeds, G_text_item_embeds, G_image_user_embeds, G_text_user_embeds)

                batch_contrastive_loss = 0
                batch_contrastive_loss1 = self.batched_contrastive_loss(G_image_user_id[users],G_user_emb[users])
                batch_contrastive_loss2 = self.batched_contrastive_loss(G_text_user_id[users],G_user_emb[users])
  
                batch_contrastive_loss = batch_contrastive_loss1 + batch_contrastive_loss2 
    
                G_inputf = torch.cat((G_image_u_sim, G_text_u_sim), dim=0)
                G_predf = (self.D(G_inputf))

                G_lossf = -(G_predf.mean())
                batch_loss = G_batch_mf_loss + G_batch_emb_loss + G_batch_reg_loss + feat_emb_loss + args.cl_rate*batch_contrastive_loss + args.G_rate*G_lossf  #feat_emb_loss

                line_var_loss.append(batch_loss.detach().data)
                line_g_loss.append(G_lossf.detach().data)
                line_cl_loss.append(batch_contrastive_loss.detach().data)
                             
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
                self.optimizer_D.zero_grad()  
                batch_loss.backward(retain_graph=False)
                self.optimizer_D.step()

                loss += float(batch_loss)
                mf_loss += float(G_batch_mf_loss)
                emb_loss += float(G_batch_emb_loss)
                reg_loss += float(G_batch_reg_loss)
    
    
            del ua_embeddings, ia_embeddings, G_ua_embeddings, G_ia_embeddings, G_u_g_embeddings, G_neg_i_g_embeddings, G_pos_i_g_embeddings


            if math.isnan(loss) == True:
                self.logger.logging('ERROR: loss is nan.')
                sys.exit()

            if (epoch + 1) % args.verbose != 0:
                perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f + %.5f  + %.5f]' % (
                    epoch, time() - t1, loss, mf_loss, emb_loss, reg_loss, contrastive_loss)
                training_time_list.append(time() - t1)
                self.logger.logging(perf_str)

            t2 = time()
            users_to_test = list(data_generator.test_set.keys())
            users_to_val = list(data_generator.val_set.keys())
            ret = self.test(users_to_val, is_val=True)  
            training_time_list.append(t2 - t1)

            t3 = time()

            loss_loger.append(loss)
            rec_loger.append(ret['recall'].data)
            pre_loger.append(ret['precision'].data)
            ndcg_loger.append(ret['ndcg'].data)
            hit_loger.append(ret['hit_ratio'].data)
            mrr_loger.append(ret['mrr'])

            line_var_recall.append(ret['recall'][1])
            line_var_precision.append(ret['precision'][1])
            line_var_ndcg.append(ret['ndcg'][1])

            tags = ["recall", "precision", "ndcg"]
            # tb_writer.add_scalar(tags[0], ret['recall'][1], epoch)
            # tb_writer.add_scalar(tags[1], ret['precision'][1], epoch)
            # tb_writer.add_scalar(tags[2], ret['ndcg'][1], epoch)


            if args.verbose > 0:
                perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f + %.5f], recall=[%.5f, %.5f, %.5f, %.5f], ' \
                           'precision=[%.5f, %.5f, %.5f, %.5f], hit=[%.5f, %.5f, %.5f, %.5f], ndcg=[%.5f, %.5f, %.5f, %.5f], mrr=[%.5f, %.5f]' % \
                           (epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, reg_loss, ret['recall'][0], ret['recall'][1], ret['recall'][2],
                            ret['recall'][-1],
                            ret['precision'][0], ret['precision'][1], ret['precision'][2], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][1], ret['hit_ratio'][2], ret['hit_ratio'][-1],
                            ret['ndcg'][0], ret['ndcg'][1], ret['ndcg'][2], ret['ndcg'][-1],ret['mrr'][0], ret['mrr'][1])
                self.logger.logging(perf_str)

            if ret['recall'][1] > best_recall:
                best_recall = ret['recall'][1]
                test_ret = self.test(users_to_test, is_val=False)
                self.logger.logging("Test_Recall@%d: %.5f,  precision=[%.5f], ndcg=[%.5f]" % (eval(args.Ks)[1], test_ret['recall'][1], test_ret['precision'][1], test_ret['ndcg'][1]))
                stopping_step = 0
            elif stopping_step < args.early_stopping_patience:
                stopping_step += 1
                self.logger.logging('#####Early stopping steps: %d #####' % stopping_step)
            else:
                self.logger.logging('#####Early stop! #####')
                break
        self.logger.logging(str(test_ret))
        print(test_ret)
        print(args.dataset,args.shuffle,args.p,tm.strftime("%a %b %d %H:%M:%S %Y", tm.localtime()),test_ret, file=open('test_mmssl_ret.txt','a'))

        return best_recall, run_time

    def val_loss(self):

        loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
        n_batch = data_generator.n_val // args.batch_size + 1
        best_recall = 0
        loss, mf_loss, emb_loss, reg_loss = 0., 0., 0., 0.
        contrastive_loss = 0.
        n_batch = data_generator.n_train // args.batch_size + 1
        self.gene_u, self.gene_real, self.gene_fake = None, None, {}
        self.topk_p_dict, self.topk_id_dict = {}, {}

        for idx in tqdm(range(n_batch)):
            users, pos_items, neg_items = data_generator.sample() 

            with torch.no_grad():
                ua_embeddings, ia_embeddings, image_item_embeds, text_item_embeds, image_user_embeds, text_user_embeds \
                                , _, _, _, _, _, _ \
                        = self.model(self.ui_graph, self.iu_graph, self.image_ui_graph, self.image_iu_graph, self.text_ui_graph, self.text_iu_graph)
            ui_u_sim_detach = self.u_sim_calculation(users, ua_embeddings, ia_embeddings).detach()
            image_u_sim_detach = self.u_sim_calculation(users, image_user_embeds, image_item_embeds).detach()
            text_u_sim_detach = self.u_sim_calculation(users, text_user_embeds, text_item_embeds).detach()
            inputf = torch.cat((image_u_sim_detach, text_u_sim_detach), dim=0)
            predf = (self.D(inputf))
            lossf = (predf.mean())
            u_ui = torch.tensor(self.ui_graph_raw[users].todense()).cuda()
            u_ui = F.softmax(u_ui - args.log_log_scale*torch.log(-torch.log(torch.empty((u_ui.shape[0], u_ui.shape[1]), dtype=torch.float32).uniform_(0,1).cuda()+1e-8)+1e-8)/args.real_data_tau, dim=1) #0.002  
            u_ui += ui_u_sim_detach*args.ui_pre_scale                  
            u_ui = F.normalize(u_ui, dim=1)  
            inputr = torch.cat((u_ui, u_ui), dim=0)
            predr = (self.D(inputr))
            lossr = - (predr.mean())
            gp = self.gradient_penalty(self.D, inputr, inputf.detach())
            loss_D = lossr + lossf + args.gp_rate*gp 
            self.optim_D.zero_grad()
            loss_D.backward()
            self.optim_D.step()
            line_d_loss.append(loss_D.detach().data)

            G_ua_embeddings, G_ia_embeddings, G_image_item_embeds, G_text_item_embeds, G_image_user_embeds, G_text_user_embeds \
                            , G_user_emb, _, G_image_user_id, G_text_user_id, _, _ \
                    = self.model(self.ui_graph, self.iu_graph, self.image_ui_graph, self.image_iu_graph, self.text_ui_graph, self.text_iu_graph)


            G_u_g_embeddings = G_ua_embeddings[users]
            G_pos_i_g_embeddings = G_ia_embeddings[pos_items]
            G_neg_i_g_embeddings = G_ia_embeddings[neg_items]
            G_batch_mf_loss, G_batch_emb_loss, G_batch_reg_loss = self.bpr_loss(G_u_g_embeddings, G_pos_i_g_embeddings, G_neg_i_g_embeddings)
            G_image_u_sim = self.u_sim_calculation(users, G_image_user_embeds, G_image_item_embeds)
            G_text_u_sim = self.u_sim_calculation(users, G_text_user_embeds, G_text_item_embeds)
            G_image_u_sim_detach = G_image_u_sim.detach() 
            G_text_u_sim_detach = G_text_u_sim.detach()


            if idx%args.T==0 and idx!=0:
                self.image_ui_graph_tmp = csr_matrix((torch.ones(len(self.image_ui_index['x'])),(self.image_ui_index['x'], self.image_ui_index['y'])), shape=(self.n_users, self.n_items))
                self.text_ui_graph_tmp = csr_matrix((torch.ones(len(self.text_ui_index['x'])),(self.text_ui_index['x'], self.text_ui_index['y'])), shape=(self.n_users, self.n_items))
                self.image_iu_graph_tmp = self.image_ui_graph_tmp.T
                self.text_iu_graph_tmp = self.text_ui_graph_tmp.T
                self.image_ui_graph = self.sparse_mx_to_torch_sparse_tensor( \
                    self.csr_norm(self.image_ui_graph_tmp, mean_flag=True)
                    ).cuda() 
                self.text_ui_graph = self.sparse_mx_to_torch_sparse_tensor(
                    self.csr_norm(self.text_ui_graph_tmp, mean_flag=True)
                    ).cuda()
                self.image_iu_graph = self.sparse_mx_to_torch_sparse_tensor(
                    self.csr_norm(self.image_iu_graph_tmp, mean_flag=True)
                    ).cuda()
                self.text_iu_graph = self.sparse_mx_to_torch_sparse_tensor(
                    self.csr_norm(self.text_iu_graph_tmp, mean_flag=True)
                    ).cuda()

                self.image_ui_index = {'x':[], 'y':[]}
                self.text_ui_index = {'x':[], 'y':[]}

            else:
                _, image_ui_id = torch.topk(G_image_u_sim_detach, int(self.n_items*args.m_topk_rate), dim=-1)
                self.image_ui_index['x'] += np.array(torch.tensor(users).repeat(1, int(self.n_items*args.m_topk_rate)).view(-1)).tolist()
                self.image_ui_index['y'] += np.array(image_ui_id.cpu().view(-1)).tolist()
                _, text_ui_id = torch.topk(G_text_u_sim_detach, int(self.n_items*args.m_topk_rate), dim=-1)
                self.text_ui_index['x'] += np.array(torch.tensor(users).repeat(1, int(self.n_items*args.m_topk_rate)).view(-1)).tolist()
                self.text_ui_index['y'] += np.array(text_ui_id.cpu().view(-1)).tolist()


            feat_emb_loss = self.feat_reg_loss_calculation(G_image_item_embeds, G_text_item_embeds, G_image_user_embeds, G_text_user_embeds)

            batch_contrastive_loss = 0
            batch_contrastive_loss1 = self.batched_contrastive_loss(G_image_user_id[users],G_user_emb[users])
            batch_contrastive_loss2 = self.batched_contrastive_loss(G_text_user_id[users],G_user_emb[users])

            batch_contrastive_loss = batch_contrastive_loss1 + batch_contrastive_loss2 

            G_inputf = torch.cat((G_image_u_sim, G_text_u_sim), dim=0)
            G_predf = (self.D(G_inputf))

            G_lossf = -(G_predf.mean())
            batch_loss = G_batch_mf_loss + G_batch_emb_loss + G_batch_reg_loss + feat_emb_loss + args.cl_rate*batch_contrastive_loss + args.G_rate*G_lossf  #feat_emb_loss

            line_var_loss.append(batch_loss.detach().data)
            line_g_loss.append(G_lossf.detach().data)
            line_cl_loss.append(batch_contrastive_loss.detach().data)

            loss += float(batch_loss)
            mf_loss += float(G_batch_mf_loss)
            emb_loss += float(G_batch_emb_loss)
            reg_loss += float(G_batch_reg_loss)
        return loss


    def bpr_loss(self, users, pos_items, neg_items):
        pos_scores = torch.sum(torch.mul(users, pos_items), dim=1)
        neg_scores = torch.sum(torch.mul(users, neg_items), dim=1)

        regularizer = 1./2*(users**2).sum() + 1./2*(pos_items**2).sum() + 1./2*(neg_items**2).sum()        
        regularizer = regularizer / self.batch_size

        maxi = F.logsigmoid(pos_scores - neg_scores)
        mf_loss = -torch.mean(maxi)

        emb_loss = self.decay * regularizer
        reg_loss = 0.0
        return mf_loss, emb_loss, reg_loss

    def sparse_mx_to_torch_sparse_tensor(self, sparse_mx):
        """Convert a scipy sparse matrix to a torch sparse tensor."""
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
        indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        return torch.sparse.FloatTensor(indices, values, shape)

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed)  

if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    set_seed(args.seed)
    config = dict()
    config['n_users'] = data_generator.n_users
    config['n_items'] = data_generator.n_items

    trainer = Trainer(data_config=config)
    trainer.train()

In [None]:
#shuffle image 0.2
#{'precision': array([0.00613525, 0.00478015, 0.00325534]), 'recall': array([0.0581856 , 0.09064272, 0.15365523]), 'ndcg': array([0.03131403, 0.03982799, 0.05289263]), 'hit_ratio': array([0.06094112, 0.09457444, 0.16050399]), 'auc': 0.0}
#shuffle image 0.5
#{'precision': array([0.0059244 , 0.00460015, 0.00314322]), 'recall': array([0.05598587, 0.08734232, 0.14823411]), 'ndcg': array([0.03014553, 0.03830779, 0.05098226]), 'hit_ratio': array([0.05878118, 0.09123168, 0.15443559]), 'auc': 0.0}

In [None]:
#shuffle all 0.2
#{'precision': array([0.0059244 , 0.00464644, 0.00319259, 0.00744664]), 'recall': array([0.05599138, 0.08787973, 0.1509533 , 0.03531377]), 'ndcg': array([0.03049136, 0.03891664, 0.05194823, 0.0235323 ]), 'hit_ratio': array([0.05878118, 0.09200309, 0.15746979, 0.03713037]), 'mrr': array([0.02226298, 0.02454933, 0.02658688, 0.01941287]), 'auc': 0.0}

In [None]:
import requests
headers = {"Authorization": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1aWQiOjQwMTY2LCJ1dWlkIjoiNmE1MGYzN2ItNTE1My00ZGY4LTkzZjYtZTJkNGRkZjhhMWM1IiwiaXNfYWRtaW4iOmZhbHNlLCJpc19zdXBlcl9hZG1pbiI6ZmFsc2UsInN1Yl9uYW1lIjoiIiwidGVuYW50IjoiYXV0b2RsIiwidXBrIjoiIn0.W9vekL_TuPpETo5tcNnSNn4lRLPj8znhZ7T4yFxDaKmpJIY4kLNN-RqKPHw0wZtYTZDoVE-QMlSW3Gem7Wi6Ww"}
resp = requests.post("https://www.autodl.com/api/v1/wechat/message/send",
                     json={
                         "title": "my_clip_rob-best-try",
                         "name": "my_clip_rob-best-try",
                         "content": "my_clip_rob-best-try"
                     }, headers = headers)
print(resp.content.decode())

In [None]:
!shutdown