In [1]:
import torch
from lightning import Trainer
from dataloader import InsectDatamodule
from model_20 import ResNet
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
import os
import yaml

In [3]:
# initialize the datamodule and the model

train_new_model = True

# log directory
save_dir='./lightning_logs/'
sub_dir='all_data'
version='version06'

# select Dataset
csv_paths = ['../data/Cicadidae.csv', '../data/Orthoptera.csv']
# csv_paths = ['../data/Orthoptera.csv']

# parameters

batch_size = 10
num_workers = 0

n_fft = 1024
n_mels = None
top_db = None

patience = 30

in_channels=1
base_channels=8
kernel_size=3
n_max_pool=3
n_res_blocks=4
learning_rate=0.001

log_every_n_steps=20

if train_new_model:
    # create log directory
    log_dir = f'{save_dir}/{sub_dir}/{version}'
    if os.path.exists(log_dir):
        raise FileExistsError(f'{log_dir} already exists. Please change the version.')
    else:
        os.makedirs(log_dir)

    parameters = {
        'csv_paths': csv_paths,
        'batch_size': batch_size,
        'num_workers': num_workers,
        'n_fft': n_fft,
        'n_mels': n_mels,
        'top_db': top_db,
        'patience': patience,
        'in_channels': in_channels,
        'base_channels': base_channels,
        'kernel_size': kernel_size,
        'n_max_pool': n_max_pool,
        'n_res_blocks': n_res_blocks,
        'learning_rate': learning_rate,
        'log_every_n_steps': log_every_n_steps
    }

    # Write parameters to a YAML file
    with open(f'{log_dir}/all_parameters.yaml', 'w') as file:
        yaml.dump(parameters, file)

datamodule = InsectDatamodule(
    csv_paths = csv_paths,
    batch_size = batch_size,
    num_workers = num_workers,
    n_fft = n_fft,
    n_mels = n_mels,
    top_db = top_db)

resnet = ResNet(
    in_channels=in_channels,
    base_channels=base_channels,
    kernel_size=kernel_size,
    n_max_pool=n_max_pool,
    n_res_blocks=n_res_blocks,
    num_classes=datamodule.num_classes,
    learning_rate=learning_rate,
    class_weights=datamodule.class_weights)

logger = TensorBoardLogger(
    save_dir=save_dir,
    name=sub_dir,
    version=version,  # You can customize this
)

trainer = Trainer(
    logger=logger,
    log_every_n_steps=log_every_n_steps,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=patience),
        ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min', filename='best'),
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
trainer.fit(
    resnet,
    train_dataloaders=datamodule.train_dataloader(),
    val_dataloaders=datamodule.val_dataloader()
)

trainer.test(ckpt_path='best', dataloaders=datamodule.test_dataloader())

c:\Users\kraft\.conda\envs\torch_cuda\Lib\site-packages\lightning\pytorch\loops\utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type              | Params
------------------------------------------------------------
0 | conv1                 | Conv2d            | 80    
1 | batchnorm1            | BatchNorm2d       | 16    
2 | relu                  | ReLU              | 0     
3 | res_blocks            | Sequential        | 306 K 
4 | avgpool               | AdaptiveAvgPool2d | 0     
5 | convout               | Conv2d            | 4.1 K 
6 | softmax               | Softmax           | 0     
7 | cross_entropy_loss_fn | CrossEntropyLoss  | 0     
------------------------------------------------------------
310 K     Trainable params
0         Non-trainable params
310 K     Total params
1.242     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\kraft\.conda\envs\torch_cuda\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kraft\.conda\envs\torch_cuda\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 1:  29%|██▊       | 6/21 [00:07<00:19,  0.79it/s, v_num=on06] 

c:\Users\kraft\.conda\envs\torch_cuda\Lib\site-packages\lightning\pytorch\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
Restoring states from the checkpoint path at ./lightning_logs/all_data\version06\checkpoints\best.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./lightning_logs/all_data\version06\checkpoints\best.ckpt
c:\Users\kraft\.conda\envs\torch_cuda\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
trainer.test(ckpt_path='best', dataloaders=datamodule.test_dataloader())

In [2]:
csv_paths = ['../data/Cicadidae.csv', '../data/Orthoptera.csv']

batch_size = 10
num_workers = 0

n_fft = 1024
n_mels = None
top_db = None


datamodule = InsectDatamodule(
    csv_paths = csv_paths,
    batch_size = batch_size,
    num_workers = num_workers,
    n_fft = n_fft,
    n_mels = n_mels,
    top_db = top_db)

ckpt_path = './lightning_logs/all_data/version06/checkpoints/best.ckpt'

resnet = ResNet.load_from_checkpoint(checkpoint_path=ckpt_path)

# resnet = ResNet.load_from_checkpoint('./lightning_logs/all_data/version04/checkpoints/best.ckpt', in_channels=1)