In [1]:
# standard modules
import numpy as np
import pandas as pd
import importlib

# PyTorch modules
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

# import data + model modules
import ukbb_data
import ukbb_ica_models
# in case of changes 
importlib.reload(ukbb_data)
importlib.reload(ukbb_ica_models)

# import custom functions
import utils
# in case of changes
importlib.reload(utils)

# visualisation
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# prepare data paths
ukbb_dir = '/ritter/share/data/UKBB/ukb_data/'

In [3]:
torch.cuda.is_available()

True

# Testing different learning rates for the ICA25 simple 1D CNN

In [None]:
# define learning rates to test 
lrs = [10e-2, 10e-3, 10e-4, 10e-5, 10e-6]

# define logging dirs
log_dir = 'ICA25/LearningRates/'

# initialise DataModule
data = ukbb_data.UKBBDataModule(ukbb_dir)

for lr in lrs:
    print(f'\n>>Training model with learning rate {lr}...')
    # initialise model
    simple_CNN = ukbb_ica_models.simple1DCNN(lr=lr)

    # initialise logger
    logger = CSVLogger(save_dir=log_dir+str(lr)+'/', name='Logs')

    # set callbacks
    early_stopping = EarlyStopping(monitor='val_loss')

    checkpoint = ModelCheckpoint(dirpath=log_dir+str(lr)+'/Checkpoint/',
                                 filename='models-{epoch:02d}-{valid_loss:.2f}',
                                 monitor='val_loss',
                                 save_top_k=1,
                                 mode='min')

    # initialise trainer
    trainer = pl.Trainer(accelerator='gpu',
                         devices=[2],
                         max_epochs=500, ## which number would be good?
                         logger=logger,
                         log_every_n_steps=10,
                         callbacks=[early_stopping, checkpoint],
                         deterministic=True)

    # train model
    trainer.fit(simple_CNN, datamodule=data)
    print('Training complete.')
    print(f'\nTesting model with learning rate {lr}...')
    
    # test model
    trainer.test(simple_CNN, datamodule=data)
    
    ### print pic?