In [1]:
class args:
    lr = 1e-3
    batch = 4096
    tstBat = 4096
    reg = 1e-5
    epoch = 200
    latdim = 32
    gnn_layer = 2
    topk = 10
    #data = 'yelp'
    ssl_reg = 0.1
    ib_reg = 0.1
    temp = 0.5
    tstEpoch = 1
    gpu = -1
    lambda0 = 1e-4
    gamma = -0.45
    zeta = 1.05
    init_temperature = 2.0
    temperature_decay = 0.98
    eps = 1e-3
    seed = 421

In [2]:
dataset_name='ml-100k'
dataset_folder='ml'
preflix_folder='24_07_19'

In [3]:
import torch as t
import torch.nn.functional as F

def innerProduct(usrEmbeds, itmEmbeds):
    return t.sum(usrEmbeds * itmEmbeds, dim=-1)

def pairPredict(ancEmbeds, posEmbeds, negEmbeds):
    return innerProduct(ancEmbeds, posEmbeds) - innerProduct(ancEmbeds, negEmbeds)

def calcRegLoss(model):
    ret = 0
    for W in model.parameters():
        ret += W.norm(2).square()
    return ret

def contrastLoss(embeds1, embeds2, nodes, temp):
    embeds1 = F.normalize(embeds1, p=2)
    embeds2 = F.normalize(embeds2, p=2)
    pckEmbeds1 = embeds1[nodes]
    pckEmbeds2 = embeds2[nodes]
    nume = t.exp(t.sum(pckEmbeds1 * pckEmbeds2, dim=-1) / temp)
    deno = t.exp(pckEmbeds1 @ embeds2.T / temp).sum(-1)
    return -t.log(nume / deno)

In [4]:
from torch import nn
import torch.nn.functional as F
import torch
#from Params import args
from copy import deepcopy
import numpy as np
import math
import scipy.sparse as sp
#from Utils.Utils import contrastLoss, calcRegLoss, pairPredict
import time
#import torch_sparse

init = nn.init.xavier_uniform_

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.uEmbeds = nn.Parameter(init(torch.empty(args.user, args.latdim)))
        self.iEmbeds = nn.Parameter(init(torch.empty(args.item, args.latdim)))
        self.gcnLayers = nn.Sequential(*[GCNLayer() for i in range(args.gnn_layer)])

    def forward_gcn(self, adj):
        iniEmbeds = torch.concat([self.uEmbeds, self.iEmbeds], axis=0)

        embedsLst = [iniEmbeds]
        for gcn in self.gcnLayers:
            embeds = gcn(adj, embedsLst[-1])
            embedsLst.append(embeds)
        mainEmbeds = sum(embedsLst)

        return mainEmbeds[:args.user], mainEmbeds[args.user:]

    def forward_graphcl(self, adj):
        iniEmbeds = torch.concat([self.uEmbeds, self.iEmbeds], axis=0)

        embedsLst = [iniEmbeds]
        for gcn in self.gcnLayers:
            embeds = gcn(adj, embedsLst[-1])
            embedsLst.append(embeds)
        mainEmbeds = sum(embedsLst)

        return mainEmbeds

    def forward_graphcl_(self, generator):
        iniEmbeds = torch.concat([self.uEmbeds, self.iEmbeds], axis=0)

        embedsLst = [iniEmbeds]		
        count = 0
        for gcn in self.gcnLayers:
            with torch.no_grad():
                adj = generator.generate(x=embedsLst[-1], layer=count)
            embeds = gcn(adj, embedsLst[-1])
            embedsLst.append(embeds)
            count += 1
        mainEmbeds = sum(embedsLst)

        return mainEmbeds

    def loss_graphcl(self, x1, x2, users, items):
        T = args.temp
        user_embeddings1, item_embeddings1 = torch.split(x1, [args.user, args.item], dim=0)
        user_embeddings2, item_embeddings2 = torch.split(x2, [args.user, args.item], dim=0)

        user_embeddings1 = F.normalize(user_embeddings1, dim=1)
        item_embeddings1 = F.normalize(item_embeddings1, dim=1)
        user_embeddings2 = F.normalize(user_embeddings2, dim=1)
        item_embeddings2 = F.normalize(item_embeddings2, dim=1)

        user_embs1 = F.embedding(users, user_embeddings1)
        item_embs1 = F.embedding(items, item_embeddings1)
        user_embs2 = F.embedding(users, user_embeddings2)
        item_embs2 = F.embedding(items, item_embeddings2)

        all_embs1 = torch.cat([user_embs1, item_embs1], dim=0)
        all_embs2 = torch.cat([user_embs2, item_embs2], dim=0)

        all_embs1_abs = all_embs1.norm(dim=1)
        all_embs2_abs = all_embs2.norm(dim=1)

        sim_matrix = torch.einsum('ik,jk->ij', all_embs1, all_embs2) / torch.einsum('i,j->ij', all_embs1_abs, all_embs2_abs)
        sim_matrix = torch.exp(sim_matrix / T)
        pos_sim = sim_matrix[np.arange(all_embs1.shape[0]), np.arange(all_embs1.shape[0])]
        loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
        loss = - torch.log(loss)

        return loss

    def getEmbeds(self):
        self.unfreeze(self.gcnLayers)
        return torch.concat([self.uEmbeds, self.iEmbeds], axis=0)

    def unfreeze(self, layer):
        for child in layer.children():
            for param in child.parameters():
                param.requires_grad = True

    def getGCN(self):
        return self.gcnLayers

class GCNLayer(nn.Module):
    def __init__(self):
        super(GCNLayer, self).__init__()

    def forward(self, adj, embeds, flag=True):
        if (flag):
            return torch.spmm(adj, embeds)
        else:
            return torch.spmm(adj, embeds)
        #torch_sparse.spmm(adj.indices(), adj.values(), adj.shape[0], adj.shape[1], embeds)

class vgae_encoder(Model):
    def __init__(self):
        super(vgae_encoder, self).__init__()
        hidden = args.latdim
        self.encoder_mean = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden))
        self.encoder_std = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden), nn.Softplus())

    def forward(self, adj):
        x = self.forward_graphcl(adj)

        x_mean = self.encoder_mean(x)
        x_std = self.encoder_std(x)
        gaussian_noise = torch.randn(x_mean.shape).cuda()
        x = gaussian_noise * x_std + x_mean
        return x, x_mean, x_std

class vgae_decoder(nn.Module):
    def __init__(self, hidden=args.latdim):
        super(vgae_decoder, self).__init__()
        self.decoder = nn.Sequential(nn.ReLU(inplace=True), nn.Linear(hidden, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, 1))
        self.sigmoid = nn.Sigmoid()
        self.bceloss = nn.BCELoss(reduction='none')

    def forward(self, x, x_mean, x_std, users, items, neg_items, encoder):
        x_user, x_item = torch.split(x, [args.user, args.item], dim=0)

        edge_pos_pred = self.sigmoid(self.decoder(x_user[users] * x_item[items]))
        edge_neg_pred = self.sigmoid(self.decoder(x_user[users] * x_item[neg_items]))

        loss_edge_pos = self.bceloss( edge_pos_pred, torch.ones(edge_pos_pred.shape).cuda() )
        loss_edge_neg = self.bceloss( edge_neg_pred, torch.zeros(edge_neg_pred.shape).cuda() )
        loss_rec = loss_edge_pos + loss_edge_neg

        kl_divergence = - 0.5 * (1 + 2 * torch.log(x_std) - x_mean**2 - x_std**2).sum(dim=1)

        ancEmbeds = x_user[users]
        posEmbeds = x_item[items]
        negEmbeds = x_item[neg_items]
        scoreDiff = pairPredict(ancEmbeds, posEmbeds, negEmbeds)
        bprLoss = - (scoreDiff).sigmoid().log().sum() / args.batch
        regLoss = calcRegLoss(encoder) * args.reg

        beta = 0.1
        loss = (loss_rec + beta * kl_divergence.mean() + bprLoss + regLoss).mean()

        return loss

class vgae(nn.Module):
    def __init__(self, encoder, decoder):
        super(vgae, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, data, users, items, neg_items):
        x, x_mean, x_std = self.encoder(data)
        loss = self.decoder(x, x_mean, x_std, users, items, neg_items, self.encoder)
        return loss

    def generate(self, data, edge_index, adj):
        x, _, _ = self.encoder(data)

        edge_pred = self.decoder.sigmoid(self.decoder.decoder(x[edge_index[0]] * x[edge_index[1]]))

        vals = adj._values()
        idxs = adj._indices()
        edgeNum = vals.size()
        edge_pred = edge_pred[:, 0]
        mask = ((edge_pred + 0.5).floor()).type(torch.bool)

        newVals = vals[mask]

        newVals = newVals / (newVals.shape[0] / edgeNum[0])
        newIdxs = idxs[:, mask]

        return torch.sparse.FloatTensor(newIdxs, newVals, adj.shape)

class DenoisingNet(nn.Module):
    def __init__(self, gcnLayers, features):
        super(DenoisingNet, self).__init__()

        self.features = features

        self.gcnLayers = gcnLayers

        self.edge_weights = []
        self.nblayers = []
        self.selflayers = []

        self.attentions = []
        self.attentions.append([])
        self.attentions.append([])

        hidden = args.latdim

        self.nblayers_0 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True))
        self.nblayers_1 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True))

        self.selflayers_0 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True))
        self.selflayers_1 = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True))

        self.attentions_0 = nn.Sequential(nn.Linear( 2 * hidden, 1))
        self.attentions_1 = nn.Sequential(nn.Linear( 2 * hidden, 1))

    def freeze(self, layer):
        for child in layer.children():
            for param in child.parameters():
                param.requires_grad = False

    def get_attention(self, input1, input2, layer=0):
        if layer == 0:
            nb_layer = self.nblayers_0
            selflayer = self.selflayers_0
        if layer == 1:
            nb_layer = self.nblayers_1
            selflayer = self.selflayers_1

        input1 = nb_layer(input1)
        input2 = selflayer(input2)

        input10 = torch.concat([input1, input2], axis=1)

        if layer == 0:
            weight10 = self.attentions_0(input10)
        if layer == 1:
            weight10 = self.attentions_1(input10)

        return weight10

    def hard_concrete_sample(self, log_alpha, beta=1.0, training=True):
        gamma = args.gamma
        zeta = args.zeta

        if training:
            debug_var = 1e-7
            bias = 0.0
            np_random = np.random.uniform(low=debug_var, high=1.0-debug_var, size=np.shape(log_alpha.cpu().detach().numpy()))
            random_noise = bias + torch.tensor(np_random)
            gate_inputs = torch.log(random_noise) - torch.log(1.0 - random_noise)
            gate_inputs = (gate_inputs.cuda() + log_alpha) / beta
            gate_inputs = torch.sigmoid(gate_inputs)
        else:
            gate_inputs = torch.sigmoid(log_alpha)

        stretched_values = gate_inputs * (zeta-gamma) +gamma
        cliped = torch.clamp(stretched_values, 0.0, 1.0)
        return cliped.float()

    def generate(self, x, layer=0):
        f1_features = x[self.row, :]
        f2_features = x[self.col, :]

        weight = self.get_attention(f1_features, f2_features, layer)

        mask = self.hard_concrete_sample(weight, training=False)

        mask = torch.squeeze(mask)
        adj = torch.sparse.FloatTensor(self.adj_mat._indices(), mask, self.adj_mat.shape)

        ind = deepcopy(adj._indices())
        row = ind[0, :]
        col = ind[1, :]

        rowsum = torch.sparse.sum(adj, dim=-1).to_dense()
        d_inv_sqrt = torch.reshape(torch.pow(rowsum, -0.5), [-1])
        d_inv_sqrt = torch.clamp(d_inv_sqrt, 0.0, 10.0)
        row_inv_sqrt = d_inv_sqrt[row]
        col_inv_sqrt = d_inv_sqrt[col]
        values = torch.mul(adj._values(), row_inv_sqrt)
        values = torch.mul(values, col_inv_sqrt)

        support = torch.sparse.FloatTensor(adj._indices(), values, adj.shape)

        return support

    def l0_norm(self, log_alpha, beta):
        gamma = args.gamma
        zeta = args.zeta
        gamma = torch.tensor(gamma)
        zeta = torch.tensor(zeta)
        reg_per_weight = torch.sigmoid(log_alpha - beta * torch.log(-gamma/zeta))

        return torch.mean(reg_per_weight)

    def set_fea_adj(self, nodes, adj):
        self.node_size = nodes
        self.adj_mat = adj

        ind = deepcopy(adj._indices())

        self.row = ind[0, :]
        self.col = ind[1, :]

    def call(self, inputs, training=None):
        if training:
            temperature = inputs
        else:
            temperature = 1.0

        self.maskes = []

        x = self.features.detach()
        layer_index = 0
        embedsLst = [self.features.detach()]

        for layer in self.gcnLayers:
            xs = []
            f1_features = x[self.row, :]
            f2_features = x[self.col, :]

            weight = self.get_attention(f1_features, f2_features, layer=layer_index)
            mask = self.hard_concrete_sample(weight, temperature, training)

            self.edge_weights.append(weight)
            self.maskes.append(mask)
            mask = torch.squeeze(mask)

            adj = torch.sparse.FloatTensor(self.adj_mat._indices(), mask, self.adj_mat.shape).coalesce()
            ind = deepcopy(adj._indices())
            row = ind[0, :]
            col = ind[1, :]

            rowsum = torch.sparse.sum(adj, dim=-1).to_dense() + 1e-6
            d_inv_sqrt = torch.reshape(torch.pow(rowsum, -0.5), [-1])
            d_inv_sqrt = torch.clamp(d_inv_sqrt, 0.0, 10.0)
            row_inv_sqrt = d_inv_sqrt[row]
            col_inv_sqrt = d_inv_sqrt[col]
            values = torch.mul(adj.values(), row_inv_sqrt)
            values = torch.mul(values, col_inv_sqrt)
            support = torch.sparse.FloatTensor(adj._indices(), values, adj.shape).coalesce()

            nextx = layer(support, x, False)
            xs.append(nextx)
            x = xs[0]
            embedsLst.append(x)
            layer_index += 1
        return sum(embedsLst)

    def lossl0(self, temperature):
        l0_loss = torch.zeros([]).cuda()
        for weight in self.edge_weights:
            l0_loss += self.l0_norm(weight, temperature)
        self.edge_weights = []
        return l0_loss

    def forward(self, users, items, neg_items, temperature):
        self.freeze(self.gcnLayers)
        x = self.call(temperature, True)
        x_user, x_item = torch.split(x, [args.user, args.item], dim=0)
        ancEmbeds = x_user[users]
        posEmbeds = x_item[items]
        negEmbeds = x_item[neg_items]
        scoreDiff = pairPredict(ancEmbeds, posEmbeds, negEmbeds)
        bprLoss = - (scoreDiff).sigmoid().log().sum() / args.batch
        regLoss = calcRegLoss(self) * args.reg

        lossl0 = self.lossl0(temperature) * args.lambda0
        return bprLoss + regLoss + lossl0

In [5]:
import pickle
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
#from Params import args
import scipy.sparse as sp
#from Utils.TimeLogger import log
import torch as t
import torch.utils.data as data
import torch.utils.data as dataloader

class DataHandler:
    def __init__(self):
        predir='E:/datasets/'+dataset_folder+'/'+dataset_name+'/'+preflix_folder+'/'
        self.predir = predir
        self.trnfile = predir + dataset_name+'_adagcl_train_data.pkl'
        self.tstfile = predir + dataset_name+'_adagcl_test_data.pkl'

    def loadOneFile(self, filename):
        print(filename)
        with open(filename, 'rb') as fs:
            ret = (pickle.load(fs) != 0).astype(np.float32)
        if type(ret) != coo_matrix:
            ret = sp.coo_matrix(ret)
        return ret

    def normalizeAdj(self, mat):
        degree = np.array(mat.sum(axis=-1))
        dInvSqrt = np.reshape(np.power(degree, -0.5), [-1])
        dInvSqrt[np.isinf(dInvSqrt)] = 0.0
        dInvSqrtMat = sp.diags(dInvSqrt)
        return mat.dot(dInvSqrtMat).transpose().dot(dInvSqrtMat).tocoo()

    def makeTorchAdj(self, mat):
        # make ui adj
        a = sp.csr_matrix((args.user, args.user))
        b = sp.csr_matrix((args.item, args.item))
        mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
        mat = (mat != 0) * 1.0
        mat = (mat + sp.eye(mat.shape[0])) * 1.0
        mat = self.normalizeAdj(mat)

        # make cuda tensor
        idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
        vals = t.from_numpy(mat.data.astype(np.float32))
        shape = t.Size(mat.shape)
        return t.sparse.FloatTensor(idxs, vals, shape).cuda()

    def LoadData(self):
        trnMat = self.loadOneFile(self.trnfile)
        tstMat = self.loadOneFile(self.tstfile)
        self.trnMat = trnMat
        args.user, args.item = trnMat.shape
        self.torchBiAdj = self.makeTorchAdj(trnMat)
        trnData = TrnData(trnMat)
        self.trnLoader = dataloader.DataLoader(trnData, batch_size=args.batch, shuffle=True, num_workers=0)
        tstData = TstData(tstMat, trnMat)
        self.tstLoader = dataloader.DataLoader(tstData, batch_size=args.tstBat, shuffle=False, num_workers=0)

class TrnData(data.Dataset):
    def __init__(self, coomat):
        self.rows = coomat.row
        self.cols = coomat.col
        self.dokmat = coomat.todok()
        self.negs = np.zeros(len(self.rows)).astype(np.int32)

    def negSampling(self):
        for i in range(len(self.rows)):
            u = self.rows[i]
            while True:
                iNeg = np.random.randint(args.item)
                if (u, iNeg) not in self.dokmat:
                    break
            self.negs[i] = iNeg

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        return self.rows[idx], self.cols[idx], self.negs[idx]

class TstData(data.Dataset):
    def __init__(self, coomat, trnMat):
        self.csrmat = (trnMat.tocsr() != 0) * 1.0

        tstLocs = [None] * coomat.shape[0]
        tstUsrs = set()
        for i in range(len(coomat.data)):
            row = coomat.row[i]
            col = coomat.col[i]
            if tstLocs[row] is None:
                tstLocs[row] = list()
            tstLocs[row].append(col)
            tstUsrs.add(row)
        tstUsrs = np.array(list(tstUsrs))
        self.tstUsrs = tstUsrs
        self.tstLocs = tstLocs

    def __len__(self):
        return len(self.tstUsrs)

    def __getitem__(self, idx):
        return self.tstUsrs[idx], np.reshape(self.csrmat[self.tstUsrs[idx]].toarray(), [-1]),

In [6]:
###
import pickle
import numpy as np
from scipy.sparse import csr_matrix, coo_matrix, dok_matrix
#from Params import args
import scipy.sparse as sp
#from Utils.TimeLogger import log
import torch as t
import torch.utils.data as data
import torch.utils.data as dataloader

class my_DataHandler:
    def __init__(self):
        predir='E:/datasets/'+dataset_folder+'/'+dataset_name+'/'+preflix_folder+'/'
        self.predir = predir
        self.trnfile = predir + dataset_name+'_adagcl_train_data.pkl'
        self.tstfile = predir + dataset_name+'_adagcl_test_data.pkl'

    def loadOneFile(self, filename):
        print(filename)
        with open(filename, 'rb') as fs:
            ret = (pickle.load(fs) != 0).astype(np.float32)
        if type(ret) != coo_matrix:
            ret = sp.coo_matrix(ret)
        return ret

    def normalizeAdj(self, mat):
        degree = np.array(mat.sum(axis=-1))
        dInvSqrt = np.reshape(np.power(degree, -0.5), [-1])
        dInvSqrt[np.isinf(dInvSqrt)] = 0.0
        dInvSqrtMat = sp.diags(dInvSqrt)
        return mat.dot(dInvSqrtMat).transpose().dot(dInvSqrtMat).tocoo()

    def makeTorchAdj(self, mat):
        # make ui adj
        a = sp.csr_matrix((args.user, args.user))
        b = sp.csr_matrix((args.item, args.item))
        mat = sp.vstack([sp.hstack([a, mat]), sp.hstack([mat.transpose(), b])])
        mat = (mat != 0) * 1.0
        mat = (mat + sp.eye(mat.shape[0])) * 1.0
        mat = self.normalizeAdj(mat)

        # make cuda tensor
        idxs = t.from_numpy(np.vstack([mat.row, mat.col]).astype(np.int64))
        vals = t.from_numpy(mat.data.astype(np.float32))
        shape = t.Size(mat.shape)
        return t.sparse.FloatTensor(idxs, vals, shape).cuda()

    def LoadData(self):
        trnMat = self.loadOneFile(self.trnfile)
        tstMat = self.loadOneFile(self.tstfile)
        self.trnMat = trnMat
        args.user, args.item = trnMat.shape
        self.torchBiAdj = self.makeTorchAdj(trnMat)
        trnData = TrnData(trnMat)
        self.trnLoader = dataloader.DataLoader(trnData, batch_size=args.batch, shuffle=True, num_workers=0)
        tstData = my_TstData(tstMat, trnMat)
        self.tstLoader = dataloader.DataLoader(tstData, batch_size=args.tstBat, shuffle=False, num_workers=0)

class TrnData(data.Dataset):
    def __init__(self, coomat):
        self.rows = coomat.row
        self.cols = coomat.col
        self.dokmat = coomat.todok()
        self.negs = np.zeros(len(self.rows)).astype(np.int32)

    def negSampling(self):
        for i in range(len(self.rows)):
            u = self.rows[i]
            while True:
                iNeg = np.random.randint(args.item)
                if (u, iNeg) not in self.dokmat:
                    break
            self.negs[i] = iNeg

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        return self.rows[idx], self.cols[idx], self.negs[idx]

class my_TstData(data.Dataset):
    def __init__(self, coomat, trnMat):
        self.csrmat = (trnMat.tocsr() != 0) * 1.0
        tstLocs = [None] * coomat.shape[0]
        tstUsrs = set()
        target_user=[]
        target_item=[]
        for i in range(len(coomat.data)):
            row = coomat.row[i]
            col = coomat.col[i]
            if tstLocs[row] is None:
                tstLocs[row] = list()
            tstLocs[row].append(col)
            tstUsrs.add(row)
            target_user.append(row)
            target_item.append(col)
        tstUsrs = np.array(list(tstUsrs))
        self.tstUsrs = tstUsrs
        self.tstLocs = tstLocs
        self.target_user=target_user
        self.target_item=target_item
        print(len(self.target_user),len(self.target_item))

    def __len__(self):
        return len(self.target_user)

    def __getitem__(self, idx):
        return self.target_user[idx], self.target_item[idx]

In [7]:
def hr(rank, k):
    """Hit Rate.
    Args:
        :param rank: A list.
        :param k: A scalar(int).
    :return: hit rate.
    """
    res = 0.0
    for r in rank:
        if r < k:
            res += 1
    return res / len(rank)


def mrr(rank, k):
    """Mean Reciprocal Rank.
    Args:
        :param rank: A list.
        :param k: A scalar(int).
    :return: mrr.
    """
    mrr = 0.0
    for r in rank:
        if r < k:
            mrr += 1 / (r + 1)
    return mrr / len(rank)


def ndcg(rank, k):
    """Normalized Discounted Cumulative Gain.
    Args:
        :param rank: A list.
        :param k: A scalar(int).
    :return: ndcg.
    """
    res = 0.0
    for r in rank:
        if r < k:
            res += 1 / np.log2(r + 2)
    return res / len(rank)

In [8]:
import torch
#import Utils.TimeLogger as logger
#from Utils.TimeLogger import log
#from Params import args
#from Model import Model, vgae_encoder, vgae_decoder, vgae, DenoisingNet
#from DataHandler import DataHandler
import numpy as np
#from Utils.Utils import calcRegLoss, pairPredict
import os
from copy import deepcopy
import scipy.sparse as sp
import random

class Coach:
    def __init__(self, handler):
        self.handler = handler

        print('USER', args.user, 'ITEM', args.item)
        print('NUM OF INTERACTIONS', self.handler.trnLoader.dataset.__len__())
        self.metrics = dict()
        mets = ['Loss', 'preLoss', 'Recall', 'NDCG']
        for met in mets:
            self.metrics['Train' + met] = list()
            self.metrics['Test' + met] = list()
    def makePrint(self, name, ep, reses, save):
        ret = 'Epoch %d/%d, %s: ' % (ep, args.epoch, name)
        for metric in reses:
            val = reses[metric]
            ret += '%s = %.4f, ' % (metric, val)
            tem = name + metric
            if save and tem in self.metrics:
                self.metrics[tem].append(val)
        ret = ret[:-2] + '  '
        return ret
    def run(self):
        self.prepareModel()
        print('Model Prepared')

        recallMax = 0
        ndcgMax = 0
        bestEpoch = 0

        stloc = 0
        print('Model Initialized')

        for ep in range(stloc, args.epoch):
            temperature = max(0.05, args.init_temperature * pow(args.temperature_decay, ep))
            tstFlag = (ep % args.tstEpoch == 0)
            reses = self.trainEpoch(temperature)
            print(self.makePrint('Train', ep, reses, tstFlag))
            if tstFlag:
                f_hr, f_mrr, f_ndcg, f_acc = self.testEpoch()
                print("hr: {:5f} \t mrr: {:.5f}\t ndcg: {:.5f}\t acc: {:.5f}".format(f_hr, f_mrr, f_ndcg, f_acc))
                if (reses['Recall'] > recallMax):
                    recallMax = reses['Recall']
                    ndcgMax = reses['NDCG']
                    bestEpoch = ep
                print(self.makePrint('Test', ep, reses, tstFlag))
            print()
        print('Best epoch : ', bestEpoch, ' , Recall : ', recallMax, ' , NDCG : ', ndcgMax)
    def prepareModel(self):
        self.model = Model().cuda()

        encoder = vgae_encoder().cuda()
        decoder = vgae_decoder().cuda()
        self.generator_1 = vgae(encoder, decoder).cuda()
        self.generator_2 = DenoisingNet(self.model.getGCN(), self.model.getEmbeds()).cuda()
        self.generator_2.set_fea_adj(args.user+args.item, deepcopy(self.handler.torchBiAdj).cuda())

        self.opt = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0)
        self.opt_gen_1 = torch.optim.Adam(self.generator_1.parameters(), lr=args.lr, weight_decay=0)
        self.opt_gen_2 = torch.optim.Adam(filter(lambda p: p.requires_grad, self.generator_2.parameters()), lr=args.lr, weight_decay=0, eps=args.eps)
    def trainEpoch(self, temperature):
        trnLoader = self.handler.trnLoader
        trnLoader.dataset.negSampling()
        generate_loss_1, generate_loss_2, bpr_loss, im_loss, ib_loss, reg_loss = 0, 0, 0, 0, 0, 0
        steps = trnLoader.dataset.__len__() // args.batch

        for i, tem in enumerate(trnLoader):
            data = deepcopy(self.handler.torchBiAdj).cuda()

            data1 = self.generator_generate(self.generator_1)

            self.opt.zero_grad()
            self.opt_gen_1.zero_grad()
            self.opt_gen_2.zero_grad()

            ancs, poss, negs = tem
            ancs = ancs.long().cuda()
            poss = poss.long().cuda()
            negs = negs.long().cuda()

            out1 = self.model.forward_graphcl(data1)
            out2 = self.model.forward_graphcl_(self.generator_2)

            loss = self.model.loss_graphcl(out1, out2, ancs, poss).mean() * args.ssl_reg
            im_loss += float(loss)
            loss.backward()

            self.opt.step()
            self.opt.zero_grad()

            # info bottleneck
            _out1 = self.model.forward_graphcl(data1)
            _out2 = self.model.forward_graphcl_(self.generator_2)

            loss_ib = self.model.loss_graphcl(_out1, out1.detach(), ancs, poss) + self.model.loss_graphcl(_out2, out2.detach(), ancs, poss)
            loss= loss_ib.mean() * args.ib_reg
            ib_loss += float(loss)
            loss.backward()

            self.opt.step()
            self.opt.zero_grad()

            # BPR
            usrEmbeds, itmEmbeds = self.model.forward_gcn(data)
            ancEmbeds = usrEmbeds[ancs]
            posEmbeds = itmEmbeds[poss]
            negEmbeds = itmEmbeds[negs]
            scoreDiff = pairPredict(ancEmbeds, posEmbeds, negEmbeds)
            bprLoss = - (scoreDiff).sigmoid().log().sum() / args.batch
            regLoss = calcRegLoss(self.model) * args.reg
            loss = bprLoss + regLoss
            bpr_loss += float(bprLoss)
            reg_loss += float(regLoss)
            loss.backward()

            loss_1 = self.generator_1(deepcopy(self.handler.torchBiAdj).cuda(), ancs, poss, negs)
            loss_2 = self.generator_2(ancs, poss, negs, temperature)

            loss = loss_1 + loss_2
            generate_loss_1 += float(loss_1)
            generate_loss_2 += float(loss_2)
            loss.backward()

            self.opt.step()
            self.opt_gen_1.step()
            self.opt_gen_2.step()
            if False:
                print('Step %d/%d: gen 1 : %.3f ; gen 2 : %.3f ; bpr : %.3f ; im : %.3f ; ib : %.3f ; reg : %.3f  ' % (
                i, 
                steps,
                generate_loss_1,
                generate_loss_2,
                bpr_loss,
                im_loss,
                ib_loss,
                reg_loss,
                ))

        ret = dict()
        ret['Gen_1 Loss'] = generate_loss_1 / steps
        ret['Gen_2 Loss'] = generate_loss_2 / steps
        ret['BPR Loss'] = bpr_loss / steps
        ret['IM Loss'] = im_loss / steps
        ret['IB Loss'] = ib_loss / steps
        ret['Reg Loss'] = reg_loss / steps

        return ret
    
    def testEpoch(self):
        tstLoader = self.handler.tstLoader
        epRecall, epNdcg = [0] * 2
        i = 0
        num = tstLoader.dataset.__len__()
        #print('tst num',num)
        steps = num // args.tstBat
        hr_b,mrr_b,ndcg_b,acc_b=[],[],[],[]
        for usr, tar_i in tstLoader:
            i += 1
            usr = usr.long().cuda()
            tar_i=tar_i.long().cuda()
            #trnMask = trnMask.cuda()
            neg_indx = torch.arange(args.item).cuda().unsqueeze(0)
            #neg_indx = torch.randint(low=1, high=args.item, size=(tar_i.shape[0], 1000)).cuda() 

            usrEmbeds, itmEmbeds = self.model.forward_gcn(self.handler.torchBiAdj)
            
            tar_i_emb=itmEmbeds[tar_i].unsqueeze(1)
            tar_u_emb=usrEmbeds[usr].unsqueeze(1)
            neg_i_emb=itmEmbeds[neg_indx].expand(tar_i_emb.shape[0], -1, -1)
            
            test_i_embeds=torch.cat([tar_i_emb,neg_i_emb],dim=1)
            
            allPreds = -torch.matmul(tar_u_emb, test_i_embeds.transpose(1,2)).squeeze(1)

            print('rating shape:',allPreds.shape)
            rank = allPreds.argsort().argsort()[:, 0]
            rank=rank.cpu()
            res_1 = hr(rank, args.topk)
            res_2 = ndcg(rank, args.topk)
            res_3 = mrr(rank, args.topk)
            res_4 = hr(rank, 1)
            
            hr_b.append(res_1)
            ndcg_b.append(res_2)
            mrr_b.append(res_3)
            acc_b.append(res_4)
            
        f_hr=np.mean(hr_b)
        f_mrr=np.mean(mrr_b)
        f_ndcg=np.mean(ndcg_b)
        f_acc=np.mean(acc_b)

        return f_hr, f_mrr, f_ndcg, f_acc

    def calcRes(self, topLocs, tstLocs, batIds):
        assert topLocs.shape[0] == len(batIds)
        allRecall = allNdcg = 0
        for i in range(len(batIds)):
            temTopLocs = list(topLocs[i])
            temTstLocs = tstLocs[batIds[i]]
            tstNum = len(temTstLocs)
            maxDcg = np.sum([np.reciprocal(np.log2(loc + 2)) for loc in range(min(tstNum, args.topk))])
            recall = dcg = 0
            for val in temTstLocs:
                if val in temTopLocs:
                    recall += 1
                    dcg += np.reciprocal(np.log2(temTopLocs.index(val) + 2))
            recall = recall / tstNum
            ndcg = dcg / maxDcg
            allRecall += recall
            allNdcg += ndcg
        return allRecall, allNdcg

    def generator_generate(self, generator):
        edge_index = []
        edge_index.append([])
        edge_index.append([])
        adj = deepcopy(self.handler.torchBiAdj)
        idxs = adj._indices()

        with torch.no_grad():
            view = generator.generate(self.handler.torchBiAdj, idxs, adj)

        return view

def seed_it(seed):
    random.seed(seed)
    os.environ["PYTHONSEED"] = str(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 
    torch.backends.cudnn.enabled = True
    torch.manual_seed(seed)

if __name__ == '__main__':
    with torch.cuda.device(args.gpu):
        #logger.saveDefault = True
        seed_it(args.seed)

        print('Start')
        handler = my_DataHandler()
        handler.LoadData()
        print('Load Data')

        coach = Coach(handler)
        coach.run()

Start
E:/datasets/ml/ml-100k/24_07_19/ml-100k_adagcl_train_data.pkl
E:/datasets/ml/ml-100k/24_07_19/ml-100k_adagcl_test_data.pkl
16098 16098
Load Data
USER 943 ITEM 1523
NUM OF INTERACTIONS 58196


  return t.sparse.FloatTensor(idxs, vals, shape).cuda()


Model Prepared
Model Initialized
Epoch 0/200, Train: Gen_1 Loss = 3.6185, Gen_2 Loss = 0.7031, BPR Loss = 0.7025, IM Loss = 0.7477, IB Loss = 1.4943, Reg Loss = 0.0013  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.001727 	 mrr: 0.00045	 ndcg: 0.00074	 acc: 0.00018
Epoch 1/200, Train: Gen_1 Loss = 3.5759, Gen_2 Loss = 0.7031, BPR Loss = 0.7015, IM Loss = 0.7503, IB Loss = 1.4940, Reg Loss = 0.0013  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.006213 	 mrr: 0.00152	 ndcg: 0.00258	 acc: 0.00037
Epoch 2/200, Train: Gen_1 Loss = 3.5406, Gen_2 Loss = 0.7031, BPR Loss = 0.6997, IM Loss = 0.7493, IB Loss = 1.4941, Reg Loss = 0.0013  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape:

rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.049297 	 mrr: 0.01370	 ndcg: 0.02182	 acc: 0.00404
Epoch 24/200, Train: Gen_1 Loss = 2.4160, Gen_2 Loss = 0.7031, BPR Loss = 0.2102, IM Loss = 0.7887, IB Loss = 1.5734, Reg Loss = 0.0365  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.049058 	 mrr: 0.01362	 ndcg: 0.02170	 acc: 0.00392
Epoch 25/200, Train: Gen_1 Loss = 2.4271, Gen_2 Loss = 0.7031, BPR Loss = 0.2072, IM Loss = 0.7879, IB Loss = 1.5719, Reg Loss = 0.0379  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.048110 	 mrr: 0.01352	 ndcg: 0.02142	 acc: 0.00405
Epoch 26/200, Train: Gen_1 Loss = 2.4044, Gen_2 Loss = 0.7031, BPR Loss = 0.2013, IM Loss = 0.7870, IB Loss = 1.5700, Reg Lo

Epoch 47/200, Train: Gen_1 Loss = 2.4183, Gen_2 Loss = 0.7030, BPR Loss = 0.1558, IM Loss = 0.7752, IB Loss = 1.5462, Reg Loss = 0.0575  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.044840 	 mrr: 0.01213	 ndcg: 0.01956	 acc: 0.00385
Epoch 48/200, Train: Gen_1 Loss = 2.3944, Gen_2 Loss = 0.7030, BPR Loss = 0.1526, IM Loss = 0.7746, IB Loss = 1.5452, Reg Loss = 0.0582  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.044831 	 mrr: 0.01211	 ndcg: 0.01955	 acc: 0.00373
Epoch 49/200, Train: Gen_1 Loss = 2.3868, Gen_2 Loss = 0.7030, BPR Loss = 0.1499, IM Loss = 0.7741, IB Loss = 1.5443, Reg Loss = 0.0588  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 

rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.042686 	 mrr: 0.01145	 ndcg: 0.01856	 acc: 0.00324
Epoch 71/200, Train: Gen_1 Loss = 2.3897, Gen_2 Loss = 0.7030, BPR Loss = 0.1334, IM Loss = 0.7692, IB Loss = 1.5345, Reg Loss = 0.0693  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.042503 	 mrr: 0.01144	 ndcg: 0.01851	 acc: 0.00331
Epoch 72/200, Train: Gen_1 Loss = 2.3961, Gen_2 Loss = 0.7030, BPR Loss = 0.1336, IM Loss = 0.7690, IB Loss = 1.5340, Reg Loss = 0.0696  
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([4096, 1524])
rating shape: torch.Size([3810, 1524])
hr: 0.042629 	 mrr: 0.01157	 ndcg: 0.01862	 acc: 0.00350
Epoch 73/200, Train: Gen_1 Loss = 2.3933, Gen_2 Loss = 0.7030, BPR Loss = 0.1326, IM Loss = 0.7687, IB Loss = 1.5334, Reg Loss = 0.0700  
rating shape: torch.Size(

KeyboardInterrupt: 