# Run small convnet

In [None]:
# Train model
from model import NN
from customDataModule import CustomDataModule
import config
import pytorch_lightning as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from pytorchModel import *
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

seed_everything(
    42, workers=True
)  # By setting workers=True in seed_everything(), Lightning derives unique seeds across all dataloader workers and processes for torch, numpy and stdlib random number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers.

callbacks = [ModelCheckpoint(
    save_top_k=1, mode='min', monitor="val_loss")]  # save top 1 model 


if __name__ == "__main__":
    
    pytorch_model = pytorchModel(num_classes=config.NUM_CLASSES)
    
    logger = TensorBoardLogger("tb_logs", name="small_m") # tb_logs is the folder, name is the name of the experiment/model
    logger2 = CSVLogger(save_dir="logs/", name="small_m")
    
    model = NN(
        model=pytorch_model,
        input_size=config.IN_CHANNELS,
        num_classes=config.NUM_CLASSES,
        learning_rate=config.LEARNING_RATE,
    )  # .to(device)
    dm = CustomDataModule(
        data_dir=config.DATA_DIR,
        train_csv=config.TRAIN_CSV_1,
        val_csv=config.VAL_CSV_1,
        test_csv=config.TEST_CSV,
        batch_size=config.BATCH_SIZE,
        num_workers=config.NUM_WORKERS,
        mean=config.MEAN,
        std=config.STD
    )
    trainer = pl.Trainer(
        logger=[logger2, logger],
        accelerator=config.ACCELERATOR,
        devices=config.DEVICES,
        min_epochs=config.MIN_EPOCHS,
        max_epochs=config.MAX_EPOCHS,
        deterministic=config.DETERMINISTIC,
        callbacks=callbacks
    )  # deterministic ensures random seed reproducibility

    trainer.fit(model, dm)  # it will automatically know which dataloader to use

# A general place to start is to set num_workers equal to the number of CPU cores on that machine. You can get the number of CPU cores in python using os.cpu_count(), but note that depending on your batch size, you may overflow RAM memory.


Global seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/small_m
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | pytorchModel     | 878 K 
1 | loss_fn   | CrossEntropyLoss | 0     
2 | train_acc | BinaryAccuracy   | 0     
3 | val_acc   | BinaryAccuracy   | 0     
-----------------------------------------------
878 K     Trainable params
0         Non-trainable params
878 K     Total params
3.514     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]torch.Size([32, 3, 224, 224])
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:02<00:02,  2.05s/it]torch.Size([32, 3, 224, 224])
Epoch 0:   0%|          | 0/321 [00:00<?, ?it/s]                           torch.Size([32, 3, 224, 224])
Epoch 0:   0%|          | 1/321 [00:09<50:25,  9.46s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   1%|          | 2/321 [00:09<25:25,  4.78s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   1%|          | 3/321 [00:09<17:04,  3.22s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   1%|          | 4/321 [00:09<12:53,  2.44s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   2%|▏         | 5/321 [00:10<10:43,  2.04s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   2%|▏         | 6/321 [00:10<08:59,  1.71s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   2%|▏         | 7/321 [00:10<07:45,  1.48s/it, v_num=0]torch.Size([32, 3, 224, 224])
Epoch 0:   2%|▏         | 8/

In [5]:
model

NN(
  (model): pytorchModel(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (output_layer): Linear(in_features=401408, out_features=2, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
  (train_acc): BinaryAccuracy()
  (val_acc): BinaryAccuracy()
)

In [6]:
path = trainer.checkpoint_callback.best_model_path
print(path)

logs/small_m\version_0\checkpoints\epoch=1-step=642.ckpt


In [7]:
trainer.validate(model=model, datamodule=dm, ckpt_path=path)

Restoring states from the checkpoint path at logs/small_m\version_0\checkpoints\epoch=1-step=642.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at logs/small_m\version_0\checkpoints\epoch=1-step=642.ckpt


Validation DataLoader 0:   0%|          | 0/88 [00:00<?, ?it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   1%|          | 1/88 [00:00<00:16,  5.36it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   2%|▏         | 2/88 [00:00<00:14,  6.10it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   3%|▎         | 3/88 [00:00<00:10,  8.13it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   5%|▍         | 4/88 [00:00<00:17,  4.92it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   6%|▌         | 5/88 [00:01<00:18,  4.46it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   7%|▋         | 6/88 [00:01<00:23,  3.44it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   8%|▊         | 7/88 [00:01<00:20,  3.91it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:   9%|▉         | 8/88 [00:02<00:21,  3.80it/s]torch.Size([32, 3, 224, 224])
Validation DataLoader 0:  10%|█         | 9/88 [00:02<00:20,  3.95it/s]torch.Size([32, 3, 224, 224]

[{'val_loss': 0.1731247454881668, 'val_accuracy': 0.9645415544509888}]