In [1]:
import argparse
import os
import pandas as pd
import snowflake.connector
import json
import boto3
import xxhash
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.metrics import roc_auc_score
import numpy as np

from typing import Dict

from tqdm import tqdm


sf_account_id = "lnb99345.us-east-1"
sf_secret_id = "snowflake_credentials"
warehouse = "XSMALL"
database = "PRODUCTION"
schema = "SIGNALS"


In [2]:
def get_credentials(secret_id: str, region_name: str) -> str:
    session = boto3.session.Session(profile_name="ml-staging-admin")
    client = session.client('secretsmanager', region_name=region_name)
    response = client.get_secret_value(SecretId=secret_id)
    secrets_value = json.loads(response['SecretString'])    
    
    return secrets_value

In [3]:
session = boto3.session.Session(profile_name="ml-staging-admin")
region = session.region_name

In [4]:
def connect(secret_id: str, account: str, warehouse: str, database: str, schema: str, region: str) -> snowflake.connector.SnowflakeConnection:
    
    secret_value = get_credentials(secret_id, region)
    sf_user = secret_value['username']
    sf_password = secret_value['password']
    sf_account = account
    sf_warehouse = warehouse
    sf_database = database
    sf_schema = schema
    sf_protocol = "https"
    
    print(f"sf_user={sf_user}, sf_password=****, sf_account={sf_account}, sf_warehouse={sf_warehouse}, "
          f"sf_database={sf_database}, sf_schema={sf_schema}, sf_protocol={sf_protocol}")    
    
    # Read to connect to snowflake
    ctx = snowflake.connector.connect(user=sf_user,
                                      password=sf_password,
                                      account=sf_account,
                                      warehouse=sf_warehouse,
                                      database=sf_database,
                                      schema=sf_schema,
                                      protocol=sf_protocol)
    
    return ctx

In [5]:
ctx = connect(sf_secret_id, sf_account_id, warehouse, database, schema, "us-east-1")
ctx

sf_user=MGAIDUK, sf_password=****, sf_account=lnb99345.us-east-1, sf_warehouse=XSMALL, sf_database=PRODUCTION, sf_schema=SIGNALS, sf_protocol=https


<snowflake.connector.connection.SnowflakeConnection at 0x7f3558eaaaa0>

In [6]:
def collect_dataset(ctx: snowflake.connector.SnowflakeConnection, input: str) -> pd.DataFrame:
    # Collect dataset
    sql = f"""
    select * from {input}
    """
    df = pd.read_sql(sql, ctx)
    return df

df = collect_dataset(ctx, input="TRAIN_DATASET_V4")
df

  df = pd.read_sql(sql, ctx)


Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,REQUESTID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,4b7b954f-c93d-4002-8e4b-532c465e9d33,01HM001V4DFK6JCS62FBZ64763,1ef97e43-8f70-4323-b9c4-e1f30a53e958,e122b254-72c2-4c12-aff2-a61fe1034bb6,1,0.53400,2024-02-20 02:18:43.857,0
1,a22fd578-a14f-40dc-bb43-aa3d14732444,01HEY3B6FVH5D47FVMH1PKET0C,24ba8dbd-a973-460a-900e-3e87010d50df,ab5addbc-248d-4c71-832a-d7f3aef63e1b,1,3.77994,2024-02-20 01:00:27.825,0
2,a22fd578-a14f-40dc-bb43-aa3d14732444,01HHT0YJ76N3MD1NGM7TZ0DRFK,d0ee7391-61c2-4c60-96d3-6b1db6e7d596,ab5addbc-248d-4c71-832a-d7f3aef63e1b,1,0.70114,2024-02-20 01:00:33.806,0
3,452e8205-5784-4f9d-b79c-17603fc33c9a,01H98XMRXPH02T0WMNTSD6DMZX,79e5dd58-6599-4816-a42d-c4f7d228c19e,736bb5eb-8ea1-4636-aefb-8624cf0d3d68,1,0.50839,2024-02-20 02:19:06.239,0
4,3d22cca2-5c39-46e6-8ca6-bcf9b3d13f03,01H87NGJ3VCP5QXX8078ZDWC8S,38d64e8b-a8e8-4f22-bb5c-f599e900504b,f1ba176b-db4d-4d6f-9d54-74abe97c9108,1,1.18214,2024-02-20 02:19:03.149,0
...,...,...,...,...,...,...,...,...
2608465,a79a17db-1235-457d-a973-0ba0e277c28e,01H3GGGZCP7P09S3GF7Q4Y1M7B,b034bdab-5756-450c-85c2-9d3ffb9e1c7b,5e0431aa-2d86-4591-b609-ca90d886e5b8,1,1.58241,2024-02-28 22:47:24.149,0
2608466,a79a17db-1235-457d-a973-0ba0e277c28e,01HCD35M0GFP3F6G05DQ5PMJ6C,41ab2b5f-aef5-4262-8231-3b04a5f1c840,2fdb8e8d-5a96-45ba-9f1b-26a175415f55,1,1.26669,2024-02-28 22:47:32.883,0
2608467,a79a17db-1235-457d-a973-0ba0e277c28e,01HMKJWJNP1ZNQP866N4CB8QNQ,ca22e8c7-c877-4b68-81ef-2c6d39038206,2fdb8e8d-5a96-45ba-9f1b-26a175415f55,1,1.49944,2024-02-28 22:47:34.161,1
2608468,a79a17db-1235-457d-a973-0ba0e277c28e,01GZMHG070J34V02BH6WW9V90T,e07ea9b5-6d93-487a-a91c-e77963e88737,2fdb8e8d-5a96-45ba-9f1b-26a175415f55,1,1.99909,2024-02-28 22:47:42.548,1


In [7]:
df = df.drop("REQUESTID", axis=1)
df

Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,4b7b954f-c93d-4002-8e4b-532c465e9d33,01HM001V4DFK6JCS62FBZ64763,1ef97e43-8f70-4323-b9c4-e1f30a53e958,1,0.53400,2024-02-20 02:18:43.857,0
1,a22fd578-a14f-40dc-bb43-aa3d14732444,01HEY3B6FVH5D47FVMH1PKET0C,24ba8dbd-a973-460a-900e-3e87010d50df,1,3.77994,2024-02-20 01:00:27.825,0
2,a22fd578-a14f-40dc-bb43-aa3d14732444,01HHT0YJ76N3MD1NGM7TZ0DRFK,d0ee7391-61c2-4c60-96d3-6b1db6e7d596,1,0.70114,2024-02-20 01:00:33.806,0
3,452e8205-5784-4f9d-b79c-17603fc33c9a,01H98XMRXPH02T0WMNTSD6DMZX,79e5dd58-6599-4816-a42d-c4f7d228c19e,1,0.50839,2024-02-20 02:19:06.239,0
4,3d22cca2-5c39-46e6-8ca6-bcf9b3d13f03,01H87NGJ3VCP5QXX8078ZDWC8S,38d64e8b-a8e8-4f22-bb5c-f599e900504b,1,1.18214,2024-02-20 02:19:03.149,0
...,...,...,...,...,...,...,...
2608465,a79a17db-1235-457d-a973-0ba0e277c28e,01H3GGGZCP7P09S3GF7Q4Y1M7B,b034bdab-5756-450c-85c2-9d3ffb9e1c7b,1,1.58241,2024-02-28 22:47:24.149,0
2608466,a79a17db-1235-457d-a973-0ba0e277c28e,01HCD35M0GFP3F6G05DQ5PMJ6C,41ab2b5f-aef5-4262-8231-3b04a5f1c840,1,1.26669,2024-02-28 22:47:32.883,0
2608467,a79a17db-1235-457d-a973-0ba0e277c28e,01HMKJWJNP1ZNQP866N4CB8QNQ,ca22e8c7-c877-4b68-81ef-2c6d39038206,1,1.49944,2024-02-28 22:47:34.161,1
2608468,a79a17db-1235-457d-a973-0ba0e277c28e,01GZMHG070J34V02BH6WW9V90T,e07ea9b5-6d93-487a-a91c-e77963e88737,1,1.99909,2024-02-28 22:47:42.548,1


In [8]:
def my_hash(s: str) -> int:
    return xxhash.xxh32(s).intdigest()

df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']] = df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].applymap(my_hash)
df

  df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']] = df[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].applymap(my_hash)


Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,NVIEWS,TIMESPENT,MIN_TIMESTAMP,NREACTIONS
0,2115200940,3602277063,1934183761,1,0.53400,2024-02-20 02:18:43.857,0
1,1480281949,77365951,1777925837,1,3.77994,2024-02-20 01:00:27.825,0
2,1480281949,2759622470,3550137425,1,0.70114,2024-02-20 01:00:33.806,0
3,3514747355,1310479642,1614259863,1,0.50839,2024-02-20 02:19:06.239,0
4,1413447837,2137444560,2269554751,1,1.18214,2024-02-20 02:19:03.149,0
...,...,...,...,...,...,...,...
2608465,2785304444,1740171158,1363737099,1,1.58241,2024-02-28 22:47:24.149,0
2608466,2785304444,3260630200,4188537382,1,1.26669,2024-02-28 22:47:32.883,0
2608467,2785304444,1778714852,1229274141,1,1.49944,2024-02-28 22:47:34.161,1
2608468,2785304444,3206389750,3471329674,1,1.99909,2024-02-28 22:47:42.548,1


In [9]:
val_df = df[df["MIN_TIMESTAMP"] > "2024-02-28"]
train_df = df[df["MIN_TIMESTAMP"] <= "2024-02-28"]
train_df.shape, val_df.shape

((2312781, 7), (295689, 7))

In [10]:
df.drop(["NVIEWS", "MIN_TIMESTAMP"], axis=1)

Unnamed: 0,USERID,MEDIAID,MEDIATAKENBYID,TIMESPENT,NREACTIONS
0,2115200940,3602277063,1934183761,0.53400,0
1,1480281949,77365951,1777925837,3.77994,0
2,1480281949,2759622470,3550137425,0.70114,0
3,3514747355,1310479642,1614259863,0.50839,0
4,1413447837,2137444560,2269554751,1.18214,0
...,...,...,...,...,...
2608465,2785304444,1740171158,1363737099,1.58241,0
2608466,2785304444,3260630200,4188537382,1.26669,0
2608467,2785304444,1778714852,1229274141,1.49944,1
2608468,2785304444,3206389750,3471329674,1.99909,1


In [11]:
print(torch.cuda.is_available())
device = "cuda"

True


In [12]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, device):
        self.dataframe = dataframe
        self.features = torch.tensor(dataframe[['USERID', 'MEDIAID', 'MEDIATAKENBYID']].values, dtype=torch.int32)
        self.labels = torch.tensor(dataframe[['TIMESPENT', 'NREACTIONS']].values, dtype=torch.float32)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        features = self.features[idx]
        labels = self.labels[idx]
        has_reactions = (labels[1] > 0).float()
        return {
            'features': {'USERID': features[0], 'MEDIAID': features[1], 'MEDIATAKENBYID': features[2]},
            'labels': {'TIMESPENT': labels[0], 'HAS_REACTIONS': has_reactions}
        }

In [13]:
train_dt = CustomDataset(train_df, device=device)
val_dt = CustomDataset(val_df, device=device)
batch_size = 4096  # Adjust based on your needs
shuffle = True  # Shuffle the data
num_workers = 4
train_dataloader = DataLoader(train_dt, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
val_dataloader = DataLoader(val_dt, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

In [14]:
for batch in train_dataloader:
    features = batch['features']
    labels = batch['labels']
    print(f"Features: {features}, Labels: {labels}")
    break

Features: {'USERID': tensor([-2015451549, -1417535959,   917589081,  ...,   757513723,
         -710679938,   477431031], dtype=torch.int32), 'MEDIAID': tensor([ 1225535373,  1453674218, -1035981605,  ...,  -820313917,
        -1804015320,    73740826], dtype=torch.int32), 'MEDIATAKENBYID': tensor([-1791085126, -1837680475, -1801801503,  ...,  -614054548,
         1809868701,     2591818], dtype=torch.int32)}, Labels: {'TIMESPENT': tensor([1.0842, 1.5843, 1.1504,  ..., 1.5378, 0.7516, 1.1207]), 'HAS_REACTIONS': tensor([0., 0., 0.,  ..., 0., 0., 0.])}


In [34]:
import math

class CollaborativeFilteringModel(nn.Module):
    def __init__(self, num_users, num_posts, embedding_dim):
        super(CollaborativeFilteringModel, self).__init__()
        self.num_users = num_users
        self.num_posts = num_posts
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.post_embedding = nn.Embedding(num_posts, embedding_dim)
        
    def forward(self, user_ids, post_ids):
        user_ids %= self.num_users
        post_ids %= self.num_posts
        user_embedded = self.user_embedding(user_ids)
        post_embedded = self.post_embedding(post_ids)
        interaction = (user_embedded * post_embedded).sum(dim=1)
        prediction = torch.sigmoid(interaction)
        return prediction

In [16]:
def dict_to_device(dic: Dict, device: str) -> Dict:
    return {k: v.to(device) for k, v in dic.items()}

In [17]:
import wandb
wandb.login(key="5558b41450b126a5fc0e47c271d4252b06e32fe6")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madensur[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/mak/.netrc


True

In [37]:
# train lööp
print(len(df["USERID"].unique()), len(df["MEDIAID"].unique()))
num_users = 50000
num_posts = 1800
embedding_dim = 4
model = CollaborativeFilteringModel(num_users=num_users, num_posts=num_posts, embedding_dim=embedding_dim).to(device)

criterion = nn.BCELoss()
lr = 1
optimizer = optim.SGD(model.parameters(), lr=lr)

num_epochs = 500
wandb.init(
      # Set the project where this run will be logged
      project="lapse-cf-pytorch-2", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"SGD_lr1", 
      # Track hyperparameters and run metadata
      config={
        "learning_rate": lr,
        "architecture": "CF_user_post",
        "epochs": num_epochs,
        "train_start": train_df["MIN_TIMESTAMP"].min(),
        "train_end": train_df["MIN_TIMESTAMP"].max(),
        "train_count": train_df.shape[0],
        "val_start": val_df["MIN_TIMESTAMP"].min(),
        "val_end": val_df["MIN_TIMESTAMP"].max(),
        "val_count": val_df.shape[0],
        "num_users": num_users,
        "num_posts": num_posts,
        "embedding_dim": embedding_dim,
        "batch_size": batch_size,
        "init": "uniform/4",
      })


for epoch in range(num_epochs):
    for batch in tqdm(train_dataloader):
        features = batch['features']
        labels = batch['labels']
        features = dict_to_device(features, device)
        labels = dict_to_device(labels, device)
        optimizer.zero_grad()
        outputs = model(features["USERID"], features["MEDIAID"])
        targets = labels["HAS_REACTIONS"]
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
    print(f'Epoch {epoch+1}, train loss: {loss.item()}')
    wandb.log({"train_loss": loss.item(), "epoch": epoch+1})

    # Validation phase
    model.eval() # Set the model to evaluation mode
    val_loss = 0.0
    all_targets = []
    all_outputs = []
    with torch.no_grad():
         for batch in val_dataloader:
            features = batch['features']
            labels = batch['labels']
            features = dict_to_device(features, device)
            labels = dict_to_device(labels, device)
            outputs = model(features["USERID"], features["MEDIAID"])
            targets = labels["HAS_REACTIONS"]
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            all_targets.extend(targets.cpu().numpy())  # Collect all targets
            all_outputs.extend(outputs.cpu().numpy())  # Collect all outputs
    auc_roc_score = roc_auc_score(all_targets, all_outputs)
    
    # Print average validation loss per epoch
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss / len(val_dataloader)}, AUC-ROC: {auc_roc_score}')
    wandb.log({"val_loss": loss.item(), "epoch": epoch+1, "roc-auc": auc_roc_score})

297677 1712




0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
roc-auc,▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇████
train_loss,█▆▅▄▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,500.0
roc-auc,0.50649
train_loss,0.6904
val_loss,0.71891


  0%|          | 0/565 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
Exception ignored in: Exception ignored in:   File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>if w.is_alive():
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
      File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
assert self._parent_pid == os.getpid(), 'can only test a child process'    self._shutdown_workers()


  File "/home/mak/.local/lib/python3.10/site-packa

Epoch 1, train loss: 0.9885744452476501



Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280><function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
Traceback (most recent call last):

Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process

  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    Traceback (most recent call last):
  File "/home/mak/.local/lib/pyt

Epoch 1, Validation Loss: 0.9751552531163986, AUC-ROC: 0.5028617860982029


  0%|          | 0/565 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
Exception ignored in:   File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
Traceback (most recent call last):
    <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>self._shutdown_workers()
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
        Exception ignored in: 
if w.is_alive():self._shutdown_workers()Traceback (most recent call last):

<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
   

Epoch 2, train loss: 0.8915073871612549



Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280><function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers

Traceback (most recent call last):
    if w.is_alive():  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__

      File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()    
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers

    AssertionErrorif w.is_alive():: can only test a child p

Epoch 2, Validation Loss: 0.9279087904381426, AUC-ROC: 0.5029866228127615


  0%|          | 0/565 [00:00<?, ?it/s]Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280><function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>

Traceback (most recent call last):
Exception ignored in: Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>Traceback (most recent call last):
    
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
self._shutdown_workers()Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
      File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in _

Epoch 3, train loss: 0.8527818322181702



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>Exception ignored in: Exception ignored in: 
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>

Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
Traceback (most recent call last):
        self._shutdown_workers()self._shutdown_workers()

  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
        if w.is_alive():if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/pr

Epoch 3, Validation Loss: 0.8893542208083688, AUC-ROC: 0.5031221955140314


  0%|          | 0/565 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Traceback (most recent call last):
  File "/home/mak/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__

    Exception ignored in: self._shutdown_workers()AssertionError
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3567850280>  File "/home/m