In [8]:
import torch
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger
from dataloader import InsectDatamodule
from model_20 import ResNet
import os
import numpy as np
import pandas as pd
import yaml

from utils import PredictionWriter

In [9]:
# initialize the datamodule and the model

train_new_model = False

# log directory
save_dir='./lightning_logs/'
sub_dir='all_data'
version='version06'

# select Dataset
# csv_paths = ['../data/Cicadidae.csv', '../data/Orthoptera.csv']
csv_paths = ['../data/Orthoptera.csv']

# parameters

batch_size = 10
num_workers = 0

n_fft = 1024
n_mels = None
top_db = None

patience = 30

in_channels=1
base_channels=8
kernel_size=3
n_max_pool=3
n_res_blocks=4
learning_rate=0.001

log_every_n_steps=20

# create log directory
log_dir = f'{save_dir}/{sub_dir}/{version}'

if train_new_model:
    if os.path.exists(log_dir):
        raise FileExistsError(f'{log_dir} already exists. Please change the version.')
    else:
        os.makedirs(log_dir)

    parameters = {
        'csv_paths': csv_paths,
        'batch_size': batch_size,
        'num_workers': num_workers,
        'n_fft': n_fft,
        'n_mels': n_mels,
        'top_db': top_db,
        'patience': patience,
        'in_channels': in_channels,
        'base_channels': base_channels,
        'kernel_size': kernel_size,
        'n_max_pool': n_max_pool,
        'n_res_blocks': n_res_blocks,
        'learning_rate': learning_rate,
        'log_every_n_steps': log_every_n_steps
    }

    # Write parameters to a YAML file
    with open(f'{log_dir}/all_parameters.yaml', 'w') as file:
        yaml.dump(parameters, file)

datamodule = InsectDatamodule(
    csv_paths = csv_paths,
    batch_size = batch_size,
    num_workers = num_workers,
    n_fft = n_fft,
    n_mels = n_mels,
    top_db = top_db)

resnet = ResNet(
    in_channels=in_channels,
    base_channels=base_channels,
    kernel_size=kernel_size,
    n_max_pool=n_max_pool,
    n_res_blocks=n_res_blocks,
    num_classes=datamodule.num_classes,
    learning_rate=learning_rate,
    class_weights=datamodule.class_weights)

tb_logger = TensorBoardLogger(
    save_dir=save_dir,
    name=sub_dir,
    version=version,  # You can customize this
)

csv_logger = CSVLogger(
    save_dir=save_dir,
    name=sub_dir,
    version=version,  # You can customize this
)

trainer = Trainer(
    logger=[tb_logger, csv_logger],
    log_every_n_steps=log_every_n_steps,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=patience),
        ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min', filename='best'),
        PredictionWriter(),
    ]
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(
    resnet,
    train_dataloaders=datamodule.train_dataloader(),
    val_dataloaders=datamodule.val_dataloader()
)

# trainer.test(ckpt_path='best', dataloaders=datamodule.test_dataloader())

/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory ./lightning_logs/all_data/version06 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory ./lightning_logs/all_data/version06/checkpoints exists and is not empty.

  | Name                  | Type              | Params
------------------------------------------------------------
0 | conv1                 | Conv2d            | 80    
1 | batchnorm1            | BatchNorm2d       | 16    
2 | relu                  | ReLU              | 0  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=20). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 3:  10%|█         | 1/10 [00:00<00:06,  1.46it/s, v_num=on06] 

/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [7]:
trainer.test(ckpt_path='best', dataloaders=datamodule.test_dataloader())

Restoring states from the checkpoint path at ./lightning_logs/all_data/version06/checkpoints/best-v3.ckpt
Loaded model weights from the checkpoint at ./lightning_logs/all_data/version06/checkpoints/best-v3.ckpt
/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 8/8 [00:38<00:00,  0.21it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.06756756454706192
        test_loss           0.6842144131660461
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.6842144131660461, 'test_acc': 0.06756756454706192}]

1) Dataset gibt nun ein index zurück. Das nutzten wir später, um prediction x mit dem CSV zu verknuepfen.
2) Im predict_step brauchen wir nun die predictions und den index.
3) `trainer.predict()` iteriert den prediction loader und gibt eine liste von predictions zurück.
4) Nun können wir im dataset CSV eine neue Spalte machen, und die prediction jeweils zum richtigen index schreiben.

In [11]:
trainer.predict(resnet, datamodule.predict_dataloader(), return_predictions=False)

/Users/kraftb/miniforge3/envs/torch_insect/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 143/143 [00:42<00:00,  3.33it/s]
