# Imports

In [9]:
import os
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
import pytorch_lightning as pl

In [None]:
cd ../src

In [11]:
import rsna_config as config
from rsna_models import CT_3DModel
from rsna_data import LoadingDataset
from rsna_utils import split_group, SimpleLogger
from rsna_visualization import plot_per_task_accuracy, plot_loss, get_best_epoch

In [10]:
DATASET_PATH = "rsna-2023-abdominal-trauma-detection"
INPUT_PATH = "rsna-3voxel-float16"

# Dataloader

In [14]:
# TODO: download csv
meta_df = pd.read_csv(f"{DATASET_PATH}/train_series_meta.csv")
aortic_hues = meta_df["aortic_hu"].to_numpy()

In [15]:
from sklearn.preprocessing import StandardScaler

ah_normalizer = StandardScaler().fit(aortic_hues.reshape(-1, 1))

In [None]:
label_df = pd.read_csv(f"{DATASET_PATH}/train.csv")
label_df["filepath"] = f"data_"+ label_df["patient_id"].astype(str) + ".pt"

In [19]:
data_paths = [os.path.join(INPUT_PATH, data_filename) for data_filename in os.listdir(INPUT_PATH) if data_filename.endswith('.pt')]

In [None]:
print(len(data_paths))

In [None]:
df = label_df

In [None]:
len(df)

In [23]:
# TODO: Remove.
if config.DEBUGGIN:
    df = df[:30] # This is to reduce dataset size for faster debuggin


In [24]:
# Initialize the train and validation datasets
train_df = pd.DataFrame()
val_df = pd.DataFrame()

# This is different from TARGET_COLS, as otherwise
# there is redundancy in groups
GROUP_COLS  = [
    "bowel_injury", "extravasation_injury",
    "kidney_healthy", "kidney_low", "kidney_high",
    "liver_healthy", "liver_low", "liver_high",
    "spleen_healthy", "spleen_low", "spleen_high",
]

# Iterate through the groups and split them, handling single-sample groups
for _, group in df.groupby(GROUP_COLS):
    train_group, val_group = split_group(group)
    train_df = pd.concat([train_df, train_group], ignore_index=True)
    val_df = pd.concat([val_df, val_group], ignore_index=True)
    
train_data_paths = train_df["filepath"].to_list()
val_data_paths = val_df["filepath"].to_list()

In [26]:
train_ds = LoadingDataset(train_data_paths, ah_normalizer=ah_normalizer) 
val_ds = LoadingDataset(val_data_paths, ah_normalizer=ah_normalizer)

train_dl, val_dl = DataLoader(train_ds, shuffle=True), DataLoader(val_ds)

# Train

In [29]:
logger = SimpleLogger()

In [30]:
# Instantiate the Lightning model
model = CT_3DModel()

# Instantiate the Trainer and train the model
trainer = pl.Trainer(max_epochs=config.EPOCHS, accelerator="auto", logger=logger,
                     precision="16-mixed", # mixed precission
                     accumulate_grad_batches=config.BATCH_SIZE, # virtual BATCH_SIZE
#                      gradient_clip_val=1e-1,
#                      detect_anomaly=True
                    )

2023-09-21 02:12:44.380756: F tensorflow/compiler/xla/xla_client/pjrt_computation_client.cc:81] Non-OK-status: pjrt::LoadPjrtPlugin( "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) status: INVALID_ARGUMENT: Unexpected PJRT_Api size: expected 496, got 512. Check installed software versions.
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	
	xla::ComputationClient::Create()
	
	
	xla::ComputationClient::Get()
	
	
	
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	_PyEval_EvalCodeWithName
	_PyFunction_Vectorcall
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyObject_FastCallDict
	_PyObject_Call_Prepend
	
	
	_PyObject_MakeTpCall
	_PyEval

In [31]:
from pytorch_lightning.tuner import Tuner

# Create a Tuner
tuner = Tuner(trainer)

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
lr_finder = tuner.lr_find(model, train_dataloaders=train_dl, min_lr = 1e-20, max_lr = 1)

  rank_zero_warn(
Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(


ValueError: Expected 2D array, got 1D array instead:
array=[174.].
Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

In [None]:
# # Results can be found in
print(lr_finder.results)

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# # Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()

In [None]:
new_lr = lr_finder.suggestion()
new_lr

In [None]:
# Instantiate the Lightning model
model = CT_3DModel(lr = new_lr)

# Instantiate the Trainer and train the model
trainer = pl.Trainer(max_epochs=config.EPOCHS, accelerator="auto", logger=logger,
                     precision="16-mixed", # mixed precission
                     accumulate_grad_batches=config.BATCH_SIZE, # virtual BATCH_SIZE
#                      gradient_clip_val=1,
                     detect_anomaly=True
                    )
trainer.fit(model, train_dl, val_dl)

# Visualize Results

In [None]:
plot_per_task_accuracy(logger)

In [None]:
plot_loss(logger)

In [None]:
get_best_epoch(logger)

# Save Model

In [None]:
# Save the PyTorch model
SAVE_PATH = "rsna-atd_ct_3d.pth"
trainer.save_checkpoint(SAVE_PATH)