In [1]:
## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl
import numpy as np
from pprint import pprint
from functools import partial

from pytorch_lightning import seed_everything, LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, LearningRateMonitor, ModelCheckpoint
import pandas as pd
import torch
import os
from src.lightning_model.lit_sorting_model import LitSortingModel

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)
torch.autograd.set_detect_anomaly(True)

Global seed set to 42


cuda:0


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f21e8173640>

In [2]:
%load_ext autoreload
%autoreload 2

## Model

In [3]:
random_state = 12345
seed_everything(random_state, workers=True)

model = LitSortingModel(
    gat_head=1,
    # max_step=100000,
    feature_encoded_dim=16,
    dropout=0,
    num_node=4,
    num_train=10000,
    num_val=50,
    num_test=50,
    learning_rate=0.001
)

saved_model_path = './saved_model/lit_sorting_model/'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = Trainer(
    max_epochs=10,
    callbacks=[TQDMProgressBar(refresh_rate=100), ModelCheckpoint(monitor='val_loss', mode='min')],
    val_check_interval=0.2,
    # accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    accelerator='cpu', 
    check_val_every_n_epoch=1,
    default_root_dir=saved_model_path,
    # gradient_clip_val=1.0,
)
trainer.fit(model)

Global seed set to 12345
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(

  | Name  | Type         | Params
---------------------------------------
0 | model | SortingModel | 4.3 K 
---------------------------------------
4.3 K     Trainable params
0         Non-trainable params
4.3 K     Total params
0.017     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

training loss 0 samples: nan
Validation Set 2 samples (@epoch:0): loss=5.006947994232178, accuracy=0.375


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 5.121769109660379
Validation Set 50 samples (@epoch:0): loss=5.03189754486084, accuracy=0.3100000023841858


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.895734986716525
Validation Set 50 samples (@epoch:0): loss=4.725570201873779, accuracy=0.4950000047683716


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.73097776690552
Validation Set 50 samples (@epoch:0): loss=4.63943338394165, accuracy=0.5400000214576721


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.709050725346556
Validation Set 50 samples (@epoch:0): loss=4.552489757537842, accuracy=0.5649999976158142


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.688118024579781
Validation Set 50 samples (@epoch:0): loss=4.463769435882568, accuracy=0.6850000023841858


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.625839987805384
Validation Set 50 samples (@epoch:1): loss=4.439435958862305, accuracy=0.6200000047683716


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.6164106632952
Validation Set 50 samples (@epoch:1): loss=4.376199245452881, accuracy=0.7149999737739563


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.574834891463703
Validation Set 50 samples (@epoch:1): loss=4.351279258728027, accuracy=0.7099999785423279


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.547865849563096
Validation Set 50 samples (@epoch:1): loss=4.315001487731934, accuracy=0.7300000190734863


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.509550902450534
Validation Set 50 samples (@epoch:1): loss=4.199983596801758, accuracy=0.824999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.460482119742162
Validation Set 50 samples (@epoch:2): loss=4.2723469734191895, accuracy=0.7749999761581421


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.420215816788871
Validation Set 50 samples (@epoch:2): loss=4.409130096435547, accuracy=0.7099999785423279


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.443922955388456
Validation Set 50 samples (@epoch:2): loss=4.28354024887085, accuracy=0.7950000166893005


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.366521888152178
Validation Set 50 samples (@epoch:2): loss=4.2170281410217285, accuracy=0.824999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.319839660523922
Validation Set 50 samples (@epoch:2): loss=4.224913120269775, accuracy=0.824999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.320554247561804
Validation Set 50 samples (@epoch:3): loss=4.249591827392578, accuracy=0.8149999976158142


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.340185454279447
Validation Set 50 samples (@epoch:3): loss=4.3566484451293945, accuracy=0.7599999904632568


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.357900169194334
Validation Set 50 samples (@epoch:3): loss=4.245549201965332, accuracy=0.824999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.316133110558804
Validation Set 50 samples (@epoch:3): loss=4.221246719360352, accuracy=0.8199999928474426


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.335528484486412
Validation Set 50 samples (@epoch:3): loss=4.368467807769775, accuracy=0.7450000047683716


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.38782465143273
Validation Set 50 samples (@epoch:4): loss=4.205171585083008, accuracy=0.824999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.255166881163069
Validation Set 50 samples (@epoch:4): loss=4.11562967300415, accuracy=0.8899999856948853


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.23142329468585
Validation Set 50 samples (@epoch:4): loss=4.111280918121338, accuracy=0.9049999713897705


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.217863148427913
Validation Set 50 samples (@epoch:4): loss=4.180266857147217, accuracy=0.824999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.272381652938167
Validation Set 50 samples (@epoch:4): loss=4.291440010070801, accuracy=0.7799999713897705


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.284621783969583
Validation Set 50 samples (@epoch:5): loss=4.069979667663574, accuracy=0.9200000166893005


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.171570102024711
Validation Set 50 samples (@epoch:5): loss=4.096098899841309, accuracy=0.9100000262260437


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.157663438063327
Validation Set 50 samples (@epoch:5): loss=4.038896083831787, accuracy=0.925000011920929


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.247090940464461
Validation Set 50 samples (@epoch:5): loss=4.2870869636535645, accuracy=0.7749999761581421


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.3445559737003725
Validation Set 50 samples (@epoch:5): loss=4.3031792640686035, accuracy=0.7749999761581421


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.350030405024569
Validation Set 50 samples (@epoch:6): loss=4.265058994293213, accuracy=0.7599999904632568


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.384684399443244
Validation Set 50 samples (@epoch:6): loss=4.274550437927246, accuracy=0.7549999952316284


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.404536173766766
Validation Set 50 samples (@epoch:6): loss=4.335656642913818, accuracy=0.699999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.437758281104489
Validation Set 50 samples (@epoch:6): loss=4.303549289703369, accuracy=0.7049999833106995


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.413241245872043
Validation Set 50 samples (@epoch:6): loss=4.295295715332031, accuracy=0.7200000286102295


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.41858468170472
Validation Set 50 samples (@epoch:7): loss=4.31782341003418, accuracy=0.7049999833106995


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.37902720699142
Validation Set 50 samples (@epoch:7): loss=4.236795425415039, accuracy=0.7900000214576721


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.316944850626756
Validation Set 50 samples (@epoch:7): loss=4.141088485717773, accuracy=0.8600000143051147


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.318018833261131
Validation Set 50 samples (@epoch:7): loss=4.216789722442627, accuracy=0.8550000190734863


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.324021360017282
Validation Set 50 samples (@epoch:7): loss=4.192513465881348, accuracy=0.8550000190734863


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.307494526696008
Validation Set 50 samples (@epoch:8): loss=4.198485851287842, accuracy=0.8700000047683716


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.290646634278899
Validation Set 50 samples (@epoch:8): loss=4.251878261566162, accuracy=0.8349999785423279


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.273833579372898
Validation Set 50 samples (@epoch:8): loss=4.121134281158447, accuracy=0.9449999928474426


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.2302941379615255
Validation Set 50 samples (@epoch:8): loss=4.107246398925781, accuracy=0.9350000023841858


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.196265462453893
Validation Set 50 samples (@epoch:8): loss=4.09552001953125, accuracy=0.9449999928474426


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.1747404050753225
Validation Set 50 samples (@epoch:9): loss=4.096024036407471, accuracy=0.949999988079071


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.199683591076276
Validation Set 50 samples (@epoch:9): loss=4.080599308013916, accuracy=0.9700000286102295


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.17438329092188
Validation Set 50 samples (@epoch:9): loss=4.0733184814453125, accuracy=0.9800000190734863


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.169117197131895
Validation Set 50 samples (@epoch:9): loss=4.107851982116699, accuracy=0.9549999833106995


Validation: 0it [00:00, ?it/s]

training loss 2000 samples: 4.183241829531929
Validation Set 50 samples (@epoch:9): loss=4.104586601257324, accuracy=0.9599999785423279


`Trainer.fit` stopped: `max_epochs=10` reached.
