In [1]:
import numpy as np
from src.lightning_model.lit_sorting_model import LitSortingModel
# from src.dataset.number_sorting import convert_pred_h
import torch.nn.functional as F
from src.dataset.multidigit_mnist import MultiDigitMNISTSplits
from pytorch_lightning import seed_everything, Trainer
import pandas as pd
from src.lightning_model.lit_multidigit_mnist_model import LitMultiDigitMNISTSortingModel
from src.lightning_model.lit_sorting_model import LitSortingModel
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, TQDMProgressBar



In [3]:
random_state = 12345
seed_everything(random_state, workers=True)

lit_sorting_model = LitSortingModel.load_from_checkpoint(
    'saved_model/lit_sorting_model/lightning_logs/version_2/checkpoints/epoch=5-step=2240.ckpt'
)

mnist_sort_model = LitMultiDigitMNISTSortingModel(
    lit_sorting_model,
    learning_rate=0.001,
    batch_size=100
)

saved_model_path = './saved_model/lit_mnist_sorting_model/'

trainer = Trainer(
    max_epochs=20,
    callbacks=[TQDMProgressBar(refresh_rate=200), ModelCheckpoint(monitor='val_loss', mode='min')],
    val_check_interval=0.2,
    # accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    accelerator='cpu', 
    check_val_every_n_epoch=1,
    default_root_dir=saved_model_path,
)

mnist_dataset = MultiDigitMNISTSplits(
    'mnist',
    num_compare=4,
    num_train_list=55000,
    num_val_list=1000,
    num_test_list=50,
    seed=0
)

train_loader = mnist_dataset.get_train_loader(1)
val_loader = mnist_dataset.get_valid_loader(1)

trainer.fit(
    mnist_sort_model,
    train_loader,
    val_loader
)

Global seed set to 12345
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name                     | Type                | Params
-----------------------------------------------------------------
0 | pretrained_sorting_model | LitSortingModel     | 4.3 K 
1 | cnn_model                | MultiDigitMNISTConv | 227 K 
-----------------------------------------------------------------
227 K     Trainable params
4.3 K     Non-trainable params
231 K     Total params
0.926     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Train Set Loss = nan / Validation Set (@epoch:0): loss=0.7436648607254028, accuracy=0.25


  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.162320852279663 / Validation Set (@epoch:0): loss=1.1514906883239746, accuracy=0.25200000405311584


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.1231526136398315 / Validation Set (@epoch:0): loss=1.0858745574951172, accuracy=0.24175000190734863


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.1386632919311523 / Validation Set (@epoch:0): loss=1.1862174272537231, accuracy=0.25450000166893005


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.1649094820022583 / Validation Set (@epoch:0): loss=1.1755807399749756, accuracy=0.25975000858306885


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.167538046836853 / Validation Set (@epoch:0): loss=1.1909242868423462, accuracy=0.24150000512599945


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.1536585092544556 / Validation Set (@epoch:1): loss=1.1826906204223633, accuracy=0.25


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.1483746767044067 / Validation Set (@epoch:1): loss=1.0925047397613525, accuracy=0.24574999511241913


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0895229578018188 / Validation Set (@epoch:1): loss=1.1028512716293335, accuracy=0.24899999797344208


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.081525206565857 / Validation Set (@epoch:1): loss=1.0854476690292358, accuracy=0.2460000067949295


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.065507173538208 / Validation Set (@epoch:1): loss=1.0486561059951782, accuracy=0.24124999344348907


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0688773393630981 / Validation Set (@epoch:2): loss=1.0471723079681396, accuracy=0.25049999356269836


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0460251569747925 / Validation Set (@epoch:2): loss=1.0593981742858887, accuracy=0.2447499930858612


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0430772304534912 / Validation Set (@epoch:2): loss=1.0237504243850708, accuracy=0.2750000059604645


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.058922529220581 / Validation Set (@epoch:2): loss=1.0953967571258545, accuracy=0.24400000274181366


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0697100162506104 / Validation Set (@epoch:2): loss=1.0492444038391113, accuracy=0.24525000154972076


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0568435192108154 / Validation Set (@epoch:3): loss=1.0983325242996216, accuracy=0.23874999582767487


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.050767183303833 / Validation Set (@epoch:3): loss=1.0281959772109985, accuracy=0.2542499899864197


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.022487998008728 / Validation Set (@epoch:3): loss=1.0260131359100342, accuracy=0.25850000977516174


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0358749628067017 / Validation Set (@epoch:3): loss=1.018021821975708, accuracy=0.2407499998807907


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0531467199325562 / Validation Set (@epoch:3): loss=1.0282286405563354, accuracy=0.257750004529953


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0090899467468262 / Validation Set (@epoch:4): loss=0.972534716129303, accuracy=0.24799999594688416


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0068607330322266 / Validation Set (@epoch:4): loss=0.9854021072387695, accuracy=0.25475001335144043


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0022836923599243 / Validation Set (@epoch:4): loss=0.9881818890571594, accuracy=0.24300000071525574


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.025205135345459 / Validation Set (@epoch:4): loss=0.9779621362686157, accuracy=0.25600001215934753


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0132534503936768 / Validation Set (@epoch:4): loss=1.0149028301239014, accuracy=0.24824999272823334


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0270249843597412 / Validation Set (@epoch:5): loss=1.0355032682418823, accuracy=0.26625001430511475


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0436160564422607 / Validation Set (@epoch:5): loss=1.0153028964996338, accuracy=0.24050000309944153


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0563464164733887 / Validation Set (@epoch:5): loss=0.9949579834938049, accuracy=0.25575000047683716


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0191929340362549 / Validation Set (@epoch:5): loss=0.9700120687484741, accuracy=0.2472500056028366


Validation: 0it [00:00, ?it/s]

Train Set Loss = 1.0033875703811646 / Validation Set (@epoch:5): loss=0.9439972639083862, accuracy=0.25325000286102295


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.9776747226715088 / Validation Set (@epoch:6): loss=0.9560860991477966, accuracy=0.2462500035762787


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.9616147875785828 / Validation Set (@epoch:6): loss=0.9234535098075867, accuracy=0.257750004529953


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.9358506798744202 / Validation Set (@epoch:6): loss=0.8754869103431702, accuracy=0.2447499930858612


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.9281533360481262 / Validation Set (@epoch:6): loss=0.8812254667282104, accuracy=0.2590000033378601


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.9247077703475952 / Validation Set (@epoch:6): loss=0.8766561150550842, accuracy=0.23649999499320984


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8799986839294434 / Validation Set (@epoch:7): loss=0.8767831921577454, accuracy=0.23675000667572021


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8705244064331055 / Validation Set (@epoch:7): loss=0.820222795009613, accuracy=0.24950000643730164


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.858698308467865 / Validation Set (@epoch:7): loss=0.8061467409133911, accuracy=0.25450000166893005


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8568124771118164 / Validation Set (@epoch:7): loss=0.7819609642028809, accuracy=0.24774999916553497


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8506667017936707 / Validation Set (@epoch:7): loss=0.8258043527603149, accuracy=0.24824999272823334


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8488689661026001 / Validation Set (@epoch:8): loss=0.8636955618858337, accuracy=0.2540000081062317


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8301978707313538 / Validation Set (@epoch:8): loss=0.7846968770027161, accuracy=0.24199999868869781


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8324891924858093 / Validation Set (@epoch:8): loss=0.7937164306640625, accuracy=0.2502500116825104


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8460924029350281 / Validation Set (@epoch:8): loss=0.8250954747200012, accuracy=0.2529999911785126


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.7965928912162781 / Validation Set (@epoch:8): loss=0.7622081637382507, accuracy=0.24924999475479126


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.7924739122390747 / Validation Set (@epoch:9): loss=0.7642133831977844, accuracy=0.2502500116825104


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.7820923924446106 / Validation Set (@epoch:9): loss=0.7767409682273865, accuracy=0.2515000104904175


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.782423198223114 / Validation Set (@epoch:9): loss=0.7584735155105591, accuracy=0.2370000034570694


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8011384010314941 / Validation Set (@epoch:9): loss=0.7575919032096863, accuracy=0.25099998712539673


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8313211798667908 / Validation Set (@epoch:9): loss=0.7969219088554382, accuracy=0.23475000262260437


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8212845325469971 / Validation Set (@epoch:10): loss=0.7770155072212219, accuracy=0.23874999582767487


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.8065509796142578 / Validation Set (@epoch:10): loss=0.7395159006118774, accuracy=0.2667500078678131


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.761134684085846 / Validation Set (@epoch:10): loss=0.6965093612670898, accuracy=0.24400000274181366


Validation: 0it [00:00, ?it/s]

Train Set Loss = 0.7421726584434509 / Validation Set (@epoch:10): loss=0.6916849613189697, accuracy=0.25224998593330383
