In [1]:
import json
import pandas as pd
import tqdm
from src.scrapped_data_decoders.multi_decoder import MultiDecoder
from src.preprocessing.user_featurizer import UserFeaturizer

In [2]:
with open("mongo_dump.json") as file:
    data = json.load(file)

In [3]:
decoder = MultiDecoder()
user_preferences = [decoder.decode(user) for user in data['users']]
user_features = UserFeaturizer().build_feature_matrix(user_preferences)

In [4]:
public_user_id_to_research_id = {
    user['_id']: i 
    for i, user in enumerate(data['communities'])
}

In [5]:
from src.preprocessing.edge_builder import EdgeBuilder
edge_builder = EdgeBuilder(public_user_id_to_research_id)
edges = edge_builder.build(data['topology'])

100%|██████████| 116641/116641 [00:02<00:00, 40700.26it/s]


In [7]:
from src.preprocessing.community_subscription_matrix_builder import CommunitySubscriptionMatrixBuilder
matrix_builder = CommunitySubscriptionMatriBuilder()
matrix = matrix_builder.build(data['communities'], public_user_id_to_research_id)

100%|██████████| 92153/92153 [00:03<00:00, 24091.25it/s]
100%|██████████| 92153/92153 [00:08<00:00, 11241.07it/s]


In [8]:
user_features.categorical.max(axis=0)

array([10, 10,  2,  9,  8,  6,  5], dtype=int32)

In [9]:
user_features.categorical.shape

(116641, 7)

# pytorch

In [10]:
from enum import Enum
class ConvType(Enum):
    gcn = 1
    gat = 2
    res = 3
    gin = 4

In [59]:
import torch
from torch.nn import ReLU, Dropout, Embedding
from torch_geometric.nn import MLP, Sequential, BatchNorm
from torch_geometric.nn.conv import GCNConv, ResGatedGraphConv, GATConv, GINConv

class NN(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        graph_conv: ConvType,
        depth: int,
        mlp_depth: int,
        embeddings: list[int],
        use_batchnorm: bool=False,
        dropout: float=0,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.use_batchnorm = False
        self.dropout = dropout
        self.input_mlp = MLP([input_dim,] + [hidden_dim] * mlp_depth, plain_last=False)
        self.conv_layers = Sequential(
            "x, edge_index",
            [
                self._build_conv_layer(graph_conv)
                for _ in range(depth)
            ]
        )
        self.output_mlp = MLP([hidden_dim] * mlp_depth + [output_dim])
        self.sigmoid = torch.nn.Sigmoid()
        self._build_embeddings(embeddings)
        
    def _build_conv_layer(self, conv_type: ConvType):
        if conv_type == ConvType.gcn:
            conv = (
                GCNConv(self.hidden_dim, self.hidden_dim),
                'x, edge_index -> x',
            )
        elif conv_type == ConvType.gat:
            conv = (
                GATConv(self.hidden_dim, self.hidden_dim),
                'x, edge_index -> x',
            )
        elif conv_type == ConvType.res:
            conv = (
                ResGatedGraphConv(self.hidden_dim, self.hidden_dim),
                'x, edge_index -> x',
            )
        elif conv_type == ConvType.gin:
            conv = (
                GINConv(MLP([self.hidden_dim, self.hidden_dim], train_eps=True)),
                'x, edge_index -> x',
            )
        else:
            raise NotImplementedError()
        conv_layer = [conv]
        if self.use_batchnorm:
            conv_layer.append((BatchNorm(self.hidden_dim), 'x -> x'))
        conv_layer.append(ReLU(inplace=True))
        if self.dropout is not None:
            conv_layer.append((Dropout(p=self.dropout), 'x -> x'))
        conv_layer = Sequential("x, edge_index", conv_layer)
        return conv
    
    def _build_embeddings(self, embedding_sizes):
        self.city_embedding = Embedding(embedding_sizes[0], embedding_dim=4, scale_grad_by_freq=True)
        self.country_embedding = Embedding(embedding_sizes[1], embedding_dim=4, scale_grad_by_freq=True)
        self.sex_embedding = Embedding(embedding_sizes[2], embedding_dim=2, scale_grad_by_freq=True)
        self.politics_embedding = Embedding(embedding_sizes[3], embedding_dim=4, scale_grad_by_freq=True)
        self.life_embedding = Embedding(embedding_sizes[4], embedding_dim=4, scale_grad_by_freq=True)
        self.people_embedding = Embedding(embedding_sizes[5], embedding_dim=4, scale_grad_by_freq=True)
        self.alcohol_embedding = Embedding(embedding_sizes[6], embedding_dim=4, scale_grad_by_freq=True)
        self.embeddings = [
            self.city_embedding,
            self.country_embedding,
            self.sex_embedding,
            self.politics_embedding,
            self.life_embedding,
            self.people_embedding,
            self.alcohol_embedding,
        ]
    
    def forward(self, features, edge_index):
        x = torch.Tensor(features.numerical)
        # numerical_features = torch.Tensor(features.numerical)
        # categorical_embeddings = self._extract_embeddings(torch.tensor(features.categorical, dtype=torch.long))
        # x = torch.cat([numerical_features, categorical_features], dim=-1)
        x = self.input_mlp(x)
        x = self.conv_layers(x, edge_index)
        x = self.output_mlp(x)
        probs = self.sigmoid(x)
        return probs
    
    def _extract_embeddings(self, categorical_features) -> torch.Tensor:
        feature_embeds = []
        for feature_idx in range(categorical_features.shape[1]):
            feature = categorical_features[:, feature_idx]
            feature_embed = self.embeddings[feature_idx](feature)
            feature_embeds.append(feature_embed)
        feature_embeds = torch.cat(feature_embeds, dim=-1)
        return feature_embeds

In [63]:
nn = NN(
    input_dim=4,
    hidden_dim=10,
    output_dim=10,
    graph_conv=ConvType.gcn,
    depth=3,
    mlp_depth=2,
    embeddings=user_features.categorical.max(axis=0) + 1,
)

In [69]:
nn(user_features, torch.tensor(edges))

tensor([[0.7273, 0.4393, 0.4939,  ..., 0.2790, 0.2943, 0.5061],
        [0.4792, 0.5593, 0.4500,  ..., 0.4454, 0.5442, 0.5247],
        [0.6064, 0.4894, 0.4768,  ..., 0.3601, 0.4045, 0.5053],
        ...,
        [0.5164, 0.6017, 0.4197,  ..., 0.4047, 0.4981, 0.4788],
        [0.5442, 0.6064, 0.4262,  ..., 0.3870, 0.5016, 0.4924],
        [0.6935, 0.6084, 0.4350,  ..., 0.5555, 0.4004, 0.5845]],
       grad_fn=<SigmoidBackward0>)