In [None]:
import torch
from lightning import Trainer
from dataloader import InsectDatamodule
from model_20 import ResNet
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
import os
import yaml

In [None]:
predictions = trainer.predict(resnet, datamodule.predict_dataloader())

In [None]:
# initialize the datamodule and the model

train_new_model = False

# log directory
save_dir='./lightning_logs/'
sub_dir='all_data'
version='version06'

# select Dataset
csv_paths = ['../data/Cicadidae.csv', '../data/Orthoptera.csv']
# csv_paths = ['../data/Orthoptera.csv']

# parameters

batch_size = 10
num_workers = 0

n_fft = 1024
n_mels = None
top_db = None

patience = 30

in_channels=1
base_channels=8
kernel_size=3
n_max_pool=3
n_res_blocks=4
learning_rate=0.001

log_every_n_steps=20

if train_new_model:
    # create log directory
    log_dir = f'{save_dir}/{sub_dir}/{version}'
    if os.path.exists(log_dir):
        raise FileExistsError(f'{log_dir} already exists. Please change the version.')
    else:
        os.makedirs(log_dir)

    parameters = {
        'csv_paths': csv_paths,
        'batch_size': batch_size,
        'num_workers': num_workers,
        'n_fft': n_fft,
        'n_mels': n_mels,
        'top_db': top_db,
        'patience': patience,
        'in_channels': in_channels,
        'base_channels': base_channels,
        'kernel_size': kernel_size,
        'n_max_pool': n_max_pool,
        'n_res_blocks': n_res_blocks,
        'learning_rate': learning_rate,
        'log_every_n_steps': log_every_n_steps
    }

    # Write parameters to a YAML file
    with open(f'{log_dir}/all_parameters.yaml', 'w') as file:
        yaml.dump(parameters, file)

datamodule = InsectDatamodule(
    csv_paths = csv_paths,
    batch_size = batch_size,
    num_workers = num_workers,
    n_fft = n_fft,
    n_mels = n_mels,
    top_db = top_db)

resnet = ResNet(
    in_channels=in_channels,
    base_channels=base_channels,
    kernel_size=kernel_size,
    n_max_pool=n_max_pool,
    n_res_blocks=n_res_blocks,
    num_classes=datamodule.num_classes,
    learning_rate=learning_rate,
    class_weights=datamodule.class_weights)

logger = TensorBoardLogger(
    save_dir=save_dir,
    name=sub_dir,
    version=version,  # You can customize this
)

trainer = Trainer(
    logger=logger,
    log_every_n_steps=log_every_n_steps,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=patience),
        ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min', filename='best'),
    ]
)

In [None]:
trainer.fit(
    resnet,
    train_dataloaders=datamodule.train_dataloader(),
    val_dataloaders=datamodule.val_dataloader()
)

# trainer.test(ckpt_path='best', dataloaders=datamodule.test_dataloader())

In [None]:
trainer.test(ckpt_path='best', dataloaders=datamodule.test_dataloader())

In [None]:
csv_paths = ['../data/Cicadidae.csv', '../data/Orthoptera.csv']

batch_size = 10
num_workers = 0

n_fft = 1024
n_mels = None
top_db = None


datamodule = InsectDatamodule(
    csv_paths = csv_paths,
    batch_size = batch_size,
    num_workers = num_workers,
    n_fft = n_fft,
    n_mels = n_mels,
    top_db = top_db)

ckpt_path = './lightning_logs/all_data/version04/checkpoints/best.ckpt'

# resnet = ResNet.load_from_checkpoint(checkpoint_path=ckpt_path)

resnet = ResNet.load_from_checkpoint('./lightning_logs/all_data/version04/checkpoints/best.ckpt', in_channels=1)

1) Dataset gibt nun ein index zurück. Das nutzten wir später, um prediction x mit dem CSV zu verknuepfen.
2) Im predict_step brauchen wir nun die predictions und den index.
3) `trainer.predict()` iteriert den prediction loader und gibt eine liste von predictions zurück.
4) Nun können wir im dataset CSV eine neue Spalte machen, und die prediction jeweils zum richtigen index schreiben.