# Link Prediction for Heterogeneous graph

## 0. Enviroment setup

In [1]:
# !pip uninstall torch torchvision torchaudio --yes
# !pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
# !pip install lightning torch_geometric
# !pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.2.0+cu121.html
# !pip install wandb

In [2]:
import os
import shutil
import wandb
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

from torch_geometric.utils import negative_sampling
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
# from torch_geometric.loader import LinkNeighborLoader
# from torch_geometric.data.lightning import LightningLinkData

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

In [3]:
from hive_analysis.models.homo_link_prediction import *
from hive_analysis.dataloaders import hive_preprocessing

In [4]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [5]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## 1. Data Pre-processing

In [6]:
DATA_VERSION = 'final_v1'

In [7]:
# data = hive_preprocessing(
#     f'dataset/hive/{DATA_VERSION}/nodes_labelled.csv',
#     f'dataset/hive/{DATA_VERSION}/edges_labelled.csv',
#     to_undirected = True,
# )
# torch.save(data, f'dataset/hive/{DATA_VERSION}/hive.pt')
data = torch.load(f'dataset/hive/{DATA_VERSION}/hive.pt')
data

HeteroData(
  user={
    x=[18645, 5],
    y=[18645],
  },
  comment={
    x=[125111, 1],
    y=[125111],
  },
  post={
    x=[13540, 1],
    y=[13540],
  },
  (user, upvote, comment)={
    edge_index=[2, 423638],
    edge_attr=[423638, 1],
    y=[423638],
  },
  (user, upvote, post)={
    edge_index=[2, 554131],
    edge_attr=[554131, 1],
    y=[554131],
  },
  (user, write, comment)={
    edge_index=[2, 78696],
    edge_attr=[78696, 1],
    y=[78696],
  },
  (user, write, post)={
    edge_index=[2, 12958],
    edge_attr=[12958, 1],
    y=[12958],
  },
  (user, downvote, comment)={
    edge_index=[2, 6819],
    edge_attr=[6819, 1],
    y=[6819],
  },
  (user, downvote, post)={
    edge_index=[2, 2934],
    edge_attr=[2934, 1],
    y=[2934],
  },
  (comment, belong_to, comment)={
    edge_index=[2, 58911],
    edge_attr=[58911, 1],
    y=[58911],
  },
  (comment, belong_to, post)={
    edge_index=[2, 19838],
    edge_attr=[19838, 1],
    y=[19838],
  },
  (comment, rev_upvote, user)={


In [8]:
homo_data = data.to_homogeneous()

In [17]:
homo_data

Data(edge_index=[2, 2315850], x=[157296, 5], y=[2315850], edge_attr=[2315850, 1], node_type=[157296], edge_type=[2315850])

In [9]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3, 
    add_negative_train_samples=True,
    neg_sampling_ratio=2.0,
)

train_data, val_data, test_data = transform(homo_data)

In [10]:
train_data

Data(edge_index=[2, 1296876], x=[157296, 5], y=[1296876], edge_attr=[1296876, 1], node_type=[157296], edge_type=[1296876], edge_label=[1667412], edge_label_index=[2, 1667412])

## 2. Training

In [11]:
models = {
    'GraphConv': GraphConvNet, 
    'GATv2': GATv2, 
    'GraphSAGE': GraphSAGE, 
    'GAT':GAT
}

In [12]:
models = { k: m(
    in_channels=-1,  
    out_channels=128,
    hidden_channels=[64, 128, 256, 256, 512], 
    # aggr_scheme='mean',
) for k, m in models.items()}

In [13]:
class GraphDataset(Dataset):
    def __init__(
        self,
        data,
        key='edge_label',
    ):
        self.data = data
        self.key = key

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.data, self.key
    
def collate_fn(input):
    data, key = zip(*input)
    return data[0], key[0]

train_loader = DataLoader(
    GraphDataset(train_data),
    batch_size=1,
    shuffle=True,
    drop_last=False,
    pin_memory=True,
    num_workers=4,
    collate_fn=collate_fn,
)
val_loader = DataLoader(
    GraphDataset(val_data),
    batch_size=1,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
    num_workers=4,
    collate_fn=collate_fn
)

In [14]:
# for edge_types, rev_edge_types in edges:
for mtype, model in models.items():
    log_dir = 'results/log/homo/lp/' + mtype.lower()
    loss_checkpoint_dir = f'results/checkpoints/homo/lp/{mtype.lower()}/loss'
    auc_checkpoint_dir = f'results/checkpoints/homo/lp/{mtype.lower()}/roc_auc'
    acc_checkpoint_dir = f'results/checkpoints/homo/lp/{mtype.lower()}/acc'

    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(loss_checkpoint_dir, exist_ok=True)
    os.makedirs(auc_checkpoint_dir, exist_ok=True)
    os.makedirs(acc_checkpoint_dir, exist_ok=True)


    lr = 1e-3
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    model.set_optimizer(optim)

    wandb_logger = WandbLogger(
        project="Homo_LinkPrediction_finalv1",
        log_model=True,
        save_dir=log_dir,
        name=mtype,
        entity='ssc_project'

    )

    loss_checkpoint_callback = ModelCheckpoint(
        monitor=f'val_loss',
        dirpath=loss_checkpoint_dir,
        filename='HomoLinkPred-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        save_last=True,
        mode='min',
        every_n_epochs=1
    )
    roc_auc_checkpoint_callback = ModelCheckpoint(
        monitor=f'val_roc_auc',
        dirpath=auc_checkpoint_dir,
        filename='HomoLinkPred-{epoch:02d}-{val_roc_auc:.2f}',
        save_top_k=3,
        save_last=True,
        mode='max',
        every_n_epochs=1
    )
    acc_checkpoint_callback = ModelCheckpoint(
        monitor=f'val_accuracy',
        dirpath=acc_checkpoint_dir,
        filename='HomoLinkPred-{epoch:02d}-{val_accuracy:.2f}',
        save_top_k=3,
        save_last=True,
        mode='max',
        every_n_epochs=1
    )

    trainer = L.Trainer(
        max_epochs=500,
        check_val_every_n_epoch=10,
        callbacks=[
            loss_checkpoint_callback, 
            roc_auc_checkpoint_callback,
            acc_checkpoint_callback,
        ],
        logger=wandb_logger,
        log_every_n_steps=4
    )


    trainer.fit(model, train_loader, val_loader)
    wandb.finish()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Currently logged in as: [33mhontrn9122[0m ([33mssc_project[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.9/dist-packages/lightning/pytorch/utilities/model_summary/model_summary.py:454: A layer with UninitializedParameter was found. Thus, the total number of parameters detected may be inaccurate.

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Sequential        | 607 K 
1 | crit    | BCEWithLogitsLoss | 0     
----------------------------------------------
607 K     Trainable params
0         Non-trainable params
607 K     Total params
2.430     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=4). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_accuracy,▁▆▇▅▄▆▄▅▆▆▆▆▇▆▆▆▇▆▆▆▇▇▇▇▆▆▇▇▆▇▇██▇▆████▇
val_f1,▁▆▇▅▄▆▄▅▆▆▆▆▇▆▆▆▇▆▆▆▇▇▇▇▆▆▇▇▆▇▇██▇▆████▇
val_loss,█▃▂▄▅▂▅▄▃▄▃▃▂▄▃▄▂▃▄▃▃▂▃▃▃▃▂▂▃▂▂▁▁▂▃▂▁▂▁▂
val_precision,▁▆▇▅▄▆▄▅▆▅▆▆▇▆▆▆▇▆▆▆▇▇▆▇▆▆▇▇▆▇▇██▇▆▇█▇█▇
val_recall,▁▃▂▅▇▅█▇▇███▇███▇▇██▇▇▇▇▇█▇▇█▇█▇▇▇█▇▇█▇█
val_roc_auc,▁▆▇▅▄▆▅▆▆▆▆▆▇▆▆▆▇▆▆▆▇▇▇▇▇▇▇▇▆▇▇██▇▇█████

0,1
epoch,499.0
train_loss,0.59811
trainer/global_step,499.0
val_accuracy,0.657
val_f1,0.657
val_loss,0.63871
val_precision,0.49255
val_recall,0.95811
val_roc_auc,0.73228


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669420703935126, max=1.0…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Sequential        | 611 K 
1 | crit    | BCEWithLogitsLoss | 0     
----------------------------------------------
611 K     Trainable params
0         Non-trainable params
611 K     Total params
2.446     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=4). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▆▅▃▃▃▂▂▂▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_accuracy,▆▆▄▇▆▄▅▅▅█▇▇████▃▁▄▄▅▄▄▅▇███████▇▇▇▇▆▆▆▆
val_f1,▆▆▄▇▆▄▅▅▅█▇▇████▃▁▄▄▅▄▄▅▇███████▇▇▇▇▆▆▆▆
val_loss,▄▂▄▁▃▅▄▄▄▃▄▄▃▄▃▂▆█▆▆▄▅▆▅▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃
val_precision,▆▇▅▇▇▅▆▆▅█▇▇████▃▁▅▅▅▅▄▆████████▇▇▇▇▇▇▇▆
val_recall,██▄█▇▆▇▇▆██████▆▃▁▇▇▅▅▄██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val_roc_auc,▇▇▄█▇▅▆▆▅██████▇▃▁▅▅▅▅▄▆████████▇▇▇▇▇▇▇▇

0,1
epoch,499.0
train_loss,0.6163
trainer/global_step,499.0
val_accuracy,0.56272
val_f1,0.56272
val_loss,0.71395
val_precision,0.42562
val_recall,0.89223
val_roc_auc,0.6451


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666948914838334, max=1.0)…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Sequential        | 607 K 
1 | crit    | BCEWithLogitsLoss | 0     
----------------------------------------------
607 K     Trainable params
0         Non-trainable params
607 K     Total params
2.430     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=4). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_accuracy,▁▄▄▄▁▂▁▄▃▄▃▄▄▄▄▅▄▅▅▆▆▆▇▆▆▇▆▇▇▇▆▆▇▇▇▇▇▇█▇
val_f1,▁▄▄▄▁▂▁▄▃▄▃▄▄▄▄▅▄▅▅▆▆▆▇▆▆▇▆▇▇▇▆▆▇▇▇▇▇▇█▇
val_loss,▇▄▅▄▇▆█▅▆▅▆▅▅▅▆▄▅▄▄▃▃▃▃▃▃▂▃▂▂▂▃▂▂▂▂▂▂▂▁▃
val_precision,▁▃▄▄▁▂▁▄▃▄▃▄▄▄▃▅▄▅▅▆▆▆▆▆▆▇▆▇▇▆▆▆▇▇▇▇▇▇█▇
val_recall,▁▂▄▄▇▇█▇███████▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val_roc_auc,▁▃▄▄▂▃▂▅▄▄▄▅▅▅▄▅▅▅▆▆▆▆▇▆▆▇▇▇▇▇▆▇▇█▇▇▇▇█▇

0,1
epoch,499.0
train_loss,0.59795
trainer/global_step,499.0
val_accuracy,0.66653
val_f1,0.66653
val_loss,0.6369
val_precision,0.49989
val_recall,0.95757
val_roc_auc,0.73929


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668892333594462, max=1.0…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params
----------------------------------------------
0 | encoder | Sequential        | 307 K 
1 | crit    | BCEWithLogitsLoss | 0     
----------------------------------------------
307 K     Trainable params
0         Non-trainable params
307 K     Total params
1.229     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=4). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=500` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,▄▃▃▂▃▂▂▂▃▃▃▂▂▂▂▂▂▂▃▂▂▁▂▁▁▃█▆▅▅▅▅▅▃▂▂▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_accuracy,▂▃▃▄▃▃▅▆▄▁▃▆▆▃▃▃▄▂▄▄▃▄▆▃▆▁▃▄▅▅▅▅▄▅▄▄▄▄▅█
val_f1,▂▃▃▄▃▃▅▆▄▁▃▆▆▃▃▃▄▂▄▄▃▄▆▃▆▁▃▄▅▅▅▅▄▅▄▄▄▄▅█
val_loss,█▅▆▃██▄▃▅▇█▃▃▇▆▇▇▇▅▅▆▄▃▆▃▇▆▆▅▅▄▅▅▄▄▄▆▇▄▁
val_precision,▄▄▄▅▃▄▅▆▅▁▄▆▆▅▃▅▅▄▃▃▁▃▆▂▆▃▂▃▄▄▅▄▄▅▄▃▂▂▅█
val_recall,▇███▅▅▆▇█▂█▆▆█▄███▃▃▁▂▇▂▇▆▂▂▄▃▅▄▃▆▃▂▁▁▃▆
val_roc_auc,▄▅▅▆▄▄▆▇▆▁▅▆▇▆▃▆▆▅▃▃▂▄▇▃▇▃▃▃▄▄▆▄▄▆▄▃▃▂▅█

0,1
epoch,499.0
train_loss,0.63522
trainer/global_step,499.0
val_accuracy,0.69643
val_f1,0.69643
val_loss,0.6664
val_precision,0.53113
val_recall,0.7617
val_roc_auc,0.71275


In [15]:
print('done')

done
