In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from data.aware_raw import AwareRaw, AwareSpectrogram
from data.cats_and_dogs import Dataset

classes = [
    'Control / healthy / no pulmonary disease',
    'Asthma',
    # 'CF',
    # 'COPD'
]
dataset_raw = AwareRaw("data/AWARE_DATA_LABELS_2023-12-08_1611.csv", "data/id_map.csv", "data/aware_full_1704385505.db", pickle_file="data/aware_segmented.pkl")
dataset = AwareSpectrogram(
    dataset_raw, 
    target_classes=classes, 
    age_balanced=False, 
    output_demogr=False, 
    output_spiro_raw=False, 
    output_spiro_pred=False, 
    output_oscil_raw=False, 
    output_oscil_zscore=False, 
    output_disease_label=False,
    output_inhale_exhale=True,
    relative_change=False, 
    calibration=True, 
    averaged=False, 
    num_channels=3,
    dim_order='BTCHW',
    modality='ir'
)
imflip = True

# dataset.save_to_pickle('data/aware_spectrogram.pkl')

# dataset = Dataset("data/cats_and_dogs", video=False, dim_order="BTCHW")
# imflip = False

print("Class Distribution:")
print(dataset.class_distribution)
print("Class Weights:")
print(dataset.class_weights)

0it [00:00, ?it/s]

Class Distribution:
[1062. 1944.]
Class Weights:
[1, 1]


In [3]:
inputs, labels = dataset[0]
# inputs, demogr, labels = dataset[0]
print(inputs.shape)
print(inputs.dtype)
# print(demogr)
print(labels)

torch.Size([25, 3, 224, 224])
torch.float32
0


In [8]:
import torch
from torchview import draw_graph
import cv2
import math
from models.utils import select_model
import matplotlib.pyplot as plt
from matplotlib import animation
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances
import seaborn as sns
# from utils.explainability import GradCAM
from pytorch_grad_cam import GradCAM
from peft import get_peft_model, LoraConfig, TaskType

RANDOM_SEED = [4399,114514,1234,1024,304,1,2,3,4,5]
# RANDOM_SEED = [1]
BATCH_SIZE = 16
LEARN_RATE = 1e-3
MAX_NUM_EPOCH = 10
MODEL_NAME = "vit_binary"
VISUALIZE_MODEL = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Current device:", device)

model = select_model(MODEL_NAME)
model.to(device)

if VISUALIZE_MODEL:
    model_graph = draw_graph(
        model,
        input_size=(BATCH_SIZE, 3, 224, 224),
        expand_nested=True,
        save_graph=True,
        filename=MODEL_NAME
    )
    # display(model_graph.visual_graph)
    # print(model)

print(model)

# target_layers = [model.layer4[-1]]
# target_layers = [model.vit.vit.layernorm]
# target_layers = [model.vit.layernorm]
# cam = GradCAM(model=model, target_layers=target_layers)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
model = get_peft_model(model, peft_config)
model.vivit.gradient_checkpointing_enable()
print_trainable_parameters(model)

def show_image(inputs, mask=None, flipud=False):
    inputs = inputs/2+0.5
    if flipud:
        inputs = inputs.flip(dims=(2,))
        if mask is not None:
            mask = mask.flip(dims=(1,))
    d = math.isqrt(inputs.shape[0]-1)+1
    fig, axs = plt.subplots(d, d, figsize=(8,8))
    im = []
    if inputs.shape[0]==1:
        im += [axs.imshow(inputs[0,:,:,:].permute(1,2,0))]
        if mask is not None:
            axs.imshow(mask[0,:,:], cmap='jet', alpha=0.4)
    else:
        for i in range(inputs.shape[0]):
            # im += [axs[i//d, i%d].pcolormesh(inputs[i,0,0,:,:], shading='gouraud', cmap='gray')]
            im += [axs[i//d, i%d].imshow(inputs[i,:,:,:].permute(1,2,0))]
            if mask is not None:
                axs[i//d, i%d].imshow(mask[i,:,:], cmap='jet', alpha=0.4)
    plt.show()
    
def show_video(inputs, mask=None, animate=False, flipud=False):
    inputs = inputs/2+0.5
    if flipud:
        inputs = inputs.flip(dims=(3,))
    plt.rcParams["animation.html"] = "jshtml"
    d = math.isqrt(inputs.shape[0]-1)+1
    fig, axs = plt.subplots(d, d, figsize=(8,8))
    im = []
    msk = []
    if inputs.shape[0]==1:
        im += [axs.imshow(inputs[0,:,0,:,:].permute(1,2,0))]
        if mask is not None:
            msk += [axs.imshow(mask[0,:,:], cmap='jet', alpha=0.4)]
    else:
        for i in range(inputs.shape[0]):
            # im += [axs[i//d, i%d].pcolormesh(inputs[i,0,0,:,:], shading='gouraud', cmap='gray')]
            im += [axs[i//d, i%d].imshow(inputs[i,:,0,:,:].permute(1,2,0))]
            if mask is not None:
                 msk += [axs[i//d, i%d].imshow(mask[i,:,:], cmap='jet', alpha=0.4)]
    
    def update(frame):
        for i in range(inputs.shape[0]):
            im[i].set_array(inputs[i,:,frame,:,:].permute(1,2,0))
    
    if animate:
        ani = animation.FuncAnimation(fig=fig, func=update, frames=inputs.shape[2], interval=100)
        display(ani)
        ani.save('animation.gif', writer='imagemagick', fps=10)
        plt.close(fig)
    else:
        plt.show()

data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
inputs, labels = next(iter(data_loader))
# inputs, demogr, labels = next(iter(data_loader))
print('Input Size:', inputs.size())

# inputs = image_processor(list(inputs.view(-1, *inputs.shape[2:])), return_tensors='pt')
# inputs['pixel_values'] = inputs['pixel_values'].view(BATCH_SIZE, -1, *inputs['pixel_values'].shape[2:])
# print(inputs['pixel_values'].size())
# plot(inputs['pixel_values'])

inputs = inputs.to(device)
# demogr = demogr.to(device)
with torch.no_grad():
#     outputs = model(pixel_values=inputs)
    outputs = model(inputs)
    # outputs = model(inputs, demogr)
predicted_label = outputs.argmax(-1)
print('Output Size:', outputs.size())
print('Labels:', labels)
print('Predicted:', predicted_label)

def get_attention_map(inputs, logits, att_mat):
    att_mat = torch.stack(att_mat, dim=1)
    print(att_mat.size())

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=2)
    print(att_mat.size())

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(-1)).to(device)
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
    print(aug_att_mat.size())

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size()).to(device)
    joint_attentions[:,0] = aug_att_mat[:,0]

    for n in range(1, aug_att_mat.size(1)):
        joint_attentions[:,n] = torch.matmul(aug_att_mat[:,n], joint_attentions[:,n-1])

    v = joint_attentions[:,-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[:, 0, 1:].reshape(v.size(0), grid_size, grid_size).detach().cpu().numpy()
    print(mask.shape)
    result = []
    for i in range(mask.shape[0]):
        mask[i,...] = (mask[i,...]-mask[i,...].min()) / (mask[i,...].max()-mask[i,...].min())
        result += [cv2.resize(mask[i,...], inputs[i,0,...].size())]
    result = torch.Tensor(np.array(result))

    return result

# m = get_attention_map(inputs, outputs, model.attentions)

# for i in range(12):
#     print(model.attentions[i].shape)
#     show_image(model.attentions[i][0:1,:,:,:].permute(1,0,2,3).cpu(), flipud=True)

# gradcam = cam(input_tensor=inputs[0:4,:,:,:], targets=None)
# print(gradcam.shape)

# show_video(inputs[0:4,:,:,:,:].permute(0,2,1,3,4).cpu(), animate=False, flipud=True)
# show_video(inputs[0:1,:,:,:,:].permute(0,2,1,3,4).cpu(), mask=gradcam[0:1,:,:], animate=False, flipud=True)
# show_image(inputs[0:4,:,:,:].cpu(), mask=gradcam[0:4,:,:])

# show_image(inputs[0:4,...].cpu(), flipud=imflip)
# show_image(inputs[0:4,...].cpu(), mask=m, flipud=imflip)

# t-SNE visualization
def tsne_plot(X, y):
    print(X.shape)
    tsne = TSNE(n_components=2, verbose=0, perplexity=10, n_iter=300)
    tsne_results = tsne.fit_transform(X)
    
    results = {'tsne-2d-one': tsne_results[:,0],
               'tsne-2d-two': tsne_results[:,1],
               'y': y}
    
    plt.figure(figsize=(8,6))
    sns.scatterplot(
        x="tsne-2d-one", y="tsne-2d-two",
        hue="y",
        data=results,
        palette=sns.color_palette("Set2"),
        legend="full",
    )
    plt.legend(classes)
    plt.title("t-SNE Scatter Plot")
    plt.show()

tsne_plot(outputs.cpu(), labels.cpu())

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['vivit.pooler.dense.bias', 'vivit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Current device: cuda
ViViT(
  (vivit): VivitModel(
    (embeddings): VivitEmbeddings(
      (patch_embeddings): VivitTubeletEmbeddings(
        (projection): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): VivitEncoder(
      (layer): ModuleList(
        (0-11): 12 x VivitLayer(
          (attention): VivitAttention(
            (attention): VivitSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): VivitSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): VivitIntermediate(
            

RuntimeError: The size of tensor a (2353) must match the size of tensor b (3137) at non-singleton dimension 1

In [None]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from IPython.display import clear_output
import time
import jupyter_beeper
import warnings

from data.aware_raw import AwareSplitter
from data.cats_and_dogs import Splitter
from trainer.spectrogram import Trainer
from utils.others import weight_reset, BasicMetrics, BasicOutputs, RegressionMetrics
from utils.clustering import evaluate
from utils.outlier import novelty_detection

# metrics_cluster = np.zeros((5,3,3))
metrics_val = RegressionMetrics()
metrics_test = RegressionMetrics()
outputs_test = BasicOutputs()
ig_test = torch.Tensor([])

beeper = jupyter_beeper.Beeper()
warnings.filterwarnings("ignore")  ## ignore warnings

for rand_seed in RANDOM_SEED:
    splitter = AwareSplitter(dataset, BATCH_SIZE, random_seed=rand_seed)

    timestr = time.strftime("%Y%m%d-%H%M%S")
    for split_idx, (train_loader, val_loader, test_loader) in enumerate(splitter):
        writer = SummaryWriter("runs/" + timestr + "-fold" + str(split_idx))
        model = select_model(MODEL_NAME)
        model.to(device)
        # model = get_peft_model(model, peft_config)
        # model.vit.gradient_checkpointing_enable()
        # print_trainable_parameters(model)

        # target_layers = [model.layer4[-1]]
        # cam = GradCAM(model=model, target_layers=target_layers)
        
#         clear_output(wait=True)
        print("Seed " + str(rand_seed) + " | Fold #" + str(split_idx) + " | Training...")

        trainer = Trainer(
            model,
            lr = LEARN_RATE,
            T_max = MAX_NUM_EPOCH,
            device = device,
            summarywriter = writer,
            class_weights = torch.Tensor(dataset.class_weights)
        )
        
        for epoch in tqdm(range(MAX_NUM_EPOCH), unit_scale=True, unit="epoch"):
            trainer.train(epoch, train_loader)
            trainer.validate(epoch, val_loader)
        print("Train:")
        trainer.test(train_loader, no_print=False)
        # tsne_plot(trainer.outputs.cpu(), trainer.labels.cpu())
        print("Validataion:")
        trainer.test(val_loader, no_print=False)
        metrics_val.append_from(trainer)
        # tsne_plot(trainer.outputs.cpu(), trainer.labels.cpu())
        print("Test:")
        trainer.test(test_loader, no_print=False, calculate_ig=False)
        metrics_test.append_from(trainer)
        outputs_test.append_from(trainer)
        # tsne_plot(trainer.outputs.cpu(), trainer.labels.cpu())
        # ig_test = torch.concat((ig_test, trainer.attr_ig), dim=0)
        print()

        inputs, labels = next(iter(test_loader))
        # inputs, demogr, labels = next(iter(test_loader))
        inputs = inputs.to(device)
        # demogr = demogr.to(device)
        with torch.no_grad():
            outputs = model(inputs)
            # outputs = model(inputs, demogr)
        predicted_label = outputs.argmax(-1)
        print(labels)
        print(predicted_label)

#         m = get_attention_map(inputs, outputs, model.attentions)
#         show_image(inputs[0:4,:,:,:].cpu(), flipud=imflip)
#         show_image(inputs[0:4,:,:,:].cpu(), mask=m, flipud=imflip)
        # gradcam = cam(input_tensor=inputs[0:1,:,:,:,:], targets=None)
        # show_video(inputs[0:4,:,:,:,:].cpu(), animate=True, flipud=True)
        # show_video(inputs[0:1,:,:,:,:].cpu(), mask=gradcam[0:1,:,:], animate=False, flipud=True)
        break

#     beeper.beep(frequency=600, secs=0.5)
#     novelty_detection(model, train_loader, val_loader, test_loader)
#     metrics_cluster[split_idx,:,:] = evaluate(model, train_loader, val_loader, test_loader)    

In [None]:
import sklearn.metrics as M
# print(ig_test.size())
print("Final Validation Results")
display(metrics_val)
print()
print("Final Test Results")
display(metrics_test)
print()
# beeper.beep(frequency=600, secs=0.5)
display(outputs_test)
labels = np.array(outputs_test.outputs['Labels'].to_list())
outputs = np.array(outputs_test.outputs['Outputs'].to_list())
display(np.sqrt(np.mean(np.square(outputs-labels), axis=0))) # RMSE
display(np.mean(np.abs(outputs-labels)/labels, axis=0)) # MAPE
display(
    M.r2_score(labels[:,0], outputs[:,0]),
    M.r2_score(labels[:,1], outputs[:,1]),
    M.r2_score(labels[:,2], outputs[:,2]),
    M.r2_score(labels[:,3], outputs[:,3]),
    # M.r2_score(labels[:,4], outputs[:,4])
)
fig, axs = plt.subplots(2, 2, figsize=(8,6))
axs[0,0].scatter(labels[:,0], outputs[:,0], marker='.')
axs[0,1].scatter(labels[:,1], outputs[:,1], marker='.')
axs[1,0].scatter(labels[:,2], outputs[:,2], marker='.')
axs[1,1].scatter(labels[:,3], outputs[:,3], marker='.')
# axs[1,1].scatter(labels[:,4], outputs[:,4], marker='.')
plt.show()

In [None]:
# import ast
# import matplotlib.pyplot as plt
# import pandas as pd

# labels = outputs_test.outputs['Labels'].to_list()
# labels = pd.DataFrame(labels)
# labels.columns = ['Diagnosis']

# info = outputs_test.outputs['Info'].to_list()
# info = pd.DataFrame(info)
# info.columns = ['Age', 'Sex', 'Height', 'Weight']
#                # 'FEV1', 'FVC', 'FEV1/FVC', 'FEF2575']

# meta = pd.concat([labels, info], axis=1)

# outputs_cls = outputs_test.outputs['Outputs'].to_list()
# outputs_cls = pd.DataFrame(outputs_cls)

# meta.insert(1, 'Prediction', np.exp(outputs_cls[1])/np.sum(np.exp(outputs_cls),axis=1)) # Softmax

# idx_tp = (meta['Diagnosis']==1) & (meta['Prediction']>=0.5)
# tp = idx_tp.sum()
# idx_fp = (meta['Diagnosis']==0) & (meta['Prediction']>=0.5)
# fp = idx_fp.sum()
# idx_tn = (meta['Diagnosis']==0) & (meta['Prediction']<0.5)
# tn = idx_tn.sum()
# idx_fn = (meta['Diagnosis']==1) & (meta['Prediction']<0.5)
# fn = idx_fn.sum()
# print(tn, fp)
# print(fn, tp)

# plt.figure()
# plt.hist(meta['Prediction'][meta['Diagnosis']==0], bins=50, range=(0,1), alpha = 0.3, color='b', edgecolor='k', linewidth=1)
# plt.hist(meta['Prediction'][meta['Diagnosis']==1], bins=50, range=(0,1), alpha = 0.3, color='r', edgecolor='k', linewidth=1)
# plt.legend(['Healthy', 'Asthma'])
# plt.show()

# # cross_entropy = -(y_true*np.log10(y_pred[1]) + (1-y_true)*np.log10(y_pred[0]))

# # plt.figure()
# # # plt.hist(cross_entropy, bins=np.logspace(np.log10(0.1),np.log10(10.0), 50), edgecolor='k', linewidth=1)
# # plt.hist(cross_entropy, bins=50, edgecolor='k', linewidth=1)
# # plt.legend(['Healthy', 'Asthma'])
# # # plt.gca().set_xscale("log")
# # plt.vlines(-np.log10(0.5), 0, 500, color='r')
# # plt.show()

# # y_pred = y_pred.idxmax(axis=1)

# sens = tp/(tp+fn)
# spec = tn/(fp+tn)
# print('Sens:', sens)
# print('Spec:', spec)
# print('BalAcc:', (sens+spec)/2)

# plt.figure(figsize=(20,4))
# plt.subplot(2,4,1)
# plt.hist(pd.concat([meta['Age'][idx_tp], meta['Age'][idx_fn]], axis=1), range=(meta['Age'].min(),meta['Age'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Age')
# plt.subplot(2,4,5)
# plt.hist(pd.concat([meta['Age'][idx_tn], meta['Age'][idx_fp]], axis=1), range=(meta['Age'].min(),meta['Age'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.subplot(2,4,2)
# plt.hist(pd.concat([meta['Sex'][idx_tp], meta['Sex'][idx_fn]], axis=1), range=(meta['Sex'].min(),meta['Sex'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Sex')
# plt.subplot(2,4,6)
# plt.hist(pd.concat([meta['Sex'][idx_tn], meta['Sex'][idx_fp]], axis=1), range=(meta['Sex'].min(),meta['Sex'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.subplot(2,4,3)
# plt.hist(pd.concat([meta['Height'][idx_tp], meta['Height'][idx_fn]], axis=1), range=(meta['Height'].min(),meta['Height'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Height')
# plt.subplot(2,4,7)
# plt.hist(pd.concat([meta['Height'][idx_tn], meta['Height'][idx_fp]], axis=1), range=(meta['Height'].min(),meta['Height'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.subplot(2,4,4)
# plt.hist(pd.concat([meta['Weight'][idx_tp], meta['Weight'][idx_fn]], axis=1), range=(meta['Weight'].min(),meta['Weight'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.title('Weight')
# plt.subplot(2,4,8)
# plt.hist(pd.concat([meta['Weight'][idx_tn], meta['Weight'][idx_fp]], axis=1), range=(meta['Weight'].min(),meta['Weight'].max()), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.show()

# plt.figure(figsize=(6,4))
# plt.subplot(2,1,1)
# plt.hist(pd.concat([meta['Age'][idx_tp], meta['Age'][idx_fn]], axis=1), bins=list(range(0, 78, 6)), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Asthma', 'False Healthy'])
# plt.subplot(2,1,2)
# plt.hist(pd.concat([meta['Age'][idx_tn], meta['Age'][idx_fp]], axis=1), bins=list(range(0, 78, 6)), alpha = 0.3, stacked=True, edgecolor='k', linewidth=1)
# plt.legend(['True Healthy', 'False Asthma'])
# plt.show()

# thre = np.zeros(1001)
# sens = np.zeros(1001)
# spec = np.zeros(1001)
# for i in range(0,1001):
#     idx_tp = (meta['Diagnosis']==1) & (meta['Prediction']>=i/1000)
#     tp = idx_tp.sum()
#     idx_fp = (meta['Diagnosis']==0) & (meta['Prediction']>=i/1000)
#     fp = idx_fp.sum()
#     idx_tn = (meta['Diagnosis']==0) & (meta['Prediction']<i/1000)
#     tn = idx_tn.sum()
#     idx_fn = (meta['Diagnosis']==1) & (meta['Prediction']<i/1000)
#     fn = idx_fn.sum()
#     thre[i] = i/1000
#     sens[i] = tp/(tp+fn)
#     spec[i] = tn/(fp+tn)
    
# plt.figure()
# plt.plot(thre, sens)
# plt.plot(thre, spec)
# plt.title('Sensitivity and Specificity vs. Threshold')
# plt.legend(['Sensitivity', 'Specificity'])
# plt.xlabel('Threshold')
# plt.show()

# plt.figure()
# # skplt.metrics.plot_roc_curve(meta['Diagnosis'], outputs_cls)
# plt.plot(1-spec, sens)
# plt.plot([0,1], [0,1], '--k')
# plt.title('ROC Curve')
# plt.xlabel('False Positive Rate\n(1-Specificity)')
# plt.ylabel('True Positive Rate\n(Sensitivity)')
# plt.show()