In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from random import shuffle

from sentence_transformers import SentenceTransformer

In [None]:
import os
import pandas as pd
import openai
 
openai.api_key = os.environ["OPENAI_API_KEY"] # replace by your own API key
# openai.api_key = None

import tiktoken

from tqdm.auto import tqdm

In [None]:
# Select only one dataset at each run 
lingorank = True
goodreads = False
ml_100k = False
tomplay = False

assert sum([lingorank, goodreads, ml_100k, tomplay]) == 1, "Cannot select more than one dataset"

ada = False
bert = True

### Ada embeddings

In [None]:
def compute_api_call_cost(prompt_tokens):   
    input_token_cost = 0.0001 / 1000 # per token
    return (prompt_tokens * input_token_cost)

#Source: https://platform.openai.com/docs/guides/embeddings/use-cases
model = "text-embedding-ada-002"
def get_embedding(text, model=model):
    text = text.replace("\n", " ")
    try:
        result = openai.Embedding.create(input = text, model=model)#['data'][0]['embedding']
        embedding = result['data'][0]['embedding']
        cost = compute_api_call_cost(result['usage']['total_tokens'])
    except Exception as e:
        print("Embedding not generated")
        print(e)
        embedding = None
        cost = None
        result = None

    return embedding, cost, result

# Initialize variable to keep track of total number of tokens
total_tokens = 0
encoding = tiktoken.encoding_for_model(model)


# Define function to count tokens
def count_tokens(text, index):
    global total_tokens
    tokens = len(encoding.encode(text))
    total_tokens += tokens
    truncated_text = text
    if tokens > 8000:
        # print(f"Warning: Text in row {index} is too long with {tokens} tokens.")
        truncated_text = encoding.decode(encoding.encode(text)[:8000])
    return truncated_text

def add_ada_embeddings_to_df(df):
    # Apply function to content column
    df['truncated_content'] = df['content'].apply(lambda x: count_tokens(x, df[df['content'] == x].index[0]))

    # Print total number of tokens
    print(f"Total number of tokens: {total_tokens}")
    print(f"Total estimated cost: ${compute_api_call_cost(total_tokens):.2f}")

    # Apply function to truncated_content column
    # df['ada_embedding'], df['cost'], _ = zip(*df['truncated_content'].apply(lambda x: get_embedding(x)))
    tqdm.pandas(desc="Progress")
    df['ada_embedding'], df['cost'], _ = zip(*df['truncated_content'].progress_apply(lambda x: get_embedding(x)))


    # Compute total cost
    total_cost = df['cost'].sum()
    del df['cost']

    print(f"Total real cost: ${total_cost:.2f}")

    return df

### Bert Embeddings

In [None]:
def list_to_string(lst):
    # Convert each element to string and join with comma
    return '[' + ', '.join(map(str, lst)) + ']'

In [None]:
def get_bert_embedding(text, model):
    embeddings = model.encode(text)
    return embeddings

In [None]:
def add_BERT_embeddings_to_df(df):

    model = SentenceTransformer('sentence-transformers/paraphrase-xlm-r-multilingual-v1')

    tqdm.pandas(desc="Progress")
    df['bert_embedding'] = df['content'].progress_apply(lambda x: get_bert_embedding(x, model))

    df['bert_embedding'] = df['bert_embedding'].apply(list_to_string)

    return df

### Generate the embeddings

In [None]:
if lingorank:
    file_path = f"../results/recommendation/Zeegu/article.csv"
    embeddings_file = f"../results/recommendation/embeddings_strategy2.csv.gz" 
    df = pd.read_csv(file_path)
if ml_100k:
    file_path = f"../results/recommendation/ml-100k/items.csv"
    embeddings_file = f"../results/recommendation/embeddings_ml-100k.csv.gz"
    df = pd.read_csv(file_path)
if goodreads:
    file_path = f"../results/recommendation/Goodreads/goodreads_books_children.json.gz"
    embeddings_file = f"../results/recommendation/embeddings_goodreads_children.csv.gz"
    df = pd.read_json(file_path, lines=True, compression="gzip")
if tomplay:
    file_path = f"../results/recommendation/Tomplay/items.csv"
    embeddings_file = f"../results/recommendation/embeddings_tomplay.csv.gz"
    df = pd.read_csv(file_path)

In [None]:
if lingorank:
    data_full = pd.read_csv(f"../results/recommendation/Zeegu/strategy2.csv")
        
    ## Remove the articles for which there is no positive rating 
    # Before removing articles, count the unique articles
    original_unique_articles = data_full['article_id'].nunique()

    # Identify articles that have maximum rating <= 0
    articles_to_remove = data_full.groupby('article_id')['rating'].max()
    articles_to_remove = articles_to_remove[articles_to_remove <= 0].index.tolist()

    # Remove these articles from data_full
    data_full = data_full[~data_full['article_id'].isin(articles_to_remove)]

    data = data_full[(data_full['rating'] != 0)].copy()

    unique_user_ids = data['user_id'].unique()
    unique_article_ids = data['article_id'].unique()

    df = df[~ df['id'].isnull()]
    df = df.copy()
    df['id'] = df['id'].astype(int)

    # Keep only the articles for which id is in unique_article_ids
    df = df[df['id'].isin(unique_article_ids)]
 
    # Convert content stored as bytes to string
    # df['content'] = df['content'].apply(lambda x: x[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8'))
    def format_info(row):
        content = row['content']
        try:
            # Perform the string transformations
            formatted_content = content[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
        except Exception as e:
            print(f"Error formatting row: {e}")
            formatted_content = content  # Fallback to original content in case of error
        
        return formatted_content


if ml_100k: 
    
    def format_info(row):
        # Format release date
        release_date = f"{row['release date']}"
        
        # Extract and format genres
        genre_columns = ['unknown', 'Action', 'Adventure', 'Animation', "Children's",
                        'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir',
                        'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 
                        'War', 'Western']
        genres = [genre for genre in genre_columns if row[genre] == 1]
        formatted_genres = ', '.join(genres) if genres else 'N/A'

        # Format the information string
        info = (
            f"Title: {row['movie title']}\n"
            f"Summary: {row['Summary']}\n"
            f"Release Date: {release_date}\n"
            f"Cast: {row['Cast']}\n"
            f"Director: {row['Director']}\n"
            f"Genres: {formatted_genres}\n"
            f"Runtime: {row['Runtime']}\n"
            f"Rating: {row['Rating']}\n"
            f"No. of Ratings: {row['No. of ratings']}"
        )
        return info

if goodreads:

    def try_convert_to_int(value):
        try:
            return int(value)
        except (ValueError, TypeError):
            return np.nan
            
    def format_info(row):
        
        day = try_convert_to_int(row['publication_day'])
        month = try_convert_to_int(row['publication_month'])
        year = try_convert_to_int(row['publication_year'])

        date = f"{day}/{month}/{year}" if not np.isnan(day) and not np.isnan(month) and not np.isnan(year) else np.nan
        info = (
            f"Title: {row['title']}\n"
            f"Description: {row['description']}\n"
            f"Date: {date}\n"
            f"Publisher: {row['publisher']}\n"
            f"Format: {row['format']}\n"
            f"Number of Pages: {row['num_pages']}\n"
            f"Text Reviews Count: {row['text_reviews_count']}\n"
            f"Country Code: {row['country_code']}\n"
            f"Language Code: {row['language_code']}\n"
            f"Average Rating: {row['average_rating']}\n"
            f"Ratings Count: {row['ratings_count']}"
        
        )
        return info
        
if tomplay:
    def format_info(row):
        info = (
            f"Titre: {row['NAME']}\n"
            f"Compositeur: {row['COMPOSER']}\n"
            f"Style: {row['STYLE']}\n"
            f"Accompagnement: {row['ACC_TYPE']}\n"
            f"Niveau: {row['LEVEL']}\n"
            f"Instruments: {row['INSTRUMENT']}\n"
        )
        return info


# Apply the function to each row
df['content'] = df.apply(format_info, axis=1)

In [None]:
if ada:
    df = add_ada_embeddings_to_df(df)
if bert:
    df = add_BERT_embeddings_to_df(df)

In [None]:
df_save=df.copy()

In [None]:
df.to_csv(embeddings_file, index=False, compression='gzip')