In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from data.aware_new import AwareCSA
from data.cats_and_dogs import Dataset

classes = [
    'Control / healthy / no pulmonary disease',
    'Asthma',
    'CF',
    'COPD'
]
dataset = AwareCSA(
    csv_data = 'data/exhale_data_v8_ave.csv',
    csv_outcome = 'data/exhale_outcome_v8_ave.csv',
    csv_info = 'data/exhale_verbose_v8_ave.csv',
    redcap_csv_file = "data/AWARE_DATA_LABELS_2023-12-08_1611.csv",
    id_map_file = "data/id_map.csv",
    target_classes=classes, 
    age_balanced=False, 
    output_demogr=True, 
    output_spiro_raw=True, 
    output_spiro_pred=False, 
    output_oscil_raw=False, 
    output_oscil_zscore=False, 
    output_disease_label=False
)

# dataset.save_to_pickle('data/aware_spectrogram.pkl')

# dataset = Dataset("data/cats_and_dogs", video=False, dim_order="BTCHW")

print("Class Weights:")
print(dataset.class_weights)

Class Weights:
[ 0.86694915  0.465       1.51331361 28.41666667]


In [3]:
inputs, demogr, labels = dataset[0]
print(inputs.shape)
print(inputs.dtype)
print(demogr)
print(labels)

(84,)
float32
[ 30.7    1.   160.02  80.  ]
[2.94 3.34 0.88 3.97]


In [4]:
import torch
from torchview import draw_graph
import math
from models.utils import select_model
import matplotlib.pyplot as plt
from matplotlib import animation
from utils.explainability import GradCAM
from peft import get_peft_model, LoraConfig, TaskType

RANDOM_SEED = [4399,114514,1234,1024,304,1,2,3,4,5]
# RANDOM_SEED = [1]
BATCH_SIZE = 16
LEARN_RATE = 1e-4
MAX_NUM_EPOCH = 50
MODEL_NAME = "cnn1d_reg"
VISUALIZE_MODEL = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Current device:", device)

model = select_model(MODEL_NAME)
model.to(device)

if VISUALIZE_MODEL:
    model_graph = draw_graph(
        model,
        input_size=(BATCH_SIZE, 3, 32, 224, 224),
        expand_nested=True,
        save_graph=True,
        filename=MODEL_NAME
    )
    # display(model_graph.visual_graph)
    # print(model)

# target_layers = [model.layer4[-1]]
# target_layers = [model.vivit.layernorm]
# cam = GradCAM(model=model, target_layers=target_layers)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
print(model)
# model = get_peft_model(model, peft_config)
# model.vit.gradient_checkpointing_enable()
# print_trainable_parameters(model)

def show_image(inputs, musk=None, flipud=False):
    inputs = inputs/2+0.5
    if flipud:
        inputs = inputs.flip(dims=(2,))
    d = math.isqrt(inputs.shape[0]-1)+1
    fig, axs = plt.subplots(d, d, figsize=(8,8))
    im = []
    if inputs.shape[0]==1:
        im += [axs.imshow(inputs[0,:,:,:].permute(1,2,0))]
        if musk is not None:
            axs.imshow(musk[0,:,:], cmap='jet', alpha=0.4)
    else:
        for i in range(inputs.shape[0]):
            # im += [axs[i//d, i%d].pcolormesh(inputs[i,0,0,:,:], shading='gouraud', cmap='gray')]
            im += [axs[i//d, i%d].imshow(inputs[i,:,:,:].permute(1,2,0))]
            if musk is not None:
                axs[i//d, i%d].imshow(musk[i,:,:], cmap='jet', alpha=0.4)
    plt.show()
    
def show_video(inputs, musk=None, animate=False, flipud=False):
    inputs = inputs/2+0.5
    if flipud:
        inputs = inputs.flip(dims=(3,))
    plt.rcParams["animation.html"] = "jshtml"
    d = math.isqrt(inputs.shape[0]-1)+1
    fig, axs = plt.subplots(d, d, figsize=(8,8))
    im = []
    msk = []
    if inputs.shape[0]==1:
        im += [axs.imshow(inputs[0,:,0,:,:].permute(1,2,0))]
        if musk is not None:
            msk += [axs.imshow(musk[0,:,:], cmap='jet', alpha=0.4)]
    else:
        for i in range(inputs.shape[0]):
            # im += [axs[i//d, i%d].pcolormesh(inputs[i,0,0,:,:], shading='gouraud', cmap='gray')]
            im += [axs[i//d, i%d].imshow(inputs[i,:,0,:,:].permute(1,2,0))]
            if musk is not None:
                 msk += [axs[i//d, i%d].imshow(musk[i,:,:], cmap='jet', alpha=0.4)]
    
    def update(frame):
        for i in range(inputs.shape[0]):
            im[i].set_array(inputs[i,:,frame,:,:].permute(1,2,0))
    
    if animate:
        ani = animation.FuncAnimation(fig=fig, func=update, frames=inputs.shape[2], interval=100)
        display(ani)
        ani.save('animation.gif', writer='imagemagick', fps=10)
        plt.close(fig)
    else:
        plt.show()

data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
inputs, demogr, labels = next(iter(data_loader))
print(inputs.size())

# inputs = image_processor(list(inputs.view(-1, *inputs.shape[2:])), return_tensors='pt')
# inputs['pixel_values'] = inputs['pixel_values'].view(BATCH_SIZE, -1, *inputs['pixel_values'].shape[2:])
# print(inputs['pixel_values'].size())
# plot(inputs['pixel_values'])

inputs = inputs.to(device)
demogr = demogr.to(device)
with torch.no_grad():
    # outputs = model(pixel_values=input)
    outputs = model(inputs, demogr)
predicted_label = outputs.argmax(-1)
print(outputs.size())
print(labels)
print(outputs)

# gradcam = cam(input_tensor=inputs[0:1,:,:,:,:], targets=None)
# print(gradcam.shape)

# show_video(inputs[0:4,:,:,:,:].permute(0,2,1,3,4).cpu(), animate=False, flipud=True)
# show_video(inputs[0:1,:,:,:,:].permute(0,2,1,3,4).cpu(), musk=gradcam[0:1,:,:], animate=False, flipud=True)
# show_image(inputs[0:4,:,:,:].cpu(), musk=gradcam[0:4,:,:])
# show_image(inputs[0:4,:,:,:].cpu(), flipud=True)

Current device: cuda
CNN1D(
  (encoder): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=(1,))
    (9): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): AvgPool1d(kernel_size=(3,), stride=(3,), padding=(0,))
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=112, out_features=32, bias=True)
  )
  (output): Sequential(
    (0): Linear(in_features=36, out_features=64, bias=True)
  

In [5]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from IPython.display import clear_output
import time
import jupyter_beeper
import warnings

from data.aware_raw import AwareSplitter
from data.cats_and_dogs import Splitter
from trainer.spectrogram_reg import Trainer
from utils.others import weight_reset, BasicMetrics, BasicOutputs, RegressionMetrics
from utils.clustering import evaluate
from utils.outlier import novelty_detection

# metrics_cluster = np.zeros((5,3,3))
metrics_val = RegressionMetrics()
metrics_test = RegressionMetrics()
outputs_test = BasicOutputs()
ig_test = torch.Tensor([])

beeper = jupyter_beeper.Beeper()
warnings.filterwarnings("ignore")  ## ignore warnings

for rand_seed in RANDOM_SEED:
    splitter = AwareSplitter(dataset, BATCH_SIZE, random_seed=rand_seed)

    timestr = time.strftime("%Y%m%d-%H%M%S")
    for split_idx, (train_loader, val_loader, test_loader) in enumerate(splitter):
        writer = SummaryWriter("runs/" + timestr + "-fold" + str(split_idx))
        model = select_model(MODEL_NAME)
        model.to(device)
        # model = get_peft_model(model, peft_config)
        # model.vit.gradient_checkpointing_enable()
        # print_trainable_parameters(model)

        # target_layers = [model.layer4[-1]]
        # cam = GradCAM(model=model, target_layers=target_layers)
        
#         clear_output(wait=True)
        print("Seed " + str(rand_seed) + " | Fold #" + str(split_idx) + " | Training...")

        trainer = Trainer(
            model,
            lr = LEARN_RATE,
            T_max = MAX_NUM_EPOCH,
            device = device,
            summarywriter = writer,
            class_weights = torch.Tensor(dataset.class_weights)
        )
        
        for epoch in tqdm(range(MAX_NUM_EPOCH), unit_scale=True, unit="epoch"):
            trainer.train(epoch, train_loader)
            trainer.validate(epoch, val_loader)
        print("Validataion:")
        trainer.test(val_loader, no_print=False)
        metrics_val.append_from(trainer)
        print("Test:")
        trainer.test(test_loader, no_print=False, calculate_ig=False)
        metrics_test.append_from(trainer)
        outputs_test.append_from(trainer)
        # ig_test = torch.concat((ig_test, trainer.attr_ig), dim=0)
        print()
        
        # inputs, labels = next(iter(test_loader))
        # inputs = inputs.to(device)
        # with torch.no_grad():
        #     # outputs = model(pixel_values=input)
        #     outputs = model(inputs)
        # predicted_label = outputs.argmax(-1)
        # print(labels)
        # print(predicted_label)

        # gradcam = cam(input_tensor=inputs[0:1,:,:,:,:], targets=None)
        # show_video(inputs[0:4,:,:,:,:].cpu(), animate=True, flipud=True)
        # show_video(inputs[0:1,:,:,:,:].cpu(), musk=gradcam[0:1,:,:], animate=False, flipud=True)
        break

#     beeper.beep(frequency=600, secs=0.5)
#     novelty_detection(model, train_loader, val_loader, test_loader)
#     metrics_cluster[split_idx,:,:] = evaluate(model, train_loader, val_loader, test_loader)    

Seed 4399 | Fold #0 | Training...


  0%|          | 0.00/50.0 [00:00<?, ?epoch/s]

Validataion:
Best testing loss: 0.58
On which epoch reach the highest accuracy: 25
Test:
Best testing loss: 0.63
On which epoch reach the highest accuracy: 25


AttributeError: 'Trainer' object has no attribute 'info'

In [None]:
import sklearn.metrics as M
# print(ig_test.size())
print("Final Validation Results")
display(metrics_val)
print()
print("Final Test Results")
display(metrics_test)
print()
# beeper.beep(frequency=600, secs=0.5)
display(outputs_test)
outputs_test.outputs.to_excel("outputs.xlsx")
labels = np.array(outputs_test.outputs['Labels'].to_list())
outputs = np.array(outputs_test.outputs['Outputs'].to_list())
display(np.sqrt(np.mean(np.square(outputs-labels), axis=0))) # RMSE
display(np.mean(np.abs(outputs-labels)/labels, axis=0)) # MAPE
display(
    M.r2_score(labels[:,0], outputs[:,0]),
    M.r2_score(labels[:,1], outputs[:,1]),
    M.r2_score(labels[:,2], outputs[:,2]),
    M.r2_score(labels[:,3], outputs[:,3])
)
fig, axs = plt.subplots(2, 2, figsize=(8,6))
axs[0,0].scatter(labels[:,0], outputs[:,0], marker='.')
axs[0,1].scatter(labels[:,1], outputs[:,1], marker='.')
axs[1,0].scatter(labels[:,2], outputs[:,2], marker='.')
axs[1,1].scatter(labels[:,3], outputs[:,3], marker='.')
plt.show()

In [None]:
# import ast
# import matplotlib.pyplot as plt
# import pandas as pd

# labels = outputs_test.outputs['Labels'].to_list()
# labels = pd.DataFrame(labels)
# labels.columns = ['Diagnosis']

# info = outputs_test.outputs['Info'].to_list()
# info = pd.DataFrame(info)
# info.columns = ['Age', 'Sex', 'Height', 'Weight']
#                # 'FEV1', 'FVC', 'FEV1/FVC', 'FEF2575']

# meta = pd.concat([labels, info], axis=1)

# outputs_cls = outputs_test.outputs['Outputs'].to_list()
# outputs_cls = pd.DataFrame(outputs_cls)

# meta.insert(1, 'Prediction', np.exp(outputs_cls[1])/np.sum(np.exp(outputs_cls),axis=1)) # Softmax

# idx_tp = (meta['Diagnosis']==1) & (meta['Prediction']>=0.5)
# tp = idx_tp.sum()
# idx_fp = (meta['Diagnosis']==0) & (meta['Prediction']>=0.5)
# fp = idx_fp.sum()
# idx_tn = (meta['Diagnosis']==0) & (meta['Prediction']<0.5)
# tn = idx_tn.sum()
# idx_fn = (meta['Diagnosis']==1) & (meta['Prediction']<0.5)
# fn = idx_fn.sum()
# print(tn, fp)
# print(fn, tp)

# plt.figure()
# plt.hist(meta['Prediction'][meta['Diagnosis']==0], bins=50, range=(0,1), alpha = 0.3, color='b', edgecolor='k', linewidth=1)
# plt.hist(meta['Prediction'][meta['Diagnosis']==1], bins=50, range=(0,1), alpha = 0.3, color='r', edgecolor='k', linewidth=1)
# plt.legend(['Healthy', 'Asthma'])
# plt.show()

# # cross_entropy = -(y_true*np.log10(y_pred[1]) + (1-y_true)*np.log10(y_pred[0]))

# # plt.figure()
# # # plt.hist(cross_entropy, bins=np.logspace(np.log10(0.1),np.log10(10.0), 50), edgecolor='k', linewidth=1)
# # plt.hist(cross_entropy, bins=50, edgecolor='k', linewidth=1)
# # plt.legend(['Healthy', 'Asthma'])
# # # plt.gca().set_xscale("log")
# # plt.vlines(-np.log10(0.5), 0, 500, color='r')
# # plt.show()

# # y_pred = y_pred.idxmax(axis=1)

# sens = tp/(tp+fn)
# spec = tn/(fp+tn)
# print('Sens:', sens)
# print('Spec:', spec)
# print('BalAcc:', (sens+spec)/2)

# plt.figure(figsize=(20,4))
# plt.subplot(2,4,1)
# plt.hist(pd.concat([meta['Age'][idx_tp], meta['Age'][idx_fn]], axis=1), range=(meta['Age'].min(),meta['Age'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Age')
# plt.subplot(2,4,5)
# plt.hist(pd.concat([meta['Age'][idx_tn], meta['Age'][idx_fp]], axis=1), range=(meta['Age'].min(),meta['Age'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.subplot(2,4,2)
# plt.hist(pd.concat([meta['Sex'][idx_tp], meta['Sex'][idx_fn]], axis=1), range=(meta['Sex'].min(),meta['Sex'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Sex')
# plt.subplot(2,4,6)
# plt.hist(pd.concat([meta['Sex'][idx_tn], meta['Sex'][idx_fp]], axis=1), range=(meta['Sex'].min(),meta['Sex'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.subplot(2,4,3)
# plt.hist(pd.concat([meta['Height'][idx_tp], meta['Height'][idx_fn]], axis=1), range=(meta['Height'].min(),meta['Height'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Height')
# plt.subplot(2,4,7)
# plt.hist(pd.concat([meta['Height'][idx_tn], meta['Height'][idx_fp]], axis=1), range=(meta['Height'].min(),meta['Height'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.subplot(2,4,4)
# plt.hist(pd.concat([meta['Weight'][idx_tp], meta['Weight'][idx_fn]], axis=1), range=(meta['Weight'].min(),meta['Weight'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Weight')
# plt.subplot(2,4,8)
# plt.hist(pd.concat([meta['Weight'][idx_tn], meta['Weight'][idx_fp]], axis=1), range=(meta['Weight'].min(),meta['Weight'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.show()

# plt.figure(figsize=(6,4))
# plt.subplot(2,1,1)
# plt.hist(pd.concat([meta['Age'][idx_tp], meta['Age'][idx_fn]], axis=1), bins=list(range(0, 78, 6)), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.subplot(2,1,2)
# plt.hist(pd.concat([meta['Age'][idx_tn], meta['Age'][idx_fp]], axis=1), bins=list(range(0, 78, 6)), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.show()

# thre = np.zeros(1001)
# sens = np.zeros(1001)
# spec = np.zeros(1001)
# for i in range(0,1001):
#     idx_tp = (meta['Diagnosis']==1) & (meta['Prediction']>=i/1000)
#     tp = idx_tp.sum()
#     idx_fp = (meta['Diagnosis']==0) & (meta['Prediction']>=i/1000)
#     fp = idx_fp.sum()
#     idx_tn = (meta['Diagnosis']==0) & (meta['Prediction']<i/1000)
#     tn = idx_tn.sum()
#     idx_fn = (meta['Diagnosis']==1) & (meta['Prediction']<i/1000)
#     fn = idx_fn.sum()
#     thre[i] = i/1000
#     sens[i] = tp/(tp+fn)
#     spec[i] = tn/(fp+tn)
    
# plt.figure()
# plt.plot(thre, sens)
# plt.plot(thre, spec)
# plt.title('Sensitivity and Specificity vs. Threshold')
# plt.legend(['Sensitivity', 'Specificity'])
# plt.xlabel('Threshold')
# plt.show()

# plt.figure()
# # skplt.metrics.plot_roc_curve(meta['Diagnosis'], outputs_cls)
# plt.plot(1-spec, sens)
# plt.plot([0,1], [0,1], '--k')
# plt.title('ROC Curve')
# plt.xlabel('False Positive Rate\n(1-Specificity)')
# plt.ylabel('True Positive Rate\n(Sensitivity)')
# plt.show()