## **Setup**

In [1]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: ok


In [2]:
import os
import sys
import os.path as op
import numpy as np
from functools import partial
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.loggers import WandbLogger

import torch
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

sys.path.append("..")
from mtecg import ScarDataset, LVEFDataset, SingleTaskModel 
from mtecg.utils import load_ecg_dataframe


SEED = 42
np.random.seed(SEED)
seed_everything(SEED, workers=True)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed(SEED)

c:\Anaconda3\envs\ecg\lib\site-packages\numpy\.libs\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll
c:\Anaconda3\envs\ecg\lib\site-packages\numpy\.libs\libopenblas64__v0.3.21-gcc_10_3_0.dll
Global seed set to 42


In [3]:
task = "scar"

lvef_threshold = 50
image_size= (384, 384)
batch_size = 16
num_epochs = 10

configs = {
    "in_channels": 3,
    "learning_rate": 5e-3,
    "use_timm": True,
    "pretrained": True,
    "backbone": "resnet34d",
    "latent_dim": 512,
    "num_classes": 2,
    "bias_head": True,
    # Specify the device.
    "device": "cuda",
}

task_to_dataset_map = {
    "scar": ScarDataset,
    "lvef": partial(LVEFDataset, lvef_threshold=lvef_threshold)
}

In [None]:
parent_save_dir = f"../trained_models/single_task_{task}"
os.makedirs(parent_save_dir, exist_ok=True)

run_suffix = f"{image_size[0]}_{str(lvef_threshold)}" if task == "lvef" else f"{image_size[0]}"
run_name = f"{configs['backbone']}_{run_suffix}"

## **Prepare the data**

In [4]:
image_dir = "../../ecg/ecg-cnn-local/siriraj_data/ECG_MRI_images_new/"
csv_path = "../../ECG_EF_Clin_train_dev_new.csv"

df = load_ecg_dataframe(csv_path, image_dir)
print(f"Number of images: {len(df)}")
print(f"Unique splits: {df['split'].unique()}")
df.head(5)

Number of images: 13343
Unique splits: ['old_train' 'old_valid' 'old_test' 'new_train' 'new_valid']


Unnamed: 0,run_num,train_80_percent,develop_10_percent,file_name,lvef,scar_cad,hcm,mri_date,month,year,...,dm,ht,mi,pci,cabg,ua,chest pain,dyspnea,path,split
0,1,1.0,,2009_420521391,59.9,0,0,2552-08-01 00:00:00,8,2009,...,0,1,0,0,0,0,1,0,../../ecg/ecg-cnn-local/siriraj_data/ECG_MRI_i...,old_train
1,2,1.0,,2009_472422791,81.7,0,0,2552-08-01 00:00:00,8,2009,...,0,1,0,0,0,0,1,0,../../ecg/ecg-cnn-local/siriraj_data/ECG_MRI_i...,old_train
2,3,1.0,,2009_451191451,64.7,0,0,2552-08-01 00:00:00,8,2009,...,0,1,0,0,0,0,1,1,../../ecg/ecg-cnn-local/siriraj_data/ECG_MRI_i...,old_train
3,4,1.0,,2009_512029431,10.7,1,0,2552-08-01 00:00:00,8,2009,...,1,0,1,1,0,0,0,1,../../ecg/ecg-cnn-local/siriraj_data/ECG_MRI_i...,old_train
4,5,1.0,,2009_461543281,19.3,1,0,2552-08-04 00:00:00,8,2009,...,0,1,0,0,0,0,1,1,../../ecg/ecg-cnn-local/siriraj_data/ECG_MRI_i...,old_train


In [5]:
# Combine old train and new train.
train_df = df[df.split.isin(["old_train", "new_train"])].reset_index()
# Combine old valid and new valid.
valid_df = df[df.split.isin(["old_valid", "new_valid"])].reset_index()

train_df.shape, valid_df.shape

((9393, 28), (2905, 28))

In [6]:
train_transform = A.Compose([
    A.Resize(*image_size),
    A.Blur(blur_limit=3, p=0.2),
    A.RandomBrightnessContrast(),
    A.MotionBlur(p=0.2),
    A.Normalize(),
    ToTensorV2(),
])

valid_transform = A.Compose([
    A.Resize(*image_size),
    A.Normalize(),
    ToTensorV2()
])

dataset = task_to_dataset_map[task]
train_ds = dataset(train_df, train_transform)
valid_ds = dataset(valid_df, valid_transform)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, pin_memory=True)

## **Train**

In [8]:
model =SingleTaskModel(**configs)

In [None]:
import wandb
project_name = f"ecg-single-task-{task}"

os.environ["WANDB_NOTEBOOK_NAME"] = "ecg-single-task.ipynb"
run = wandb.init(project = project_name, save_code = True)
run.log_code(".", include_fn = lambda path: path.endswith(".py") or path.endswith(".ipynb"))
run.config.update({"batch_size": batch_size,})

checkpoint_callback = ModelCheckpoint(
    filename = configs["backbone"] + "{val_acc:.2f}",
    save_top_k = 1,
    verbose = True,
    monitor = "val_loss",
    mode = "min",
)

logger = WandbLogger(
    project = project_name,
    name = configs["backbone"],
    # log_model = "all", # set to True to log at the end
)

logger.watch(
    model, 
    # log_freq=300, # uncomment to log gradients
    log_graph = True,
)

In [None]:
trainer = Trainer(
    accelerator="gpu",
    logger = logger,
    max_epochs = num_epochs,
    callbacks = [checkpoint_callback, StochasticWeightAveraging(1e-3)],
)

trainer.fit(
    model,
    train_dataloaders = train_loader,
    val_dataloaders = valid_loader,
)

In [None]:
trainer.save_checkpoint(op.join(parent_save_dir, run_name, "model.ckpt"))
model.save_configs(op.join(parent_save_dir, run_name))

A.save(train_transform, op.join(parent_save_dir, run_name, "train_transform.json"))
A.save(valid_transform, op.join(parent_save_dir, run_name, "transform.json"))