In [1]:
from data import LOOCV_datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from tqdm.notebook import tqdm
from scipy.stats import spearmanr, pearsonr

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(device)
print(torch.version.cuda)

cuda
12.4


In [2]:
image_size = (420, 420)
datasets = LOOCV_datasets(data_folders=["data/train", "data/test"], size=image_size)

  0%|          | 0/880 [00:00<?, ?it/s]

  0%|          | 0/220 [00:00<?, ?it/s]

Loaded 55 patients


In [5]:
def one_loocv_cycle(train_dataset, test_dataset, batch_size=2, epochts=10):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

    model = torchvision.models.densenet121(weights=torchvision.models.DenseNet121_Weights.DEFAULT)
    model.classifier = torch.nn.Linear(1024, 1)
    model.to(device)

    # Training
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = nn.L1Loss()

    for epoch in tqdm(range(epochts), total=epochts):
        # Training loop
        model.train()
        train_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        # Validation loop
        model.eval()
        with torch.no_grad():
            val_loss = 0
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        val_loss /= len(test_loader)
        print(f"Epoch {epoch + 1}/{epochts}, Training loss: {train_loss}")
        print(f"Epoch {epoch + 1}/{epochts}, Validation loss: {val_loss}")

    # Testing
    model.eval()
    test_targets = [i[1].item() for i in test_dataset]
    model_outputs = []
    for x, y in test_dataset:
        with torch.no_grad():
            pred = model(x.unsqueeze(0).to(device))
            model_outputs.append(pred[0][0].item())
    data_for_scc = list(zip(test_targets, model_outputs))

    return data_for_scc

In [7]:
scc_data = []
i = 0

for train_set, test_set in [i for i in datasets][0:1]:
    print(f"LOOCV {i+1}")
    res = one_loocv_cycle(train_set, test_set, batch_size=4, epochts=10)
    scc_data.extend(res)
    i += 1

LOOCV 1


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10, Training loss: 20.754129551075124
Epoch 1/10, Validation loss: 53.58882179260254
Epoch 2/10, Training loss: 14.553373741441304
Epoch 2/10, Validation loss: 52.59878845214844
Epoch 3/10, Training loss: 11.680445620748731
Epoch 3/10, Validation loss: 51.924472427368165
Epoch 4/10, Training loss: 9.133814634217156
Epoch 4/10, Validation loss: 43.61979026794434
Epoch 5/10, Training loss: 6.876561914991449
Epoch 5/10, Validation loss: 57.16317481994629
Epoch 6/10, Training loss: 6.519555062828241
Epoch 6/10, Validation loss: 37.91698322296143
Epoch 7/10, Training loss: 5.604056947540354
Epoch 7/10, Validation loss: 33.466727066040036
Epoch 8/10, Training loss: 4.96964771328149
Epoch 8/10, Validation loss: 44.94395217895508
Epoch 9/10, Training loss: 4.340466026409908
Epoch 9/10, Validation loss: 45.53731575012207
Epoch 10/10, Training loss: 4.251912812723054
Epoch 10/10, Validation loss: 26.926817321777342


In [27]:
test_targets = [i[0] for i in scc_data]
model_outputs = [-i[1] for i in scc_data]

print("Spearman correlation:", spearmanr(test_targets, model_outputs))
print("Pearson correlation:", pearsonr(test_targets, model_outputs))

Spearman correlation: SignificanceResult(statistic=np.float64(-0.3805967709584713), pvalue=np.float64(2.7076316952826272e-08))
Pearson correlation: PearsonRResult(statistic=np.float64(-0.36004741692614517), pvalue=np.float64(1.6355037132651728e-07))


In [28]:
for i in range(50):
    print(f"{i+1}) True: {test_targets[i]:.2f} — Predicted: {model_outputs[i]:.2f}")

1) True: 75.00 — Predicted: -50.38
2) True: 75.00 — Predicted: -50.74
3) True: 75.00 — Predicted: -43.59
4) True: 75.00 — Predicted: -48.82
5) True: 75.00 — Predicted: -47.98
6) True: 75.00 — Predicted: -46.72
7) True: 75.00 — Predicted: -48.59
8) True: 75.00 — Predicted: -50.08
9) True: 75.00 — Predicted: -47.78
10) True: 75.00 — Predicted: -48.37
11) True: 75.00 — Predicted: -49.59
12) True: 75.00 — Predicted: -49.90
13) True: 75.00 — Predicted: -42.67
14) True: 75.00 — Predicted: -47.99
15) True: 75.00 — Predicted: -47.08
16) True: 75.00 — Predicted: -45.84
17) True: 75.00 — Predicted: -47.76
18) True: 75.00 — Predicted: -49.23
19) True: 75.00 — Predicted: -46.83
20) True: 75.00 — Predicted: -47.59
21) True: 70.00 — Predicted: -37.33
22) True: 70.00 — Predicted: -36.41
23) True: 70.00 — Predicted: -34.83
24) True: 70.00 — Predicted: -37.42
25) True: 70.00 — Predicted: -36.81
26) True: 70.00 — Predicted: -35.75
27) True: 70.00 — Predicted: -35.89
28) True: 70.00 — Predicted: -33.46
2