In [49]:
import argparse
import os
import pandas as pd
import snowflake.connector
import json
import boto3
import xxhash
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from typing import Dict


sf_account_id = "lnb99345.us-east-1"
sf_secret_id = "snowflake_credentials"
warehouse = "XSMALL"
database = "PRODUCTION"
schema = "SIGNALS"


In [17]:
def get_credentials(secret_id: str, region_name: str) -> str:
    session = boto3.session.Session(profile_name="ml-staging-admin")
    client = session.client('secretsmanager', region_name=region_name)
    response = client.get_secret_value(SecretId=secret_id)
    secrets_value = json.loads(response['SecretString'])    
    
    return secrets_value

In [7]:
session = boto3.session.Session(profile_name="ml-staging-admin")
region = session.region_name

In [11]:
def connect(secret_id: str, account: str, warehouse: str, database: str, schema: str, region: str) -> snowflake.connector.SnowflakeConnection:
    
    secret_value = get_credentials(secret_id, region)
    sf_user = secret_value['username']
    sf_password = secret_value['password']
    sf_account = account
    sf_warehouse = warehouse
    sf_database = database
    sf_schema = schema
    sf_protocol = "https"
    
    print(f"sf_user={sf_user}, sf_password=****, sf_account={sf_account}, sf_warehouse={sf_warehouse}, "
          f"sf_database={sf_database}, sf_schema={sf_schema}, sf_protocol={sf_protocol}")    
    
    # Read to connect to snowflake
    ctx = snowflake.connector.connect(user=sf_user,
                                      password=sf_password,
                                      account=sf_account,
                                      warehouse=sf_warehouse,
                                      database=sf_database,
                                      schema=sf_schema,
                                      protocol=sf_protocol)
    
    return ctx

In [19]:
ctx = connect(sf_secret_id, sf_account_id, warehouse, database, schema, region)
ctx

sf_user=MGAIDUK, sf_password=****, sf_account=lnb99345.us-east-1, sf_warehouse=XSMALL, sf_database=PRODUCTION, sf_schema=SIGNALS, sf_protocol=https


<snowflake.connector.connection.SnowflakeConnection at 0x286f64c10>

In [21]:
def collect_dataset(ctx: snowflake.connector.SnowflakeConnection, input: str) -> pd.DataFrame:
    # Collect dataset
    sql = f"""
    select * from {input}
    """
    df = pd.read_sql(sql, ctx)
    return df

df = collect_dataset(ctx, input="TRAIN_DATASET_V4")
df

  df = pd.read_sql(sql, ctx)


Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,REQUESTID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,09fcf6f7-4f74-466e-9db3-74c44b6d7385,01HDG2MY11SFE5RYRKHENZ6FH6,2b3b34d7-3d34-48ec-8ce9-92f16327994d,8ddb3845-aa7c-4ee2-a122-ea02e11a0614,1,1.26803,2024-02-20 05:12:15.851,0
1,b20ec925-5695-4349-a7e9-80477650ac67,01HH2MB8ZP2KZFDH917M3A2ESD,60054332-d2f9-4e64-aa41-3324ffce6987,68b1810b-d800-4102-b1cf-cc43e1829a38,1,8.00857,2024-02-20 05:12:49.900,0
2,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HM1SYBDB3ZGZXED42357JJCK,b530a6ef-2245-4ffa-8634-8b61d22191bd,7cea17d1-2a75-4eba-943f-7ffe81c223b2,1,11.93254,2024-02-20 05:12:13.613,0
3,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HGQ1HMYCGRVB2403KTVYZ9B6,1ea30812-404e-415e-8053-12945304a53d,7cea17d1-2a75-4eba-943f-7ffe81c223b2,1,1.33329,2024-02-20 05:12:26.668,0
4,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HKT8H4PCEMW6R92J5TZ47QPS,e12addee-f283-4493-b333-9edd552d9c12,01f36258-a91d-4215-929c-2addd5b4917d,1,0.78313,2024-02-20 05:12:35.218,0
...,...,...,...,...,...,...,...,...
2189459,86d41a8e-3224-4c92-894e-c7d21facbe81,01HAB8WH2305DMMZZ1YR29PAM7,3d534ea7-91d4-48fa-ae54-fd8d73250b93,d3b886ac-be1e-44c0-bb8e-2833eb7ab1c8,1,1.38527,2024-02-27 05:16:00.341,0
2189460,68c4526d-e07b-47d2-b7c5-2d4acd53e6a7,01H4HYKSBS143G5AFSS4FFWQSQ,cd11e760-7bfc-4683-8641-c8f0b4a6c028,48360048-0648-403d-a2cf-96c7f8e7e7bd,1,0.87564,2024-02-27 05:16:39.196,0
2189461,10f7ed84-b0e1-4f6d-b9df-fd6452cec211,01HA123PZ78B5ZJ3N18QDDJT25,9b9f666f-3062-462b-9911-d795b4b795d1,3f19a5b8-2205-45e1-a6de-eac9a1bbc4de,1,1.93350,2024-02-27 05:16:30.480,0
2189462,b7ba7a72-dce6-4e21-ba6f-6d747686aad8,01HHB1XE25DC81FP5CBEFWGKMH,d9c5d37f-0733-4b07-88f3-ab8acc02ab4d,4adb58f1-6b17-403f-96de-167a300e307f,1,0.94993,2024-02-27 05:17:35.099,0


In [23]:
df = df.drop("REQUESTID", axis=1)
df

Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,09fcf6f7-4f74-466e-9db3-74c44b6d7385,01HDG2MY11SFE5RYRKHENZ6FH6,2b3b34d7-3d34-48ec-8ce9-92f16327994d,1,1.26803,2024-02-20 05:12:15.851,0
1,b20ec925-5695-4349-a7e9-80477650ac67,01HH2MB8ZP2KZFDH917M3A2ESD,60054332-d2f9-4e64-aa41-3324ffce6987,1,8.00857,2024-02-20 05:12:49.900,0
2,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HM1SYBDB3ZGZXED42357JJCK,b530a6ef-2245-4ffa-8634-8b61d22191bd,1,11.93254,2024-02-20 05:12:13.613,0
3,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HGQ1HMYCGRVB2403KTVYZ9B6,1ea30812-404e-415e-8053-12945304a53d,1,1.33329,2024-02-20 05:12:26.668,0
4,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HKT8H4PCEMW6R92J5TZ47QPS,e12addee-f283-4493-b333-9edd552d9c12,1,0.78313,2024-02-20 05:12:35.218,0
...,...,...,...,...,...,...,...
2189459,86d41a8e-3224-4c92-894e-c7d21facbe81,01HAB8WH2305DMMZZ1YR29PAM7,3d534ea7-91d4-48fa-ae54-fd8d73250b93,1,1.38527,2024-02-27 05:16:00.341,0
2189460,68c4526d-e07b-47d2-b7c5-2d4acd53e6a7,01H4HYKSBS143G5AFSS4FFWQSQ,cd11e760-7bfc-4683-8641-c8f0b4a6c028,1,0.87564,2024-02-27 05:16:39.196,0
2189461,10f7ed84-b0e1-4f6d-b9df-fd6452cec211,01HA123PZ78B5ZJ3N18QDDJT25,9b9f666f-3062-462b-9911-d795b4b795d1,1,1.93350,2024-02-27 05:16:30.480,0
2189462,b7ba7a72-dce6-4e21-ba6f-6d747686aad8,01HHB1XE25DC81FP5CBEFWGKMH,d9c5d37f-0733-4b07-88f3-ab8acc02ab4d,1,0.94993,2024-02-27 05:17:35.099,0


In [31]:
def my_hash(s: str) -> int:
    return xxhash.xxh32(s).intdigest()

df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']] = df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].applymap(my_hash)
df

  df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']] = df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].applymap(my_hash)


Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,2755198180,2559674953,2413341077,1,1.26803,2024-02-20 05:12:15.851,0
1,1718721920,544283077,2585866067,1,8.00857,2024-02-20 05:12:49.900,0
2,3593354338,2618651870,1610553971,1,11.93254,2024-02-20 05:12:13.613,0
3,3593354338,2429981434,509632056,1,1.33329,2024-02-20 05:12:26.668,0
4,3593354338,1049864589,1885778205,1,0.78313,2024-02-20 05:12:35.218,0
...,...,...,...,...,...,...,...
2189459,3628308869,680409234,1731226259,1,1.38527,2024-02-27 05:16:00.341,0
2189460,2039782801,511876874,1503905893,1,0.87564,2024-02-27 05:16:39.196,0
2189461,1632899809,2743535244,1346800933,1,1.93350,2024-02-27 05:16:30.480,0
2189462,3949000645,2185810922,1359448349,1,0.94993,2024-02-27 05:17:35.099,0


In [32]:
val_df = df[df["MIN_TIMESTAMP"] > "2024-02-27"]
train_df = df[df["MIN_TIMESTAMP"] <= "2024-02-27"]
train_df.shape, val_df.shape

((2082972, 7), (106492, 7))

In [35]:
df.drop(["NVIEWS", "MIN_TIMESTAMP"], axis=1)

Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,TIMESPENT,NREACTIONS
0,2755198180,2559674953,2413341077,1.26803,0
1,1718721920,544283077,2585866067,8.00857,0
2,3593354338,2618651870,1610553971,11.93254,0
3,3593354338,2429981434,509632056,1.33329,0
4,3593354338,1049864589,1885778205,0.78313,0
...,...,...,...,...,...
2189459,3628308869,680409234,1731226259,1.38527,0
2189460,2039782801,511876874,1503905893,0.87564,0
2189461,1632899809,2743535244,1346800933,1.93350,0
2189462,3949000645,2185810922,1359448349,0.94993,0


In [44]:
class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.features = dataframe[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].values
        self.labels = dataframe[['TIMESPENT', 'NREACTIONS']].values

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        features = torch.tensor(self.features[idx], dtype=torch.int32)
        labels = torch.tensor(self.labels[idx], dtype=torch.float32)
        return {
            'features': {'USERID': features[0], 'MEDIAID': features[1], 'MEDIATAKENBYID': features[2]},
            'labels': {'TIMESPENT': labels[0], 'NREACTIONS': labels[1]}
        }

In [45]:
train_dt = CustomDataset(train_df)
batch_size = 4  # Adjust based on your needs
shuffle = True  # Shuffle the data
dataloader = DataLoader(train_dt, batch_size=batch_size, shuffle=shuffle)

In [46]:
for batch in dataloader:
    features = batch['features']
    labels = batch['labels']
    print(f"Features: {features}, Labels: {labels}")
    break

Features: {'USERID': tensor([  41980475, 1602945607,  365985616,  823174442], dtype=torch.int32), 'MEDIAID': tensor([  757884224,  -310734370, -2146239065, -2133867411], dtype=torch.int32), 'MEDIATAKENBYID': tensor([-1602968852,   259504920,   726371672,  -378004270], dtype=torch.int32)}, Labels: {'TIMESPENT': tensor([0.9328, 1.6353, 0.6208, 0.0000]), 'NREACTIONS': tensor([0., 0., 0., 0.])}


In [47]:
features

{'USERID': tensor([  41980475, 1602945607,  365985616,  823174442], dtype=torch.int32),
 'MEDIAID': tensor([  757884224,  -310734370, -2146239065, -2133867411], dtype=torch.int32),
 'MEDIATAKENBYID': tensor([-1602968852,   259504920,   726371672,  -378004270], dtype=torch.int32)}

In [None]:
class CollaborativeFilteringModel(nn.Module):
    def __init__(self, num_users, num_posts, embedding_dim):
        super(CollaborativeFilteringModel, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.post_embedding = nn.Embedding(num_posts, embedding_dim)
        
    def forward(self, user_ids, post_ids):
        user_embedded = self.user_embedding(user_ids)
        post_embedded = self.post_embedding(post_ids)
        interaction = (user_embedded * post_embedded).sum(dim=1)
        prediction = torch.sigmoid(interaction)
        return prediction