In [None]:
import torch
from torch import nn, optim
from tqdm import tqdm
import math
from torch.utils.data import Dataset, DataLoader
import os
import json
import pandas as pd
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class ContinuousDiffusion(nn.Module):
    def __init__(self, model, noise_steps, beta_start, beta_end):
        super().__init__()
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        self.model = model
        self.device = device

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)
    
    def directional_noise(self, x, mean, std_dev):
        eps = torch.randn_like(std_dev)
        bar_eps = mean + torch.multiply(std_dev, eps)
        eps_prime = torch.multiply(torch.sign(x), torch.abs(bar_eps))
        return eps_prime

    def isotropic_noise(self):
        return torch.randn_like(self.std_dev)


    def make_noise(self, x, t):
        mean = torch.mean(x, dim=0)
        std_dev = torch.std(x, dim=0)
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t]).to(device)
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t]).to(device)
        noise = self.directional_noise(x, mean, std_dev).to(device)
    
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
    
    def forward(self, x_0, context):
        t = torch.randint(low=1, high=self.noise_steps, size=(1,))
        x_t, noise = self.make_noise(x_0, t)
        pred_noise = self.model(torch.unsqueeze(x_t, dim=0), context)
        pred_noise = torch.squeeze(pred_noise)
        return noise, pred_noise, t, x_t
    
    def predict(self, x_0, x_t, eps, t):
        x_0_co = torch.sqrt(self.alpha_hat[t - 1]) * self.beta[t] / (1 - self.alpha_hat[t])
        x_0_co = x_0_co.to(device)
        x_t_co = torch.sqrt(self.alpha[t]) * (1 - self.alpha_hat[t - 1]) / (1 - self.alpha_hat[t])
        x_t_co = x_t_co.to(device)
        mean = x_0_co * x_0 + x_t_co * x_t
        std_dev = torch.sqrt(self.beta[t]).to(device)
        return mean * x_t + std_dev * eps

class LightGCN(nn.Module):
    def __init__(self, num_users, num_items, latent_dim=64, n_layers=3) -> None:
        super(LightGCN, self).__init__()

        self.num_users = num_users
        self.num_items = num_items
        self.latent_dim = latent_dim
        self.n_layers = n_layers
        self.embedding_users = nn.Embedding(num_embeddings=self.num_users,
                                            embedding_dim=self.latent_dim)
        self.embedding_items = nn.Embedding(num_embeddings=self.num_items, 
                                            embedding_dim=self.latent_dim)
        nn.init.xavier_uniform_(self.embedding_users.weight, gain=1)
        nn.init.xavier_uniform_(self.embedding_items.weight, gain=1)

    def forward(self, pairs):
        row, col = pairs[0], pairs[1]
        index = torch.stack([row, col], dim=0)
        data = [1.0] * len(row)
        num_nodes = self.num_items + self.num_users
        graph = torch.sparse_coo_tensor(index, torch.tensor(data), size=(num_nodes, num_nodes))

        users_emb = self.embedding_users.weight
        items_emb = self.embedding_items.weight
        all_emb = torch.cat([users_emb, items_emb])
        embs = [all_emb]
        
        for _ in range(self.n_layers):
            all_emb = all_emb.to(device)
            graph = graph.to(device)
            all_emb = torch.sparse.mm(graph, all_emb)
            embs.append(all_emb)
        
        embs = torch.stack(embs, dim=1)
        output_embs = torch.mean(embs, dim=1)
        

        context = torch.zeros_like(output_embs)
        indices = torch.zeros(output_embs.shape[0])
        for i in range(len(row)):
            first = row[i]
            second = col[i]
            context[first] += output_embs[second]
            context[second] += output_embs[first]
            indices[first] += 1
            indices[second] += 1
        indices = indices.unsqueeze(1).expand(-1, output_embs.shape[1])
        context = context.to(device)
        indices = indices.to(device)
        context = context / indices
#         user, item = torch.split(output_embs, [self.num_users, self.num_items])
#         context_user, context_item = torch.split(context, [self.num_users, self.num_items])
        return output_embs, context

def route_args(router, args, depth):
    routed_args = [(dict(), dict()) for _ in range(depth)]
    matched_keys = [key for key in args.keys() if key in router]

    for key in matched_keys:
        val = args[key]
        for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
            new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
            routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
    return routed_args

def layer_drop(layers, prob):
    to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
    blocks = [block for block, drop in zip(layers, to_drop) if not drop]
    blocks = layers[:1] if len(blocks) == 0 else blocks
    return blocks

def default(val, default_val):
    return val if val is not None else default_val

def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor

# helper classes

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x):
        return x + self.fn(x)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
        super().__init__()
        activation = default(activation, GELU)

        self.glu = glu
        self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
        self.act = activation()
        self.dropout = nn.Dropout(dropout)
        self.w2 = nn.Linear(dim * mult, dim)

    def forward(self, x, **kwargs):
        if not self.glu:
            x = self.w1(x)
            x = self.act(x)
        else:
            x, v = self.w1(x).chunk(2, dim=-1)
            x = self.act(x) * v

        x = self.dropout(x)
        x = self.w2(x)
        return x

class LinformerSelfAttention(nn.Module):
    def __init__(self, dim, seq_len, k = 256, heads = 8, dim_head = None, one_kv_head = False, share_kv = False, dropout = 0.):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'

        self.seq_len = seq_len
        self.k = k

        self.heads = heads

        dim_head = default(dim_head, dim // heads)
        self.dim_head = dim_head

        self.to_q = nn.Linear(dim, dim_head * heads, bias = False)

        kv_dim = dim_head if one_kv_head else (dim_head * heads)
        self.to_k = nn.Linear(dim, kv_dim, bias = False)
        self.proj_k = nn.Parameter(init_(torch.zeros(seq_len, k)))

        self.share_kv = share_kv
        if not share_kv:
            self.to_v = nn.Linear(dim, kv_dim, bias = False)
            self.proj_v = nn.Parameter(init_(torch.zeros(seq_len, k)))

        self.dropout = nn.Dropout(dropout)
        self.to_out = nn.Linear(dim_head * heads, dim)

    def forward(self, x, context = None, **kwargs):
        b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k

        kv_len = n if context is None else context.shape[1]
        assert kv_len <= self.seq_len, f'the sequence length of the key / values must be {self.seq_len} - {kv_len} given'

        queries = self.to_q(x)

        proj_seq_len = lambda args: torch.einsum('bnd,nk->bkd', *args)

        kv_input = x if context is None else context

        keys = self.to_k(kv_input)
        values = self.to_v(kv_input) if not self.share_kv else keys

        kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)

        # allow for variable sequence lengths (less than maximum sequence length) by slicing projections

        if kv_len < self.seq_len:
            kv_projs = map(lambda t: t[:kv_len], kv_projs)

        # project keys and values along the sequence length dimension to k

        keys, values = map(proj_seq_len, zip((keys, values), kv_projs))

        # merge head into batch for queries and key / values

        queries = queries.reshape(b, n, h, -1).transpose(1, 2)

        merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
        keys, values = map(merge_key_values, (keys, values))

        # attention

        dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
        attn = dots.softmax(dim=-1)
        attn = self.dropout(attn)
        out = torch.einsum('bhnk,bhkd->bhnd', attn, values)

        # split heads
        out = out.transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)
    
class SequentialSequence(nn.Module):
    def __init__(self, layers, args_route = {}, layer_dropout = 0.):
        super().__init__()
        assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
        self.layers = layers
        self.args_route = args_route
        self.layer_dropout = layer_dropout

    def forward(self, x, **kwargs):
        args = route_args(self.args_route, kwargs, len(self.layers))
        layers_and_args = list(zip(self.layers, args))

        if self.training and self.layer_dropout > 0:
            layers_and_args = layer_drop(layers_and_args, self.layer_dropout)

        for (f, g), (f_args, g_args) in layers_and_args:
            x = x + f(x, **f_args)
            x = x + g(x, **g_args)
        return x

class Linformer(nn.Module):
    def __init__(self, dim, seq_len, depth, k = 256, heads = 8, dim_head = None, one_kv_head = False, share_kv = False, dropout = 0.):
        super().__init__()
        layers = nn.ModuleList([])
        for _ in range(depth):
            attn = LinformerSelfAttention(dim, seq_len, k = k, heads = heads, dim_head = dim_head, one_kv_head = one_kv_head, share_kv = share_kv, dropout = dropout)
            ff = FeedForward(dim, dropout = dropout)

            layers.append(nn.ModuleList([
                PreNorm(dim, attn),
                PreNorm(dim, ff)
            ]))

        self.net = SequentialSequence(layers)

    def forward(self, x, context):
        return self.net(x, context=context)
    

def cosine_similarity(emb_1, emb_2):
    emb_1 = np.array(emb_1)
    emb_2 = np.array(emb_2)
    return np.dot(emb_1, emb_2) / (np.linalg.norm(emb_1) * np.linalg.norm(emb_2))

def topK(target, dictionary, top_k=5):
    list_dic = [(target, k) for k, v in sorted(dictionary.items(), key=lambda item: item[1], reverse=True)]
    return list_dic[:top_k]



class TrainDataset(Dataset):
    def __init__(self, dataset_path, user_feature, item_feature, UI_train, n_enrich) -> None:
        super(TrainDataset, self).__init__()

        self.UI_train = UI_train
        self.num_users = len(user_feature)
        self.num_items = len(item_feature)

        augmented_pair = []
        if os.path.isfile(f'{dataset_path}/augmented_pair_{n_enrich}.csv'):
            import csv
            with open(f'{dataset_path}/augmented_pair_{n_enrich}.csv', 'r') as file:
                reader = csv.reader(file)
                data = list(reader)
            augmented_pair = [(int(row[0]), int(row[1])) for row in data]

        else:
            for user_1 in user_feature:
                sim_dict = {}
                for user_2 in user_feature:
                    if user_1 != user_2:
                        sim_dict[user_2] = cosine_similarity(user_feature[user_1], user_feature[user_2])

                top = topK(user_1, sim_dict, n_enrich)
                augmented_pair.extend(top)

            for item_1 in item_feature:
                sim_dict = {}
                for item_2 in item_feature:
                    if item_1 != item_2:
                        sim_dict[item_2] = cosine_similarity(item_feature[item_1], item_feature[item_2])
                top = topK(item_1, sim_dict, n_enrich)
                augmented_pair.extend(top)
        
            augmented_pair.extend([(user, item) for user, item in zip(self.UI_train['user_ID'], self.UI_train['item_ID'])])
            augmented_pair.extend([(item, user) for user, item in zip(self.UI_train['user_ID'], self.UI_train['item_ID'])])
        self.augmented_pair = augmented_pair

    def get_sparse_graph(self):
        data = [1.0] * len(self.augmented_pair)
        num_nodes = self.num_items + self.num_users
        return torch.sparse_coo_tensor(torch.tensor(self.augmented_pair).t(), torch.tensor(data), size=(num_nodes, num_nodes))
    
    def __getitem__(self, index):
        return self.augmented_pair[index]
    
    def __len__(self):
        return len(self.augmented_pair)

class TestDataset(Dataset):
    def __init__(self, UI, num_users, num_items):
        self.num_users = num_users
        self.num_items = num_items
        self.pairs = [(user, item) for user, item in zip(UI['user_ID'], UI['item_ID'])]
        self.pairs.extend([(item, user) for user, item in zip(UI['user_ID'], UI['item_ID'])])

    def __getitem__(self, index):
        return self.pairs[index]
    
    def __len__(self):
        return len(self.pairs)
    
    def get_sparse_graph(self):
        data = [1.0] * len(self.pairs)
        num_nodes = self.num_items + self.num_users
        return torch.sparse_coo_tensor(torch.tensor(self.pairs).t(), torch.tensor(data), size=(num_nodes, num_nodes))
        


class Datasets:
    def __init__(self, dataset_path, batch_size, n_enrich) -> None:
        self.batch_size = batch_size
        item_feature = json.load(open(f'{dataset_path}/item_feature.json'))
        user_feature = json.load(open(f'{dataset_path}/user_feature.json'))
        self.UI_train = pd.read_csv(f'{dataset_path}/user_item_train.csv')
        self.UI_test = pd.read_csv(f'{dataset_path}/user_item_test.csv')
        self.UI_val = pd.read_csv(f'{dataset_path}/user_item_val.csv')
        self.num_users = len(user_feature)
        self.num_items = len(item_feature)

        
        self.train_data = TrainDataset(dataset_path, user_feature, item_feature, self.UI_train, n_enrich)
        self.val_data = TestDataset(self.UI_val, self.num_users, self.num_items)
        self.test_data = TestDataset(self.UI_test, self.num_users, self.num_items)
        self.train_dataloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
        self.val_dataloader = DataLoader(self.val_data, batch_size=len(self.val_data), shuffle=False)
        self.test_dataloader = DataLoader(self.test_data, batch_size=len(self.test_data), shuffle=False) 

In [None]:
def calculate_recall(prediction, ground_truth):
    total_true_positives = 0
    total_ground_truth = 0
    
    for user_idx, ground_truth_labels in ground_truth.items():
        user_prediction = prediction[user_idx]
        true_positives = 0
        
        for label in ground_truth_labels:
            if label in user_prediction:
                true_positives += 1
        
        total_true_positives += true_positives
        total_ground_truth += len(ground_truth_labels)
    
    recall = total_true_positives / total_ground_truth
    return recall

def calculate_ndcg(prediction, ground_truth):
    total_ndcg = 0
    
    for user_idx, ground_truth_labels in ground_truth.items():
        user_prediction = prediction[user_idx]
        ideal_dcg = calculate_dcg(ground_truth_labels)
        actual_dcg = calculate_dcg([label for label in user_prediction if label in ground_truth_labels])
        
        if ideal_dcg == 0:
            ndcg = 1.0
        else:
            ndcg = actual_dcg / ideal_dcg
        
        total_ndcg += ndcg
    
    return total_ndcg / len(ground_truth)
import math
def calculate_dcg(labels):
    dcg = 0
    for i, label in enumerate(labels):
        dcg += label / math.log2(i + 2)
    return dcg

In [None]:
n_enrich = 10
dataset_path = '/kaggle/input/diffgt/datasets/foursquare'
dataset = Datasets(dataset_path=dataset_path, batch_size=2048, n_enrich=n_enrich)

graph_encoder = LightGCN(dataset.num_users, dataset.num_items)
graph_encoder.train()
graph_encoder.to(device)
decoder = Linformer(dim=64, seq_len=dataset.num_users + dataset.num_items, depth=3)
decoder.train()
decoder.to(device)
dataloader = dataset.train_dataloader
diffusion = ContinuousDiffusion(decoder, noise_steps=100, beta_start=0, beta_end=1)
diffusion.to(device)
optimizer = optim.AdamW(decoder.parameters(), lr=1e-3)
diffusion_loss = nn.MSELoss()


def contrastive_loss(a, b, temp=0.2):
    infonce_criterion = nn.CrossEntropyLoss()
    a = nn.functional.normalize(a, dim=-1)
    b = nn.functional.normalize(b, dim=-1)
    logits = torch.mm(a, b.T)
    logits /= temp
    labels = torch.arange(a.shape[0]).to(a.device)
    return infonce_criterion(logits, labels)
import random
def bpr_loss(input, emb):
    t_user, t_item = input[0], input[1]
    pos_item = np.arange(dataset.num_users)
    neg_item = np.arange(dataset.num_users)
    for i in range(len(t_user)):
        user = t_user[i].item()
        item = t_item[i].item()
        if user > item: user, item = item, user
        if user < dataset.num_users and item >= dataset.num_users:
            pos_item[i] = item
    for i in range(len(t_user)):
        neg_item[i] = random.randint(dataset.num_users, dataset.num_users + dataset.num_items - 1)
    users_emb = emb[:dataset.num_users]
    pos_emb = emb[pos_item]
    neg_emb = emb[neg_item]
    pos_scores = torch.mul(users_emb, pos_emb)
    pos_scores = torch.sum(pos_scores, dim=0)
    neg_scores = torch.mul(users_emb, neg_emb)
    neg_scores = torch.sum(neg_scores, dim=0)

    loss = torch.mean(torch.nn.functional.softplus(neg_scores - pos_scores))
    return loss

def evaluate(dataloader, num_users, num_items, topK):
    graph_encoder.eval()
    decoder.eval()
    predict_items = []
    target_items = []
    
    with torch.no_grad():
        for batch_idx, input in enumerate(dataloader):
            x, context = graph_encoder(input)
            _, pred_noise, t, x_t = diffusion(x, context)
            emb = diffusion.predict(x, x_t, pred_noise, t)
            user_emb = emb[:num_users, :]
            item_emb = emb[num_users:, :]
            ranking = torch.matmul(user_emb, item_emb.t())
            _, predicted_indices = torch.topk(ranking, topK, dim=1)
            
            predicted_indices += num_users
            target_indices = {i: list() for i in range(num_users)}
            
            target_user, target_item = input[0], input[1]
            for i in range(len(target_user)):
                user = target_user[i].item()
                item = target_item[i].item()
                if user > item: user, item = item, user
                target_indices[user].append(item)
            recall = calculate_recall(predicted_indices, target_indices)
            NDCG = calculate_ndcg(predicted_indices, target_indices)
            print(f'Recall@{topK}: {recall}, NDCG@{topK}: {NDCG}')
    graph_encoder.train()
    decoder.train()

In [None]:
import time

for epoch in range(100):
    total_loss = 0
    for batch_idx, input in enumerate(tqdm(dataloader, total=len(dataloader))):
        start = time.time()
        x, context = graph_encoder(input)
        end_encoder = time.time()
        noise, pred_noise, t, x_t = diffusion(x, context)
        end_diffusion = time.time()
        emb = diffusion.predict(x, x_t, pred_noise, t)
        end_predict = time.time()
        print("LightGCN:", end_encoder - start)
        print("Diffusion:", end_diffusion - end_encoder)
        print("Prediction:", end_predict - end_diffusion)
            
        loss = diffusion_loss(noise, pred_noise) + contrastive_loss(x, emb) + bpr_loss(input, emb)
        end_loss = time.time()
        print("Loss:", end_loss - end_predict)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss
       
    
    print("\tEpoch", epoch + 1, "complete!", "\tTotal loss: ", float(total_loss / batch_idx))
#     if (epoch) % 5 == 0:
    evaluate(dataset.val_dataloader, dataset.num_users, dataset.num_items, topK=20)