In [1]:
import argparse
import os
import pandas as pd
import snowflake.connector
import json
import boto3
import xxhash
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.metrics import roc_auc_score
import numpy as np

from typing import Dict

from tqdm import tqdm


sf_account_id = "lnb99345.us-east-1"
sf_secret_id = "snowflake_credentials"
warehouse = "XSMALL"
database = "PRODUCTION"
schema = "SIGNALS"


In [6]:
def get_credentials(secret_id: str, region_name: str) -> str:
    session = boto3.session.Session(profile_name="ml-staging-admin")
    client = session.client('secretsmanager', region_name=region_name)
    response = client.get_secret_value(SecretId=secret_id)
    secrets_value = json.loads(response['SecretString'])    
    
    return secrets_value

In [4]:
def connect(secret_id: str, account: str, warehouse: str, database: str, schema: str, region: str) -> snowflake.connector.SnowflakeConnection:
    
    secret_value = get_credentials(secret_id, region)
    sf_user = secret_value['username']
    sf_password = secret_value['password']
    sf_account = account
    sf_warehouse = warehouse
    sf_database = database
    sf_schema = schema
    sf_protocol = "https"
    
    print(f"sf_user={sf_user}, sf_password=****, sf_account={sf_account}, sf_warehouse={sf_warehouse}, "
          f"sf_database={sf_database}, sf_schema={sf_schema}, sf_protocol={sf_protocol}")    
    
    # Read to connect to snowflake
    ctx = snowflake.connector.connect(user=sf_user,
                                      password=sf_password,
                                      account=sf_account,
                                      warehouse=sf_warehouse,
                                      database=sf_database,
                                      schema=sf_schema,
                                      protocol=sf_protocol)
    
    return ctx

In [5]:
ctx = connect(sf_secret_id, sf_account_id, warehouse, database, schema, "us-east-1")
ctx

TokenRetrievalError: Error when retrieving token from sso: Token has expired and refresh failed

In [12]:
def collect_dataset(ctx: snowflake.connector.SnowflakeConnection, input: str) -> pd.DataFrame:
    # Collect dataset
    sql = f"""
    select * from {input}
    """
    df = pd.read_sql(sql, ctx)
    return df

df = collect_dataset(ctx, input="TRAIN_DATASET_V4")
df

  df = pd.read_sql(sql, ctx)


Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,REQUESTID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,09fcf6f7-4f74-466e-9db3-74c44b6d7385,01HDG2MY11SFE5RYRKHENZ6FH6,2b3b34d7-3d34-48ec-8ce9-92f16327994d,8ddb3845-aa7c-4ee2-a122-ea02e11a0614,1,1.26803,2024-02-20 05:12:15.851,0
1,b20ec925-5695-4349-a7e9-80477650ac67,01HH2MB8ZP2KZFDH917M3A2ESD,60054332-d2f9-4e64-aa41-3324ffce6987,68b1810b-d800-4102-b1cf-cc43e1829a38,1,8.00857,2024-02-20 05:12:49.900,0
2,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HM1SYBDB3ZGZXED42357JJCK,b530a6ef-2245-4ffa-8634-8b61d22191bd,7cea17d1-2a75-4eba-943f-7ffe81c223b2,1,11.93254,2024-02-20 05:12:13.613,0
3,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HGQ1HMYCGRVB2403KTVYZ9B6,1ea30812-404e-415e-8053-12945304a53d,7cea17d1-2a75-4eba-943f-7ffe81c223b2,1,1.33329,2024-02-20 05:12:26.668,0
4,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HKT8H4PCEMW6R92J5TZ47QPS,e12addee-f283-4493-b333-9edd552d9c12,01f36258-a91d-4215-929c-2addd5b4917d,1,0.78313,2024-02-20 05:12:35.218,0
...,...,...,...,...,...,...,...,...
2189459,86d41a8e-3224-4c92-894e-c7d21facbe81,01HAB8WH2305DMMZZ1YR29PAM7,3d534ea7-91d4-48fa-ae54-fd8d73250b93,d3b886ac-be1e-44c0-bb8e-2833eb7ab1c8,1,1.38527,2024-02-27 05:16:00.341,0
2189460,68c4526d-e07b-47d2-b7c5-2d4acd53e6a7,01H4HYKSBS143G5AFSS4FFWQSQ,cd11e760-7bfc-4683-8641-c8f0b4a6c028,48360048-0648-403d-a2cf-96c7f8e7e7bd,1,0.87564,2024-02-27 05:16:39.196,0
2189461,10f7ed84-b0e1-4f6d-b9df-fd6452cec211,01HA123PZ78B5ZJ3N18QDDJT25,9b9f666f-3062-462b-9911-d795b4b795d1,3f19a5b8-2205-45e1-a6de-eac9a1bbc4de,1,1.93350,2024-02-27 05:16:30.480,0
2189462,b7ba7a72-dce6-4e21-ba6f-6d747686aad8,01HHB1XE25DC81FP5CBEFWGKMH,d9c5d37f-0733-4b07-88f3-ab8acc02ab4d,4adb58f1-6b17-403f-96de-167a300e307f,1,0.94993,2024-02-27 05:17:35.099,0


In [13]:
df = df.drop("REQUESTID", axis=1)
df

Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,09fcf6f7-4f74-466e-9db3-74c44b6d7385,01HDG2MY11SFE5RYRKHENZ6FH6,2b3b34d7-3d34-48ec-8ce9-92f16327994d,1,1.26803,2024-02-20 05:12:15.851,0
1,b20ec925-5695-4349-a7e9-80477650ac67,01HH2MB8ZP2KZFDH917M3A2ESD,60054332-d2f9-4e64-aa41-3324ffce6987,1,8.00857,2024-02-20 05:12:49.900,0
2,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HM1SYBDB3ZGZXED42357JJCK,b530a6ef-2245-4ffa-8634-8b61d22191bd,1,11.93254,2024-02-20 05:12:13.613,0
3,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HGQ1HMYCGRVB2403KTVYZ9B6,1ea30812-404e-415e-8053-12945304a53d,1,1.33329,2024-02-20 05:12:26.668,0
4,5dd393fb-4d6b-4511-9aa1-a1781e211fc8,01HKT8H4PCEMW6R92J5TZ47QPS,e12addee-f283-4493-b333-9edd552d9c12,1,0.78313,2024-02-20 05:12:35.218,0
...,...,...,...,...,...,...,...
2189459,86d41a8e-3224-4c92-894e-c7d21facbe81,01HAB8WH2305DMMZZ1YR29PAM7,3d534ea7-91d4-48fa-ae54-fd8d73250b93,1,1.38527,2024-02-27 05:16:00.341,0
2189460,68c4526d-e07b-47d2-b7c5-2d4acd53e6a7,01H4HYKSBS143G5AFSS4FFWQSQ,cd11e760-7bfc-4683-8641-c8f0b4a6c028,1,0.87564,2024-02-27 05:16:39.196,0
2189461,10f7ed84-b0e1-4f6d-b9df-fd6452cec211,01HA123PZ78B5ZJ3N18QDDJT25,9b9f666f-3062-462b-9911-d795b4b795d1,1,1.93350,2024-02-27 05:16:30.480,0
2189462,b7ba7a72-dce6-4e21-ba6f-6d747686aad8,01HHB1XE25DC81FP5CBEFWGKMH,d9c5d37f-0733-4b07-88f3-ab8acc02ab4d,1,0.94993,2024-02-27 05:17:35.099,0


In [14]:
def my_hash(s: str) -> int:
    return xxhash.xxh32(s).intdigest()

df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']] = df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].applymap(my_hash)
df

  df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']] = df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].applymap(my_hash)


Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,2755198180,2559674953,2413341077,1,1.26803,2024-02-20 05:12:15.851,0
1,1718721920,544283077,2585866067,1,8.00857,2024-02-20 05:12:49.900,0
2,3593354338,2618651870,1610553971,1,11.93254,2024-02-20 05:12:13.613,0
3,3593354338,2429981434,509632056,1,1.33329,2024-02-20 05:12:26.668,0
4,3593354338,1049864589,1885778205,1,0.78313,2024-02-20 05:12:35.218,0
...,...,...,...,...,...,...,...
2189459,3628308869,680409234,1731226259,1,1.38527,2024-02-27 05:16:00.341,0
2189460,2039782801,511876874,1503905893,1,0.87564,2024-02-27 05:16:39.196,0
2189461,1632899809,2743535244,1346800933,1,1.93350,2024-02-27 05:16:30.480,0
2189462,3949000645,2185810922,1359448349,1,0.94993,2024-02-27 05:17:35.099,0


In [15]:
val_df = df[df["MIN_TIMESTAMP"] > "2024-02-27"]
train_df = df[df["MIN_TIMESTAMP"] <= "2024-02-27"]
train_df.shape, val_df.shape

((2082972, 7), (106492, 7))

In [16]:
df.drop(["NVIEWS", "MIN_TIMESTAMP"], axis=1)

Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,TIMESPENT,NREACTIONS
0,2755198180,2559674953,2413341077,1.26803,0
1,1718721920,544283077,2585866067,8.00857,0
2,3593354338,2618651870,1610553971,11.93254,0
3,3593354338,2429981434,509632056,1.33329,0
4,3593354338,1049864589,1885778205,0.78313,0
...,...,...,...,...,...
2189459,3628308869,680409234,1731226259,1.38527,0
2189460,2039782801,511876874,1503905893,0.87564,0
2189461,1632899809,2743535244,1346800933,1.93350,0
2189462,3949000645,2185810922,1359448349,0.94993,0


In [None]:
print(torch.cuda.is_available())
device = "cuda"

In [91]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, device):
        self.dataframe = dataframe
        self.features = torch.tensor(dataframe[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].values, dtype=torch.int32)
        self.labels = torch.tensor(dataframe[['TIMESPENT', 'NREACTIONS']].values, dtype=torch.float32)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        features = self.features[idx]
        labels = self.labels[idx]
        has_reactions = (labels[1] > 0).float()
        return {
            'features': {'USERID': features[0], 'MEDIAID': features[1], 'MEDIATAKENBYID': features[2]},
            'labels': {'TIMESPENT': labels[0], 'HAS_REACTIONS': has_reactions}
        }

In [101]:
train_dt = CustomDataset(train_df, device=device)
val_dt = CustomDataset(val_df, device=device)
batch_size = 4096  # Adjust based on your needs
shuffle = True  # Shuffle the data
num_workers = 8
train_dataloader = DataLoader(train_dt, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
val_dataloader = DataLoader(val_dt, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

In [79]:
for batch in train_dataloader:
    features = batch['features']
    labels = batch['labels']
    print(f"Features: {features}, Labels: {labels}")
    break

Features: {'USERID': tensor([ 200985892, 1749910036, 1246860241,  ...,  116364340,  636480120,
        1697649444], device='cuda:0', dtype=torch.int32), 'MEDIAID': tensor([  975837297, -1621105989,   611672514,  ...,  1493603796,
           73740826, -1035981605], device='cuda:0', dtype=torch.int32), 'MEDIATAKENBYID': tensor([ -674431733,  -232262996, -1229820614,  ...,  -784276765,
            2591818, -1801801503], device='cuda:0', dtype=torch.int32)}, Labels: {'TIMESPENT': tensor([0.0000, 0.8832, 0.6834,  ..., 0.0000, 0.9038, 0.0000], device='cuda:0'), 'HAS_REACTIONS': tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')}


In [82]:
class CollaborativeFilteringModel(nn.Module):
    def __init__(self, num_users, num_posts, embedding_dim):
        super(CollaborativeFilteringModel, self).__init__()
        self.num_users = num_users
        self.num_posts = num_posts
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.post_embedding = nn.Embedding(num_posts, embedding_dim)
        
    def forward(self, user_ids, post_ids):
        user_ids %= self.num_users
        post_ids %= self.num_posts
        user_embedded = self.user_embedding(user_ids)
        post_embedded = self.post_embedding(post_ids)
        interaction = (user_embedded * post_embedded).sum(dim=1)
        prediction = torch.sigmoid(interaction)
        return prediction

In [65]:
def dict_to_device(dic: Dict, device: str) -> Dict:
    return {k: v.to(device) for k, v in dic.items()}

In [66]:
import wandb
wandb.login(key="5558b41450b126a5fc0e47c271d4252b06e32fe6")



True

In [102]:
# train lööp
print(len(df["USERID"].unique()), len(df["MEDIAID"].unique()))
num_users = 250000
num_posts = 1000
embedding_dim = 8
model = CollaborativeFilteringModel(num_users=num_users, num_posts=num_posts, embedding_dim=embedding_dim).to(device)

criterion = nn.BCELoss()
lr = 0.001
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)

num_epochs = 500
wandb.init(
      # Set the project where this run will be logged
      project="lapse-cf-pytorch", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"experiment_12", 
      # Track hyperparameters and run metadata
      config={
        "learning_rate": lr,
        "architecture": "CF_user_post",
        "epochs": num_epochs,
        "train_start": train_df["MIN_TIMESTAMP"].min(),
        "train_end": train_df["MIN_TIMESTAMP"].max(),
        "train_count": train_df.shape[0],
        "val_start": val_df["MIN_TIMESTAMP"].min(),
        "val_end": val_df["MIN_TIMESTAMP"].max(),
        "val_count": val_df.shape[0],
        "num_users": num_users,
        "num_posts": num_posts,
        "embedding_dim": embedding_dim,
        "batch_size": batch_size,
        "lr_step_size": 150,
      })


for epoch in range(num_epochs):
    for batch in tqdm(train_dataloader):
        features = batch['features']
        labels = batch['labels']
        features = dict_to_device(features, device)
        labels = dict_to_device(labels, device)
        optimizer.zero_grad()
        outputs = model(features["USERID"], features["MEDIAID"])
        targets = labels["HAS_REACTIONS"]
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
    print(f'Epoch {epoch+1}, train loss: {loss.item()}')
    wandb.log({"train_loss": loss.item(), "epoch": epoch+1})

    # Validation phase
    model.eval() # Set the model to evaluation mode
    val_loss = 0.0
    all_targets = []
    all_outputs = []
    with torch.no_grad():
         for batch in val_dataloader:
            features = batch['features']
            labels = batch['labels']
            features = dict_to_device(features, device)
            labels = dict_to_device(labels, device)
            outputs = model(features["USERID"], features["MEDIAID"])
            targets = labels["HAS_REACTIONS"]
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            all_targets.extend(targets.cpu().numpy())  # Collect all targets
            all_outputs.extend(outputs.cpu().numpy())  # Collect all outputs
    auc_roc_score = roc_auc_score(all_targets, all_outputs)
    
    # Print average validation loss per epoch
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss / len(val_dataloader)}, AUC-ROC: {auc_roc_score}')
    wandb.log({"val_loss": loss.item(), "epoch": epoch+1, "roc-auc": auc_roc_score})

    scheduler.step()

255444 1367




  0%|          | 1/509 [00:00<02:09,  3.93it/s]Exception ignored in: <function _ConnectionBase.__del__ at 0x7f1460607880>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 132, in __del__
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
    self._close()  File "/usr/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()

  File "/usr/lib/python3.10/multiprocessing/connection.py", line 361, in _close
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
    OSError: [Errno 9] Bad file descriptor
_close(self._handle)
OSError: [Errno 9] Bad file descriptor
 40%|███▉      | 203/509 [00:25<00:37,  8.11it/s]


KeyboardInterrupt: 