#### Libraries/Imports

In [115]:
import requests

from bson import ObjectId
from pydantic.json import ENCODERS_BY_TYPE

from pydantic import BaseModel, Field
from typing import List, Optional

import numpy as np

import torch
from transformers import AutoTokenizer, AutoModel

##### Helper Models

In [116]:
class PydanticObjectId(ObjectId):
    """
    Object Id field. Compatible with Pydantic.
    """

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return PydanticObjectId(v)

    @classmethod
    def __modify_schema__(cls, field_schema: dict):
        field_schema.update(
            type="string",
        )


ENCODERS_BY_TYPE[PydanticObjectId] = str

In [117]:
class UserForum(BaseModel):
    id: Optional[PydanticObjectId] = Field(None, alias="_id")
    following_ids: List[PydanticObjectId]

class Post(BaseModel):
    id: Optional[PydanticObjectId] = Field(None, alias="_id")
    author_id: PydanticObjectId
    title: str
    content: str
    hashtags: List[str]
    response_to_id: Optional[PydanticObjectId]

##### API URL

In [118]:
url = 'http://localhost:5000/posts'

In [119]:
try:
    response = requests.get(url)

    if response.status_code == 200:
        print("Response content:")
        response_content = response.json()
        print(response_content)
    else:
        print(f"Failed to retrieve data. Status code: {response.status_code}")

except requests.exceptions.RequestException as e:
    print(f"Error occurred: {e}")

Response content:
{'posts': [{'_id': '65e3204ce64d1e43b6dd7875', 'author_id': '65e1f82be64d1e71f2a9226b', 'content': 'My special thanks to @at0mul @z @abc', 'hashtags': ['hub_page'], 'title': 'Hello, #hub_page'}, {'_id': '65e326d7e64d1eb175d8bd33', 'author_id': '65e3050de64d1eb2ddb0c678', 'content': '', 'hashtags': [], 'response_to_id': '65e3204ce64d1e43b6dd7875', 'title': 'fr, really cool'}, {'_id': '65e32799e64d1ec521b63709', 'author_id': '65e249baf69a9c082c820154', 'content': '', 'hashtags': [], 'response_to_id': '65e3204ce64d1e43b6dd7875', 'title': 'hmmm'}, {'_id': '65e726e8e64d1e32b88e928b', 'author_id': '65e6e573e64d1e32dacf9881', 'content': 'Astazi facem review la shaormeria din spatele blocului.\n\nNu pot pune in cuvinte ce inseamna aceasta locatie pentru mine. Imi aduc cu drag de inima aminte momentul in care maicuta mea m-a dus sa imi ia primul donner din aceasta locatie.\n\nAstazi, privind locatie cu alti ochi, pot spune ca nu impresioneaza in niciun aspect, dar cu siguranta

In [120]:
all_posts = [Post(**post_dict) for post_dict in response_content.get("posts", [])]
print("All posts:")
for post in all_posts:
    print(post)

liked_posts = all_posts[:4] + all_posts[-4:]
# liked_posts = []
print("Liked posts:")
for post in liked_posts:
    print(post)

disliked_posts = all_posts[12:16] + [all_posts[31]]
# disliked_posts = all_posts[:4] + all_posts[-4:]
# disliked_posts = []
print("Disliked posts:")
for post in disliked_posts:
    print(post)

current_user_id = PydanticObjectId("65d25cd3c2ef35ebebb785e6")
following_ids = [PydanticObjectId("65eda222e64d1e63721f1b1b"), 
                      PydanticObjectId("65db417ff69a9c1ee871447e"),
                      PydanticObjectId("65e475a5d831837d3a72eac5"),
                      PydanticObjectId("65e6e573e64d1e32dacf9881")]
# following_ids = []

current_user_forum = UserForum(_id=current_user_id, following_ids=following_ids)

print(f"User Data: {current_user_forum}")

All posts:
id=ObjectId('65e3204ce64d1e43b6dd7875') author_id=ObjectId('65e1f82be64d1e71f2a9226b') title='Hello, #hub_page' content='My special thanks to @at0mul @z @abc' hashtags=['hub_page'] response_to_id=None
id=ObjectId('65e326d7e64d1eb175d8bd33') author_id=ObjectId('65e3050de64d1eb2ddb0c678') title='fr, really cool' content='' hashtags=[] response_to_id=ObjectId('65e3204ce64d1e43b6dd7875')
id=ObjectId('65e32799e64d1ec521b63709') author_id=ObjectId('65e249baf69a9c082c820154') title='hmmm' content='' hashtags=[] response_to_id=ObjectId('65e3204ce64d1e43b6dd7875')
id=ObjectId('65e726e8e64d1e32b88e928b') author_id=ObjectId('65e6e573e64d1e32dacf9881') title='Imi place sa mananc aicea' content='Astazi facem review la shaormeria din spatele blocului.\n\nNu pot pune in cuvinte ce inseamna aceasta locatie pentru mine. Imi aduc cu drag de inima aminte momentul in care maicuta mea m-a dus sa imi ia primul donner din aceasta locatie.\n\nAstazi, privind locatie cu alti ochi, pot spune ca nu im

##### Compute Product Similarity

In [121]:
tokenizer = AutoTokenizer.from_pretrained('Twitter/twhin-bert-base')
model = AutoModel.from_pretrained('Twitter/twhin-bert-base')

def process_post_data(posts):
    post_texts = [post.title + " " + post.content for post in posts]
    inputs = tokenizer(post_texts, return_tensors="pt", padding=True, truncation=True)
    return inputs

def get_recommendations(posts, liked_posts, disliked_posts, following_ids):
    all_post_inputs = process_post_data(posts) if posts else None
    liked_post_inputs = process_post_data(liked_posts) if liked_posts else None
    disliked_post_inputs = process_post_data(disliked_posts) if disliked_posts else None


    with torch.no_grad():
        all_outputs = model(**all_post_inputs)
        all_embeddings = all_outputs.last_hidden_state[:, 0, :]

        if liked_post_inputs:
            liked_outputs = model(**liked_post_inputs)
            liked_embeddings = liked_outputs.last_hidden_state[:, 0, :]
        else:
            liked_embeddings = torch.zeros(1, model.config.hidden_size)

        if disliked_post_inputs:
            disliked_outputs = model(**disliked_post_inputs)
            disliked_embeddings = disliked_outputs.last_hidden_state[:, 0, :]
        else:
            disliked_embeddings = torch.zeros(1, model.config.hidden_size)

    similarities_liked = torch.matmul(all_embeddings, liked_embeddings.T)
    similarities_disliked = torch.matmul(all_embeddings, disliked_embeddings.T)

    recommendation_scores = torch.zeros(len(posts))

    for i, post in enumerate(posts):
        if post.author_id in following_ids:
            recommendation_scores[i] += 50

    recommendation_scores += similarities_liked.mean(dim=1) - similarities_disliked.mean(dim=1)

    sorted_indices = torch.argsort(recommendation_scores, descending=True)

    recommended_posts = [posts[i] for i in sorted_indices]
    return recommended_posts

Some weights of BertModel were not initialized from the model checkpoint at Twitter/twhin-bert-base and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


##### Test Implementation

In [122]:
recommended_posts = get_recommendations(all_posts, liked_posts, disliked_posts, current_user_forum.following_ids)
print("Recommended posts:")
for post in recommended_posts:
    print(post)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Recommended posts:
id=ObjectId('65ecf3e9d831834e9e21a97f') author_id=ObjectId('65e475a5d831837d3a72eac5') title='La lautari' content='Unde mergem sa petrecem\nCand plecam din cluburi tari?!\nLa lautari, unde canta lautari\nhttps://www.versuri.ro/w/8jt2\nUnde intra spritu bine\nSi se fac petreceri mari\nLa lautari, unde canta lautari' hashtags=[] response_to_id=None
id=ObjectId('65ecf558d831834e9e21a982') author_id=ObjectId('65e475a5d831837d3a72eac5') title='Nefiu' content='Eu cu Amtilb stăm în spate, patru bagaboante\nPar puțin fumate, deci sunt sparte, parfumate\nO împart fiindcă sunt dulce cofetar de\nStradă mușc din savarină și biscuiți cu lapte\nEu cu Amtilb stăm în spate, patru bagaboante\nNe conduc mașina așa șofer n-ai văzut frate\nLe duc până la cofetărie-n spate\nIau mini-eclere și le întreb: „parlez-vous français”?' hashtags=[] response_to_id=ObjectId('65ecf3e9d831834e9e21a97f')
id=ObjectId('65ecf5a2d831834e9e21a984') author_id=ObjectId('65e475a5d831837d3a72eac5') title='Bili