In [246]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
from PIL import Image
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import numpy as np
from tqdm import tqdm
import pandas as pd
from torchvision import models
from torch.optim.lr_scheduler import _LRScheduler

from sklearn.model_selection import train_test_split
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
from sklearn.model_selection import GridSearchCV
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.svm import SVR
from xgboost import XGBRegressor
from sklearn.ensemble import AdaBoostRegressor, RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import VotingRegressor, BaggingRegressor, GradientBoostingRegressor
from sklearn.ensemble import BaggingRegressor
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr



import warnings
warnings.filterwarnings("ignore")

In [2]:
device = torch.device("mps")

In [3]:
DATA_DIR = '/Users/gufran/Desktop/PfsPredictionLungCancer/data/pathology_patches'

In [4]:
class MultiModalData(Dataset):
    def __init__(self, root_dir, df_path=None, num_patches=10, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.num_patches = num_patches
        self.slide_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.df = pd.read_csv(df_path)
        
        self.df_id = "id"
        self.df.drop(["pfs", "dmp_pt_id" if "dmp_pt_id" in self.df.columns else "pdl1_image_id"],axis=1, inplace=True)

    def __len__(self):
        return len(self.slide_dirs)

    def __getitem__(self, idx):
        slide_dir = self.slide_dirs[idx]
        pfs_label = float(slide_dir.split('_')[-1])

        patch_paths = [os.path.join(self.root_dir, slide_dir, patch) for patch in os.listdir(os.path.join(self.root_dir, slide_dir))]
        patch_paths = patch_paths[:self.num_patches]  # Take the first 'num_patches' patches

        patches = [self.transform(Image.open(patch)) for patch in patch_paths]

        while len(patches) < self.num_patches:
            patches.append(patches[-1])
        
        path_clin_features = np.array(self.df[self.df[self.df_id] == int(slide_dir.split('_')[0])].drop(["id"], axis=1).iloc[0])
        return {'patches': torch.stack(patches), 'pfs_label': torch.tensor(pfs_label, dtype=torch.float32), "pcf": torch.tensor(path_clin_features, dtype=torch.float32)}

In [5]:
transform = transforms.Compose([
    # transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [23]:
batch_size = 8

train_dataset = MultiModalData(root_dir=os.path.join(DATA_DIR, "train"), df_path="/Users/gufran/Desktop/PfsPredictionLungCancer/data/train_path_clin_filtered.csv", transform=transform)
val_dataset = MultiModalData(root_dir=os.path.join(DATA_DIR, "val"), df_path="/Users/gufran/Desktop/PfsPredictionLungCancer/data/val_path_clin_filtered.csv", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [7]:
next(iter(train_loader))["patches"].shape

torch.Size([8, 10, 3, 128, 128])

In [133]:
class DataFusionModel(nn.Module):
    def __init__(self, currently_training=True):
        super(DataFusionModel, self).__init__()
        self.currently_training = currently_training
        
        self.resnet = models.resnet50(pretrained=True)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()
        
        self.mnet = models.mobilenet_v2(pretrained=True)
        mobilenet_in_features = self.mnet.classifier[1].in_features
        self.mnet.classifier = nn.Identity()
        
        self.fc = nn.Linear(in_features + mobilenet_in_features + 15, 256)
        
        self.regressor = nn.Sequential(
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def forward(self, img_batch, pcf_batch):
        image_features_resnet = []
        for patches in img_batch:
            patch_outputs = [self.resnet(patch.unsqueeze(0)) for patch in patches]
            aggregated_output = torch.stack(patch_outputs).mean(dim=0)
            aggregated_output = aggregated_output.squeeze()
            image_features_resnet.append(aggregated_output)
            
        image_features_mnet = []
        for patches in img_batch:
            patch_outputs = [self.mnet(patch.unsqueeze(0)) for patch in patches]
            aggregated_output = torch.stack(patch_outputs).mean(dim=0)
            aggregated_output = aggregated_output.squeeze()
            image_features_mnet.append(aggregated_output)
        
        multimodal_features = []
        for i in range(len(img_batch)):
            multimodal_features.append(torch.cat((image_features_resnet[i], image_features_mnet[i], pcf_batch[i]), dim=0))
        
        multimodal_features = torch.stack(multimodal_features)

        final_features = self.fc(multimodal_features)
        if self.currently_training: final_features = self.regressor(final_features)
        # else: final_features = F.relu(final_features)
        
        return final_features

In [26]:
model = DataFusionModel().to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

num_epochs = 100

In [27]:
class LinearDecayLR(_LRScheduler):
    def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
        self.start_decay=start_decay
        self.n_epoch=n_epoch
        super(LinearDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        last_epoch = self.last_epoch
        n_epoch=self.n_epoch
        b_lr=self.base_lrs[0]
        start_decay=self.start_decay
        if last_epoch>start_decay:
            lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
        else:
            lr=b_lr
        return [lr]
    
lr_scheduler=LinearDecayLR(optimizer, num_epochs, 10)

In [11]:
for epoch in range(num_epochs):
    model.train()

    y_true_train = []
    y_pred_train = []
    
    best_mse = 9999999
    for batch in tqdm(train_loader):
        patches = batch['patches'].to(device)
        pfs_labels = batch['pfs_label'].to(device)
        pcf_data = batch['pcf'].to(device)

        optimizer.zero_grad()
        outputs = model(patches, pcf_data)
        loss = criterion(outputs, pfs_labels)
        loss.backward()
        optimizer.step()

        y_true_train.extend(pfs_labels.cpu().numpy())
        
        for pred_val in outputs.detach().cpu().numpy():
            y_pred_train.append(float(pred_val))
      
    lr_scheduler.step()
    
    mse_train = mean_squared_error(y_true_train, y_pred_train)
    pearson_corr_train, _ = pearsonr(y_true_train, y_pred_train)
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    print(f"MSE: {mse_train} - Pearson Correlation: {pearson_corr_train}")
    print(f"Learning Rate: {lr_scheduler.get_lr()}", end="\n\n")
    
    if mse_train<best_mse:
        best_mse = mse_train
        torch.save(model.state_dict(), f"/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal.pt")

100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [1/100]
MSE: 174.23263907077663 - Pearson Correlation: 0.1250376785978919
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.17s/it]


Epoch [2/100]
MSE: 103.6848539237711 - Pearson Correlation: 0.0798068383199664
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.37s/it]


Epoch [3/100]
MSE: 89.9541158066473 - Pearson Correlation: 0.1639376375182244
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.18s/it]


Epoch [4/100]
MSE: 79.39935971920085 - Pearson Correlation: 0.16562692496645906
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [5/100]
MSE: 77.91795566243046 - Pearson Correlation: 0.14752934724932298
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.33s/it]


Epoch [6/100]
MSE: 77.89744576479038 - Pearson Correlation: 0.14320411514438816
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.34s/it]


Epoch [7/100]
MSE: 78.26358229346023 - Pearson Correlation: 0.1355395153512915
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.18s/it]


Epoch [8/100]
MSE: 79.07952785808143 - Pearson Correlation: 0.08722082493572068
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [9/100]
MSE: 79.4698915066396 - Pearson Correlation: 0.04642090294791669
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.20s/it]


Epoch [10/100]
MSE: 78.1094272561746 - Pearson Correlation: 0.1354197725325443
Learning Rate: [5e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [11/100]
MSE: 78.51849482283001 - Pearson Correlation: 0.1066488460224826
Learning Rate: [4.9444444444444446e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.28s/it]


Epoch [12/100]
MSE: 77.53932646939842 - Pearson Correlation: 0.17573582759787523
Learning Rate: [4.888888888888889e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.41s/it]


Epoch [13/100]
MSE: 77.53215240088001 - Pearson Correlation: 0.1860637546877938
Learning Rate: [4.8333333333333334e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.19s/it]


Epoch [14/100]
MSE: 77.97342552349096 - Pearson Correlation: 0.1458720007006441
Learning Rate: [4.777777777777778e-05]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.45s/it]


Epoch [15/100]
MSE: 78.17561521435609 - Pearson Correlation: 0.14371884259001502
Learning Rate: [4.722222222222222e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.27s/it]


Epoch [16/100]
MSE: 77.13598696163866 - Pearson Correlation: 0.21451442931849718
Learning Rate: [4.666666666666667e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [17/100]
MSE: 77.92872646133436 - Pearson Correlation: 0.1522620711960016
Learning Rate: [4.6111111111111115e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.25s/it]


Epoch [18/100]
MSE: 78.15628239881282 - Pearson Correlation: 0.14680422870391469
Learning Rate: [4.555555555555556e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [19/100]
MSE: 77.49518876261128 - Pearson Correlation: 0.19522217409185114
Learning Rate: [4.5e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.24s/it]


Epoch [20/100]
MSE: 76.46055211195079 - Pearson Correlation: 0.21503611610731155
Learning Rate: [4.4444444444444447e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.37s/it]


Epoch [21/100]
MSE: 79.15759257050274 - Pearson Correlation: 0.14345755744629826
Learning Rate: [4.388888888888889e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.22s/it]


Epoch [22/100]
MSE: 77.65494074712913 - Pearson Correlation: 0.16672460567969044
Learning Rate: [4.3333333333333334e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [23/100]
MSE: 76.81701972008815 - Pearson Correlation: 0.22052895459243105
Learning Rate: [4.277777777777778e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.19s/it]


Epoch [24/100]
MSE: 76.6095801444461 - Pearson Correlation: 0.32343083279757084
Learning Rate: [4.222222222222222e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [25/100]
MSE: 76.83207740340923 - Pearson Correlation: 0.27895220264836396
Learning Rate: [4.166666666666667e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.40s/it]


Epoch [26/100]
MSE: 76.84236513169473 - Pearson Correlation: 0.2795869754730711
Learning Rate: [4.111111111111111e-05]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.44s/it]


Epoch [27/100]
MSE: 75.82365861598258 - Pearson Correlation: 0.308522049871519
Learning Rate: [4.055555555555556e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.18s/it]


Epoch [28/100]
MSE: 75.94458371009507 - Pearson Correlation: 0.35384927852230147
Learning Rate: [4e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [29/100]
MSE: 77.92658027036973 - Pearson Correlation: 0.24465027489944316
Learning Rate: [3.944444444444445e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.29s/it]


Epoch [30/100]
MSE: 75.54914318566627 - Pearson Correlation: 0.33358446774514766
Learning Rate: [3.888888888888889e-05]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.42s/it]


Epoch [31/100]
MSE: 75.89377759796118 - Pearson Correlation: 0.2911629083595744
Learning Rate: [3.8333333333333334e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.19s/it]


Epoch [32/100]
MSE: 76.8209773757571 - Pearson Correlation: 0.30533573231393285
Learning Rate: [3.777777777777778e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.31s/it]


Epoch [33/100]
MSE: 76.9529038484775 - Pearson Correlation: 0.27102394385755824
Learning Rate: [3.722222222222222e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.24s/it]


Epoch [34/100]
MSE: 76.19350369308005 - Pearson Correlation: 0.32709288063651387
Learning Rate: [3.6666666666666666e-05]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.60s/it]


Epoch [35/100]
MSE: 76.63407889949977 - Pearson Correlation: 0.3291046851880026
Learning Rate: [3.611111111111111e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.31s/it]


Epoch [36/100]
MSE: 76.92740058294497 - Pearson Correlation: 0.23943656244770042
Learning Rate: [3.555555555555556e-05]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.56s/it]


Epoch [37/100]
MSE: 74.87494988419637 - Pearson Correlation: 0.43949267364124056
Learning Rate: [3.5e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.29s/it]


Epoch [38/100]
MSE: 78.20121398981152 - Pearson Correlation: 0.18397785140667056
Learning Rate: [3.444444444444445e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [39/100]
MSE: 75.54766273589438 - Pearson Correlation: 0.36816895374110326
Learning Rate: [3.3888888888888884e-05]



100%|███████████████████████████████████████████| 12/12 [01:01<00:00,  5.16s/it]


Epoch [40/100]
MSE: 75.92765837387395 - Pearson Correlation: 0.35587983361232717
Learning Rate: [3.3333333333333335e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.34s/it]


Epoch [41/100]
MSE: 75.44570849811244 - Pearson Correlation: 0.4172387870276071
Learning Rate: [3.277777777777778e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.23s/it]


Epoch [42/100]
MSE: 76.08436108261783 - Pearson Correlation: 0.32842020300915414
Learning Rate: [3.222222222222222e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [43/100]
MSE: 75.05115720372928 - Pearson Correlation: 0.35195182486689663
Learning Rate: [3.1666666666666666e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.20s/it]


Epoch [44/100]
MSE: 75.98777614388294 - Pearson Correlation: 0.4286864830621773
Learning Rate: [3.111111111111111e-05]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.46s/it]


Epoch [45/100]
MSE: 75.22617962766545 - Pearson Correlation: 0.4081106317908021
Learning Rate: [3.055555555555556e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.24s/it]


Epoch [46/100]
MSE: 74.89322608585282 - Pearson Correlation: 0.4698422081517724
Learning Rate: [3e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.39s/it]


Epoch [47/100]
MSE: 75.19271246520383 - Pearson Correlation: 0.4481835793369756
Learning Rate: [2.9444444444444445e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.20s/it]


Epoch [48/100]
MSE: 74.13761760386595 - Pearson Correlation: 0.4731940389163902
Learning Rate: [2.8888888888888888e-05]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.42s/it]


Epoch [49/100]
MSE: 73.69995719239422 - Pearson Correlation: 0.45602064999860425
Learning Rate: [2.8333333333333332e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.18s/it]


Epoch [50/100]
MSE: 75.04995696518407 - Pearson Correlation: 0.4085274532460641
Learning Rate: [2.777777777777778e-05]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.58s/it]


Epoch [51/100]
MSE: 74.86485704683162 - Pearson Correlation: 0.5054611141502366
Learning Rate: [2.7222222222222223e-05]



100%|███████████████████████████████████████████| 12/12 [01:09<00:00,  5.82s/it]


Epoch [52/100]
MSE: 74.47131259630481 - Pearson Correlation: 0.5330366371915976
Learning Rate: [2.6666666666666667e-05]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  5.93s/it]


Epoch [53/100]
MSE: 74.54308194311551 - Pearson Correlation: 0.531908066624032
Learning Rate: [2.611111111111111e-05]



100%|███████████████████████████████████████████| 12/12 [01:09<00:00,  5.79s/it]


Epoch [54/100]
MSE: 74.67507953651064 - Pearson Correlation: 0.5058367022007613
Learning Rate: [2.5555555555555554e-05]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  5.93s/it]


Epoch [55/100]
MSE: 75.03502147329915 - Pearson Correlation: 0.44566045169774166
Learning Rate: [2.4999999999999998e-05]



100%|███████████████████████████████████████████| 12/12 [01:09<00:00,  5.76s/it]


Epoch [56/100]
MSE: 77.01137130798078 - Pearson Correlation: 0.22607645382113556
Learning Rate: [2.4444444444444445e-05]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  5.93s/it]


Epoch [57/100]
MSE: 73.67654550818592 - Pearson Correlation: 0.5786573645726906
Learning Rate: [2.388888888888889e-05]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.53s/it]


Epoch [58/100]
MSE: 73.83004153235377 - Pearson Correlation: 0.5998905902898753
Learning Rate: [2.3333333333333332e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.37s/it]


Epoch [59/100]
MSE: 73.10392773630774 - Pearson Correlation: 0.603703379040097
Learning Rate: [2.2777777777777776e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.24s/it]


Epoch [60/100]
MSE: 72.76904022827378 - Pearson Correlation: 0.5882490363080606
Learning Rate: [2.222222222222222e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [61/100]
MSE: 73.08164310267729 - Pearson Correlation: 0.6738229336031474
Learning Rate: [2.1666666666666667e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.21s/it]


Epoch [62/100]
MSE: 74.04661990938159 - Pearson Correlation: 0.651744141944015
Learning Rate: [2.111111111111111e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [63/100]
MSE: 74.10116917228402 - Pearson Correlation: 0.6536858830097851
Learning Rate: [2.0555555555555555e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.25s/it]


Epoch [64/100]
MSE: 73.90082350788332 - Pearson Correlation: 0.7053872688429368
Learning Rate: [1.9999999999999998e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [65/100]
MSE: 73.24519692390176 - Pearson Correlation: 0.6988582548459459
Learning Rate: [1.9444444444444442e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.22s/it]


Epoch [66/100]
MSE: 73.08699547638773 - Pearson Correlation: 0.671015951581321
Learning Rate: [1.8888888888888886e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [67/100]
MSE: 72.8329658293521 - Pearson Correlation: 0.6826247180629346
Learning Rate: [1.833333333333333e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.21s/it]


Epoch [68/100]
MSE: 73.43711520472812 - Pearson Correlation: 0.7008750835684969
Learning Rate: [1.7777777777777773e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [69/100]
MSE: 73.03404598113688 - Pearson Correlation: 0.7314333944137822
Learning Rate: [1.7222222222222224e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.24s/it]


Epoch [70/100]
MSE: 73.14925267475684 - Pearson Correlation: 0.6926621562855637
Learning Rate: [1.6666666666666667e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.30s/it]


Epoch [71/100]
MSE: 72.90716094607897 - Pearson Correlation: 0.7138298988530709
Learning Rate: [1.611111111111111e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.25s/it]


Epoch [72/100]
MSE: 73.28138288819474 - Pearson Correlation: 0.6958395658578677
Learning Rate: [1.5555555555555555e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.27s/it]


Epoch [73/100]
MSE: 72.65938877137329 - Pearson Correlation: 0.7379015294480301
Learning Rate: [1.4999999999999999e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.25s/it]


Epoch [74/100]
MSE: 73.36303335155543 - Pearson Correlation: 0.7111660121591132
Learning Rate: [1.4444444444444442e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.31s/it]


Epoch [75/100]
MSE: 73.0793942614066 - Pearson Correlation: 0.7158901061797323
Learning Rate: [1.3888888888888886e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.26s/it]


Epoch [76/100]
MSE: 72.84926629810288 - Pearson Correlation: 0.7256047511910007
Learning Rate: [1.333333333333333e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.40s/it]


Epoch [77/100]
MSE: 72.8776174636464 - Pearson Correlation: 0.6671497738823072
Learning Rate: [1.2777777777777774e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.31s/it]


Epoch [78/100]
MSE: 72.48884771533314 - Pearson Correlation: 0.7163734839877434
Learning Rate: [1.2222222222222217e-05]



100%|███████████████████████████████████████████| 12/12 [01:02<00:00,  5.23s/it]


Epoch [79/100]
MSE: 74.05442265280738 - Pearson Correlation: 0.5561801883064533
Learning Rate: [1.1666666666666661e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [80/100]
MSE: 72.87287574895393 - Pearson Correlation: 0.7216256143305664
Learning Rate: [1.1111111111111112e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [81/100]
MSE: 72.26193249339892 - Pearson Correlation: 0.7514610877610912
Learning Rate: [1.0555555555555555e-05]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.56s/it]


Epoch [82/100]
MSE: 72.00038586733159 - Pearson Correlation: 0.7835098814994926
Learning Rate: [9.999999999999999e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.42s/it]


Epoch [83/100]
MSE: 72.20820457097516 - Pearson Correlation: 0.7830241022902007
Learning Rate: [9.444444444444443e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [84/100]
MSE: 72.1698918278113 - Pearson Correlation: 0.8074730731006808
Learning Rate: [8.888888888888887e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.34s/it]


Epoch [85/100]
MSE: 71.94119989690724 - Pearson Correlation: 0.7871696590300575
Learning Rate: [8.33333333333333e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [86/100]
MSE: 72.23658500241817 - Pearson Correlation: 0.8007752560582847
Learning Rate: [7.777777777777774e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.28s/it]


Epoch [87/100]
MSE: 72.01007183753137 - Pearson Correlation: 0.8064283860061023
Learning Rate: [7.222222222222218e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.37s/it]


Epoch [88/100]
MSE: 71.91273126986441 - Pearson Correlation: 0.8027991729471121
Learning Rate: [6.6666666666666616e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.47s/it]


Epoch [89/100]
MSE: 71.8665059141885 - Pearson Correlation: 0.8043151631741561
Learning Rate: [6.111111111111105e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.60s/it]


Epoch [90/100]
MSE: 72.01275586380777 - Pearson Correlation: 0.8061694649087342
Learning Rate: [5.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:08<00:00,  5.70s/it]


Epoch [91/100]
MSE: 71.95230046024471 - Pearson Correlation: 0.8084453411716838
Learning Rate: [4.9999999999999996e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.62s/it]


Epoch [92/100]
MSE: 71.7467348341079 - Pearson Correlation: 0.8086958712826973
Learning Rate: [4.444444444444443e-06]



100%|███████████████████████████████████████████| 12/12 [01:08<00:00,  5.73s/it]


Epoch [93/100]
MSE: 71.62751838798953 - Pearson Correlation: 0.8222832069529389
Learning Rate: [3.888888888888887e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.63s/it]


Epoch [94/100]
MSE: 71.56686353731057 - Pearson Correlation: 0.8238828903961928
Learning Rate: [3.3333333333333308e-06]



100%|███████████████████████████████████████████| 12/12 [01:09<00:00,  5.82s/it]


Epoch [95/100]
MSE: 71.54359299558172 - Pearson Correlation: 0.8305595293637439
Learning Rate: [2.7777777777777745e-06]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  5.98s/it]


Epoch [96/100]
MSE: 71.61176883473658 - Pearson Correlation: 0.8321286794952172
Learning Rate: [2.2222222222222183e-06]



100%|███████████████████████████████████████████| 12/12 [01:13<00:00,  6.16s/it]


Epoch [97/100]
MSE: 71.69255581548295 - Pearson Correlation: 0.8291619895336405
Learning Rate: [1.666666666666662e-06]



100%|███████████████████████████████████████████| 12/12 [01:16<00:00,  6.40s/it]


Epoch [98/100]
MSE: 71.53693732162591 - Pearson Correlation: 0.833741445759859
Learning Rate: [1.1111111111111057e-06]



100%|███████████████████████████████████████████| 12/12 [01:10<00:00,  5.87s/it]


Epoch [99/100]
MSE: 71.51733206104463 - Pearson Correlation: 0.8353437301656723
Learning Rate: [5.555555555555495e-07]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [100/100]
MSE: 71.49900769051669 - Pearson Correlation: 0.8349957451865692
Learning Rate: [-6.776263578034403e-21]



In [30]:
model = DataFusionModel().to(device)
model.load_state_dict(torch.load("/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal.pt"))

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

num_epochs = 100

lr_scheduler=LinearDecayLR(optimizer, num_epochs, 10)

for epoch in range(num_epochs):
    model.train()

    y_true_train = []
    y_pred_train = []
    
    best_mse = 71.49900769051669
    for batch in tqdm(train_loader):
        patches = batch['patches'].to(device)
        pfs_labels = batch['pfs_label'].to(device)
        pcf_data = batch['pcf'].to(device)

        optimizer.zero_grad()
        outputs = model(patches, pcf_data)
        loss = criterion(outputs, pfs_labels)
        loss.backward()
        optimizer.step()

        y_true_train.extend(pfs_labels.cpu().numpy())
        
        for pred_val in outputs.detach().cpu().numpy():
            y_pred_train.append(float(pred_val))
      
    lr_scheduler.step()
    
    mse_train = mean_squared_error(y_true_train, y_pred_train)
    pearson_corr_train, _ = pearsonr(y_true_train, y_pred_train)
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    print(f"MSE: {mse_train} - Pearson Correlation: {pearson_corr_train}")
    print(f"Learning Rate: {lr_scheduler.get_lr()}", end="\n\n")
    
    if mse_train<best_mse:
        best_mse = mse_train
        torch.save(model.state_dict(), f"/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal_post.pt")

100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.43s/it]


Epoch [1/100]
MSE: 72.04992576772663 - Pearson Correlation: 0.6396412682269137
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.40s/it]


Epoch [2/100]
MSE: 71.40840023543183 - Pearson Correlation: 0.8058354704341856
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.60s/it]


Epoch [3/100]
MSE: 71.50424801580701 - Pearson Correlation: 0.8194996004888343
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.30s/it]


Epoch [4/100]
MSE: 71.47579552613811 - Pearson Correlation: 0.8005195635927533
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.57s/it]


Epoch [5/100]
MSE: 71.38600025356047 - Pearson Correlation: 0.8222239509346302
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.27s/it]


Epoch [6/100]
MSE: 71.990726506076 - Pearson Correlation: 0.8179873405113912
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.59s/it]


Epoch [7/100]
MSE: 71.78531385160542 - Pearson Correlation: 0.8140082320562266
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.25s/it]


Epoch [8/100]
MSE: 71.15970103690161 - Pearson Correlation: 0.7479374003575195
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.52s/it]


Epoch [9/100]
MSE: 71.27849685668524 - Pearson Correlation: 0.7622184560524998
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.61s/it]


Epoch [10/100]
MSE: 71.41537656342375 - Pearson Correlation: 0.7697376925886051
Learning Rate: [1e-05]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.49s/it]


Epoch [11/100]
MSE: 71.23168986739442 - Pearson Correlation: 0.8069832571610904
Learning Rate: [9.888888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.57s/it]


Epoch [12/100]
MSE: 71.22362140601567 - Pearson Correlation: 0.7338641974352145
Learning Rate: [9.777777777777779e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.39s/it]


Epoch [13/100]
MSE: 71.13423256626261 - Pearson Correlation: 0.7354017037204627
Learning Rate: [9.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.59s/it]


Epoch [14/100]
MSE: 71.3846499152818 - Pearson Correlation: 0.8323995239738357
Learning Rate: [9.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.47s/it]


Epoch [15/100]
MSE: 71.06439963219219 - Pearson Correlation: 0.8427193771350274
Learning Rate: [9.444444444444445e-06]



100%|███████████████████████████████████████████| 12/12 [01:08<00:00,  5.71s/it]


Epoch [16/100]
MSE: 71.5955381163756 - Pearson Correlation: 0.8309032097603677
Learning Rate: [9.333333333333334e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.45s/it]


Epoch [17/100]
MSE: 71.6279570242259 - Pearson Correlation: 0.8381520146959214
Learning Rate: [9.222222222222222e-06]



100%|███████████████████████████████████████████| 12/12 [01:08<00:00,  5.73s/it]


Epoch [18/100]
MSE: 70.91857035698854 - Pearson Correlation: 0.8582301007207753
Learning Rate: [9.111111111111112e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [19/100]
MSE: 70.81551462465202 - Pearson Correlation: 0.8235113380895607
Learning Rate: [9e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.48s/it]


Epoch [20/100]
MSE: 70.72130269306791 - Pearson Correlation: 0.7854649091387901
Learning Rate: [8.88888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.39s/it]


Epoch [21/100]
MSE: 70.8143895003784 - Pearson Correlation: 0.8028127694547411
Learning Rate: [8.777777777777778e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.46s/it]


Epoch [22/100]
MSE: 71.13243667901563 - Pearson Correlation: 0.85032554816621
Learning Rate: [8.666666666666668e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.54s/it]


Epoch [23/100]
MSE: 71.27027758792173 - Pearson Correlation: 0.8587675983306532
Learning Rate: [8.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.34s/it]


Epoch [24/100]
MSE: 70.97246850729515 - Pearson Correlation: 0.8552370362892824
Learning Rate: [8.444444444444446e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.48s/it]


Epoch [25/100]
MSE: 70.89352418988477 - Pearson Correlation: 0.8260649862919666
Learning Rate: [8.333333333333334e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.34s/it]


Epoch [26/100]
MSE: 70.87856864693741 - Pearson Correlation: 0.8395218880026926
Learning Rate: [8.222222222222223e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.56s/it]


Epoch [27/100]
MSE: 71.01960356107148 - Pearson Correlation: 0.8477614662589092
Learning Rate: [8.111111111111112e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.31s/it]


Epoch [28/100]
MSE: 71.07020422322748 - Pearson Correlation: 0.8342613316659031
Learning Rate: [8.000000000000001e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.61s/it]


Epoch [29/100]
MSE: 71.41267927010684 - Pearson Correlation: 0.82836105836822
Learning Rate: [7.88888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [30/100]
MSE: 71.42501462245168 - Pearson Correlation: 0.8256569417520134
Learning Rate: [7.777777777777777e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.46s/it]


Epoch [31/100]
MSE: 70.65997663006051 - Pearson Correlation: 0.8353722423041772
Learning Rate: [7.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.43s/it]


Epoch [32/100]
MSE: 70.6286581397296 - Pearson Correlation: 0.8462807100828443
Learning Rate: [7.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.39s/it]


Epoch [33/100]
MSE: 70.50107994740836 - Pearson Correlation: 0.8413607709735521
Learning Rate: [7.444444444444445e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.47s/it]


Epoch [34/100]
MSE: 70.31064176201188 - Pearson Correlation: 0.816495226451709
Learning Rate: [7.333333333333334e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.37s/it]


Epoch [35/100]
MSE: 70.35041688373754 - Pearson Correlation: 0.8443016746063308
Learning Rate: [7.222222222222223e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.52s/it]


Epoch [36/100]
MSE: 70.43196607068117 - Pearson Correlation: 0.8627392247238188
Learning Rate: [7.111111111111112e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [37/100]
MSE: 70.94217968609368 - Pearson Correlation: 0.8507573330509859
Learning Rate: [7.000000000000001e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.55s/it]


Epoch [38/100]
MSE: 70.66004557187426 - Pearson Correlation: 0.8642580167077997
Learning Rate: [6.888888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.30s/it]


Epoch [39/100]
MSE: 70.68906251188115 - Pearson Correlation: 0.8528247920649664
Learning Rate: [6.777777777777779e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.63s/it]


Epoch [40/100]
MSE: 70.36241694386422 - Pearson Correlation: 0.8715112920365371
Learning Rate: [6.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.27s/it]


Epoch [41/100]
MSE: 70.47030004627486 - Pearson Correlation: 0.879905983683281
Learning Rate: [6.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.50s/it]


Epoch [42/100]
MSE: 70.3823342161246 - Pearson Correlation: 0.8866366520013421
Learning Rate: [6.4444444444444445e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.37s/it]


Epoch [43/100]
MSE: 70.38459227072786 - Pearson Correlation: 0.8886362245685064
Learning Rate: [6.333333333333334e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.50s/it]


Epoch [44/100]
MSE: 70.42634873266655 - Pearson Correlation: 0.8909330784223877
Learning Rate: [6.222222222222222e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.62s/it]


Epoch [45/100]
MSE: 70.36058372496423 - Pearson Correlation: 0.8875358457571679
Learning Rate: [6.111111111111111e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.44s/it]


Epoch [46/100]
MSE: 69.9631014111765 - Pearson Correlation: 0.8876010883771789
Learning Rate: [6e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.63s/it]


Epoch [47/100]
MSE: 70.24455596243466 - Pearson Correlation: 0.8869990348063979
Learning Rate: [5.888888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.31s/it]


Epoch [48/100]
MSE: 70.01032448988929 - Pearson Correlation: 0.8974880794258966
Learning Rate: [5.777777777777778e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.58s/it]


Epoch [49/100]
MSE: 69.9850459766836 - Pearson Correlation: 0.8966386355403072
Learning Rate: [5.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.32s/it]


Epoch [50/100]
MSE: 69.86455293527399 - Pearson Correlation: 0.9025348555595396
Learning Rate: [5.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:08<00:00,  5.67s/it]


Epoch [51/100]
MSE: 69.76006938345472 - Pearson Correlation: 0.9011407346608095
Learning Rate: [5.444444444444445e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.34s/it]


Epoch [52/100]
MSE: 69.76254010271778 - Pearson Correlation: 0.9018928116226436
Learning Rate: [5.333333333333334e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.51s/it]


Epoch [53/100]
MSE: 70.09558311587506 - Pearson Correlation: 0.8856008974300913
Learning Rate: [5.2222222222222226e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.39s/it]


Epoch [54/100]
MSE: 69.85514095048218 - Pearson Correlation: 0.8963305866823449
Learning Rate: [5.1111111111111115e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.42s/it]


Epoch [55/100]
MSE: 69.68263299385329 - Pearson Correlation: 0.8900048355327761
Learning Rate: [5e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.49s/it]


Epoch [56/100]
MSE: 69.70115072670771 - Pearson Correlation: 0.8805536755328571
Learning Rate: [4.888888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [57/100]
MSE: 69.46152547896222 - Pearson Correlation: 0.8754134649183246
Learning Rate: [4.777777777777778e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.57s/it]


Epoch [58/100]
MSE: 69.71114077553221 - Pearson Correlation: 0.8896311496460947
Learning Rate: [4.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [59/100]
MSE: 69.73093044724087 - Pearson Correlation: 0.8989326380264414
Learning Rate: [4.555555555555556e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.60s/it]


Epoch [60/100]
MSE: 69.5799862684293 - Pearson Correlation: 0.897518011515794
Learning Rate: [4.444444444444445e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.28s/it]


Epoch [61/100]
MSE: 69.71337620317392 - Pearson Correlation: 0.9093165440665916
Learning Rate: [4.333333333333334e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.63s/it]


Epoch [62/100]
MSE: 69.99297215476867 - Pearson Correlation: 0.8969708617089371
Learning Rate: [4.222222222222223e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.36s/it]


Epoch [63/100]
MSE: 70.10891611952792 - Pearson Correlation: 0.8647128556651807
Learning Rate: [4.111111111111112e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.55s/it]


Epoch [64/100]
MSE: 70.34545547615265 - Pearson Correlation: 0.8387546141302821
Learning Rate: [4e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.46s/it]


Epoch [65/100]
MSE: 70.40183799699315 - Pearson Correlation: 0.8415462869653094
Learning Rate: [3.888888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.41s/it]


Epoch [66/100]
MSE: 69.79539001231531 - Pearson Correlation: 0.8902126591791115
Learning Rate: [3.7777777777777777e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.49s/it]


Epoch [67/100]
MSE: 69.5986448978567 - Pearson Correlation: 0.9000110359261068
Learning Rate: [3.6666666666666666e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.40s/it]


Epoch [68/100]
MSE: 69.65577195594047 - Pearson Correlation: 0.8947819369253357
Learning Rate: [3.5555555555555555e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.53s/it]


Epoch [69/100]
MSE: 69.324014246819 - Pearson Correlation: 0.9056557571028266
Learning Rate: [3.4444444444444444e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.33s/it]


Epoch [70/100]
MSE: 69.3717485758897 - Pearson Correlation: 0.9038198991342198
Learning Rate: [3.3333333333333333e-06]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.61s/it]


Epoch [71/100]
MSE: 69.65138651381577 - Pearson Correlation: 0.9035169775933836
Learning Rate: [3.2222222222222222e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.26s/it]


Epoch [72/100]
MSE: 69.32958397193754 - Pearson Correlation: 0.9099260984146247
Learning Rate: [3.111111111111111e-06]



100%|███████████████████████████████████████████| 12/12 [01:08<00:00,  5.68s/it]


Epoch [73/100]
MSE: 69.28275038841493 - Pearson Correlation: 0.8944190081380218
Learning Rate: [3e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.33s/it]


Epoch [74/100]
MSE: 69.16152802978137 - Pearson Correlation: 0.896002702533736
Learning Rate: [2.888888888888889e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.47s/it]


Epoch [75/100]
MSE: 69.33312334833498 - Pearson Correlation: 0.9054628805653
Learning Rate: [2.777777777777778e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.35s/it]


Epoch [76/100]
MSE: 69.26908651415303 - Pearson Correlation: 0.9074988864320427
Learning Rate: [2.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.40s/it]


Epoch [77/100]
MSE: 69.28863250475308 - Pearson Correlation: 0.9093996122552304
Learning Rate: [2.5555555555555557e-06]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.46s/it]


Epoch [78/100]
MSE: 69.25525567693177 - Pearson Correlation: 0.907936859987348
Learning Rate: [2.4444444444444447e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.41s/it]


Epoch [79/100]
MSE: 69.26283888894764 - Pearson Correlation: 0.9084234000239021
Learning Rate: [2.3333333333333336e-06]



100%|███████████████████████████████████████████| 12/12 [01:09<00:00,  5.82s/it]


Epoch [80/100]
MSE: 69.25229030734998 - Pearson Correlation: 0.9111618298216327
Learning Rate: [2.2222222222222217e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.28s/it]


Epoch [81/100]
MSE: 69.27456915244926 - Pearson Correlation: 0.9107408341209426
Learning Rate: [2.1111111111111114e-06]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.52s/it]


Epoch [82/100]
MSE: 69.27591534931622 - Pearson Correlation: 0.9112088715500037
Learning Rate: [1.9999999999999995e-06]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.30s/it]


Epoch [83/100]
MSE: 69.22548796988224 - Pearson Correlation: 0.9077277513827463
Learning Rate: [1.8888888888888893e-06]



100%|███████████████████████████████████████████| 12/12 [01:14<00:00,  6.21s/it]


Epoch [84/100]
MSE: 69.15818388472667 - Pearson Correlation: 0.9092328091481908
Learning Rate: [1.7777777777777773e-06]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  5.92s/it]


Epoch [85/100]
MSE: 69.14906353641686 - Pearson Correlation: 0.9100899390508379
Learning Rate: [1.666666666666667e-06]



100%|███████████████████████████████████████████| 12/12 [01:12<00:00,  6.08s/it]


Epoch [86/100]
MSE: 69.11250944147186 - Pearson Correlation: 0.9069654576592887
Learning Rate: [1.5555555555555552e-06]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  6.00s/it]


Epoch [87/100]
MSE: 69.06483740702417 - Pearson Correlation: 0.9071548319086588
Learning Rate: [1.444444444444445e-06]



100%|███████████████████████████████████████████| 12/12 [01:12<00:00,  6.02s/it]


Epoch [88/100]
MSE: 69.10608282814377 - Pearson Correlation: 0.9086028114482246
Learning Rate: [1.333333333333333e-06]



100%|███████████████████████████████████████████| 12/12 [01:12<00:00,  6.02s/it]


Epoch [89/100]
MSE: 69.12836480505709 - Pearson Correlation: 0.9114859451117415
Learning Rate: [1.2222222222222228e-06]



100%|███████████████████████████████████████████| 12/12 [01:11<00:00,  5.95s/it]


Epoch [90/100]
MSE: 69.14566042804982 - Pearson Correlation: 0.9109221581374147
Learning Rate: [1.1111111111111108e-06]



100%|███████████████████████████████████████████| 12/12 [01:12<00:00,  6.06s/it]


Epoch [91/100]
MSE: 69.16383656099008 - Pearson Correlation: 0.9144395175019036
Learning Rate: [1.0000000000000006e-06]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.33s/it]


Epoch [92/100]
MSE: 69.20131721715727 - Pearson Correlation: 0.9149564537690525
Learning Rate: [8.888888888888887e-07]



100%|███████████████████████████████████████████| 12/12 [01:06<00:00,  5.57s/it]


Epoch [93/100]
MSE: 69.20160197245657 - Pearson Correlation: 0.9154533138673849
Learning Rate: [7.777777777777784e-07]



100%|███████████████████████████████████████████| 12/12 [01:03<00:00,  5.27s/it]


Epoch [94/100]
MSE: 69.19100169953443 - Pearson Correlation: 0.9162526413075558
Learning Rate: [6.666666666666665e-07]



100%|███████████████████████████████████████████| 12/12 [01:07<00:00,  5.62s/it]


Epoch [95/100]
MSE: 69.21287701990327 - Pearson Correlation: 0.9156327829101465
Learning Rate: [5.555555555555563e-07]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.45s/it]


Epoch [96/100]
MSE: 69.19831856656872 - Pearson Correlation: 0.9160851275572897
Learning Rate: [4.4444444444444433e-07]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.46s/it]


Epoch [97/100]
MSE: 69.18624407709989 - Pearson Correlation: 0.9159565148335377
Learning Rate: [3.333333333333324e-07]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.38s/it]


Epoch [98/100]
MSE: 69.19580020934723 - Pearson Correlation: 0.9154263682197712
Learning Rate: [2.2222222222222217e-07]



100%|███████████████████████████████████████████| 12/12 [01:04<00:00,  5.41s/it]


Epoch [99/100]
MSE: 69.18132682285652 - Pearson Correlation: 0.9156773205400714
Learning Rate: [1.1111111111111024e-07]



100%|███████████████████████████████████████████| 12/12 [01:05<00:00,  5.49s/it]


Epoch [100/100]
MSE: 69.18152593553712 - Pearson Correlation: 0.9157353772478032
Learning Rate: [0.0]



In [429]:
model = DataFusionModel().to(device)
# model.load_state_dict(torch.load("/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal.pt"))
model.load_state_dict(torch.load("/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal_post.pt"))

<All keys matched successfully>

In [430]:
model.eval()
y_true_val = []
y_pred_val = []

for batch in tqdm(val_loader):
    patches = batch['patches'].to(device)
    pfs_labels = batch['pfs_label'].to(device)
    pcf_data = batch['pcf'].to(device)

    outputs = model(patches, pcf_data)
    y_true_val.extend(pfs_labels.cpu().numpy())
    
    for pred_val in outputs.detach().cpu().numpy():
        y_pred_val.append(float(pred_val))

mse_val = mean_squared_error(y_true_val, y_pred_val)
pearson_corr_val, _ = pearsonr(y_true_val, y_pred_val)
print(f"MSE: {mse_val} - Pearson Correlation: {pearson_corr_val}")

 43%|███████████████████▎                         | 3/7 [00:05<00:07,  1.95s/it]


KeyboardInterrupt: 

In [431]:
def create_df(model, loader):
    model_ = DataFusionModel(currently_training=False).to(device)
    model_.load_state_dict(model.state_dict())

    model_.eval()

    features_list = []
    pfs_values_list = []
    for batch in tqdm(loader):
        patches = batch['patches'].to(device)
        pfs_labels = batch['pfs_label'].to(device)
        pcf_data = batch['pcf'].to(device)

        with torch.no_grad():
            outputs = model_(patches, pcf_data)
            
            for o in outputs:
                features_list.append(list(o.cpu().numpy()))
        
        pfs_values_list.extend(pfs_labels.cpu().numpy())
        
    df = pd.DataFrame(data=features_list, columns=[f'feature_{i}' for i in range(len(features_list[0]))])
    df['pfs'] = pfs_values_list

    return df

model = DataFusionModel()
# model.load_state_dict(torch.load("/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal.pt"))
model.load_state_dict(torch.load("/Users/gufran/Desktop/PfsPredictionLungCancer/notebooks/checkpoint/multimodal_post.pt"))

dftr = create_df(model, train_loader)
dfte = create_df(model, val_loader)

100%|███████████████████████████████████████████| 12/12 [00:18<00:00,  1.51s/it]
100%|█████████████████████████████████████████████| 7/7 [00:10<00:00,  1.44s/it]


In [432]:
dftr.shape,dfte.shape

((94, 257), (50, 257))

In [312]:
from sklearn.preprocessing import MinMaxScaler

In [369]:
# dftr.to_csv("/Users/gufran/Desktop/PfsPredictionLungCancer/data/multimodal_train.csv", index=False)
# dfte.to_csv("/Users/gufran/Desktop/PfsPredictionLungCancer/data/multimodal_val.csv", index=False)

In [433]:
correlation_threshold = 0.01
correlations = dftr.drop("pfs", axis=1).corrwith(dftr['pfs'])
selected_columns = correlations[(correlations >= correlation_threshold) | (correlations <= -correlation_threshold)].index

selected_columns = list(selected_columns)
# selected_columns

In [434]:
for sc in selected_columns:
    if sc in ["dmp_pt_id", "pfs"]: continue
    print(f"Correlation with {sc}: {dftr['pfs'].corr(dftr[sc])}")

Correlation with feature_0: 0.10029362259466426
Correlation with feature_1: -0.13271471427599973
Correlation with feature_2: -0.09982973758072936
Correlation with feature_3: -0.12855659023538688
Correlation with feature_4: 0.11845733146837749
Correlation with feature_5: -0.12640003813339756
Correlation with feature_6: 0.1197360294789298
Correlation with feature_7: -0.12649468462515984
Correlation with feature_8: -0.13113971337685268
Correlation with feature_9: 0.15578891084159424
Correlation with feature_10: 0.1393835395433827
Correlation with feature_11: -0.15029075916342893
Correlation with feature_12: -0.17499163926315653
Correlation with feature_13: -0.1303680371988665
Correlation with feature_14: -0.12325858143177537
Correlation with feature_15: -0.12997120196023348
Correlation with feature_16: -0.14257498006767808
Correlation with feature_17: -0.14664703979795815
Correlation with feature_18: 0.15611935369007796
Correlation with feature_19: 0.13629201945111172
Correlation with fea

In [435]:
# dftr_s = dftr[selected_columns+["pfs"]]
# dfte_s = dfte[selected_columns+["pfs"]]

dftr_s = dftr
dfte_s = dfte

In [436]:
X_train, y_train = dftr_s.drop(["pfs"], axis=1), dftr_s["pfs"]
X_test, y_test = dfte_s.drop(["pfs"], axis=1), dfte_s["pfs"]

In [437]:
classifiers_reg = {
    'SVM': {
        'name': 'Support Vector Machine',
        'classifier': SVR(),
        'param_grid': {'C': [0.1, 1.0, 5.0, 10.0], 'kernel': ['linear', 'rbf']}
    },
    'XGBoost': {
        'name': 'XGBoost',
        'classifier': XGBRegressor(),
        'param_grid': {'n_estimators': [50, 100, 200], 'max_depth': [3, 4, 5]}
    },
    'AdaBoost': {
        'name': 'AdaBoost',
        'classifier': AdaBoostRegressor(),
        'param_grid': {'n_estimators': [50, 100, 200], 'learning_rate': [0.1, 0.5, 1.0]}
    },
    'RandomForest': {
        'name': 'Random Forest',
        'classifier': RandomForestRegressor(),
        'param_grid': {'n_estimators': [50, 100, 200], 'max_depth': [None, 10, 20]}
    },
    'DecisionTree': {
        'name': 'Decision Tree',
        'classifier': DecisionTreeRegressor(),
        'param_grid': {'max_depth': [None, 10, 20]}
    }
}

In [401]:
# best_models = {}

# for clf_name, clf_info in classifiers_reg.items():
#     print(f"Performing GridSearchCV for {clf_info['name']}...")
    
#     clf = clf_info['classifier']
#     param_grid = clf_info['param_grid']
    
#     grid_search = GridSearchCV(clf, param_grid, cv=5, scoring='neg_mean_squared_error', n_jobs=-1)
#     grid_search.fit(X_train, y_train)

#     best_models[clf_name] = grid_search.best_estimator_

# print()

# for clf_name, best_model in best_models.items():
#     print(f"Evaluating {clf_name} on test data...")
    
#     y_pred = best_model.predict(X_test)

#     mse = mean_squared_error(y_test, y_pred)
#     pearson_coefficient, _ = pearsonr(y_test, y_pred)
    
#     print(f"Mean Squared Error on Test Data: {mse:.4f}")
#     print(f"Pearson Correlation Coefficient on Test Data: {pearson_coefficient:.4f}\n")

In [441]:
from sklearn.ensemble import BaggingRegressor

bagging_reg = BaggingRegressor(base_estimator=SVR(kernel='rbf', epsilon=0.2, C=1.0), n_estimators=2, random_state=42)  # You can customize parameters
bagging_reg.fit(X_train, y_train)

y_pred_bagging = bagging_reg.predict(X_test)
mse_bagging = mean_squared_error(y_test, y_pred_bagging)
pearson_coefficient_bagging, _ = pearsonr(y_test, y_pred_bagging)

# if mse_bagging<40:
print(f"Mean Squared Error on Test Data (Bagging): {mse_bagging:.4f}")
print(f"Pearson Correlation Coefficient on Test Data (Bagging): {pearson_coefficient_bagging:.4f}\n")

Mean Squared Error on Test Data (Bagging): 35.6142
Pearson Correlation Coefficient on Test Data (Bagging): 0.3929



In [287]:
best_models["BaggingRegressor"] = bagging_reg

In [288]:
best_models.items()

dict_items([('SVM', SVR(C=0.1)), ('AdaBoost', AdaBoostRegressor(learning_rate=0.1, n_estimators=100)), ('RandomForest', RandomForestRegressor(max_depth=10, n_estimators=200)), ('DecisionTree', DecisionTreeRegressor()), ('BaggingRegressor', BaggingRegressor(base_estimator=SVR(), n_estimators=2, random_state=42))])

In [388]:
from sklearn.ensemble import VotingRegressor

voting_reg = VotingRegressor(estimators=list(best_models.items()))
voting_reg.fit(X_train, y_train)

y_pred_voting = voting_reg.predict(X_test)
mse_voting = mean_squared_error(y_test, y_pred_voting)
pearson_coefficient_voting, _ = pearsonr(y_test, y_pred_voting)

print(f"Mean Squared Error on Test Data (Voting): {mse_voting:.4f}")
print(f"Pearson Correlation Coefficient on Test Data (Voting): {pearson_coefficient_voting:.4f}\n")

Mean Squared Error on Test Data (Voting): 67.7944
Pearson Correlation Coefficient on Test Data (Voting): -0.0827

