In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import pandas as pd
import tensorflow as tf

import dataloader as DL
import layers as CLayers
import train_utils as TU

from sklearn.model_selection import train_test_split

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

## Load data

In [None]:
df = pd.read_csv('./data/ADNI_saliency.csv', low_memory=False)
date = '20240118_2300'
input_name = 'MagGrad'
timepoint = '2Y'
df = df[df[f'CONV_STATE_{timepoint}'] != -1]
unique_rids = df['RID'].unique()

In [None]:
# train_rids, valid_rids = train_test_split(unique_rids, test_size=0.2, random_state=2024)

In [None]:
# with open(f'./data/train_rids_{timepoint}.txt', 'w') as file:
#     for rid in train_rids:
#         file.write(f"{rid}\n")
# with open(f'./data/val_rids_{timepoint}.txt', 'w') as file:
#     for rid in valid_rids:
#         file.write(f"{rid}\n")

In [None]:
with open(f'./data/train_rids_{timepoint}.txt', 'r') as file:
    train_rids = [line.strip() for line in file]
train_rids = [int(float(rid)) for rid in train_rids]
# Load validation RIDs from the text file
with open(f'./data/val_rids_{timepoint}.txt', 'r') as file:
    valid_rids = [line.strip() for line in file]
valid_rids = [int(float(rid)) for rid in valid_rids]

In [None]:
data_file = './data/ADNI_saliency.csv'
batch_size  = 10
num_classes = 3
timesteps   = int(timepoint[0])*2+1
image_shape = (84,48,42)
use_pe = False
augmented = True
train_ds = DL.InputFunction(filepath    = data_file,
                            list_rids   = train_rids,
                            labeltime   = timepoint,
                            num_classes = num_classes,
                            image_shape = image_shape,
                            use_pe      = use_pe,
                            augmented   = augmented,
                            timesteps   = timesteps, 
                            batch_size  = batch_size,
                            drop_last=True,
                            shuffle=True)
valid_ds = DL.InputFunction(filepath    = data_file,
                            list_rids   = valid_rids,
                            labeltime   = timepoint,
                            num_classes = num_classes,
                            image_shape = image_shape,
                            use_pe      = use_pe,
                            augmented   = augmented,
                            timesteps   = timesteps,
                            batch_size  = batch_size,
                            shuffle=True)

In [None]:
train_steps = train_ds.steps_per_epoch()
if (not train_ds.drop_last) and (train_ds.steps_per_epoch()*batch_size<train_ds.size()):
    train_steps += 1

In [None]:
# max_time = 0
# for i, (feat, label) in enumerate(valid_ds()):
#     temp_time = feat['deltas'].shape[1]
#     if temp_time > max_time:
#         max_time = temp_time
# print(max_time)

In [None]:
# for i, (feat, label) in enumerate(train_ds()):
#     if i==0:
#         break

In [None]:
# label['conv_lb_neg2']

## Create model

#### Define unimodel

##### #1

In [None]:
model_config = dict(
    input_name=input_name,
    num_classes=num_classes,
    num_filters=[4,8,8,16,16],
    bap_filters=16,
    fc_units=[50, 20],
    kernel_size=3,
    pool_size=2,
    dropout=0.45,
)

In [None]:
uni_model = CLayers.UnimodelS1_CNN_Attention(**model_config)

## Create trainer

In [None]:
learning_rate = 3e-3
num_epochs    = 50
optimizer     = tf.keras.optimizers.Adam(learning_rate=learning_rate, weight_decay=1e-5)

In [None]:
trainer = TU.TrainAndEvaluateS1_WSL(
    model=uni_model,
    model_dir=f"./checkpoints/unimodel_wsl_{input_name}_{date}_{timepoint}",
    input_name=input_name,
    train_dataset=train_ds(),
    eval_dataset=valid_ds(),
    num_epochs=num_epochs,
    train_steps=train_steps,
    optimizer=optimizer
)

In [None]:
trainer.train_and_evaluate()

In [None]:
trainer.save_model('save_model')

## Prediction

In [None]:
import predict_utils as PU

In [None]:
predictor = PU.Predictor(trainer.model, trainer.model_dir, input_name, augmented=True, ckpt_name='save_model')

In [None]:
predictor.predict(valid_ds)

In [None]:
results = predictor.get_performance()

In [None]:
results

In [None]:
predictor.plot_confusion_matrix(columns=['Non-Converted', 'Converted', 'AD'], save_path=f'figures/unimodel_wsl_{input_name}_{date}_{timepoint}.png')