In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm.notebook import tqdm
import random
import pickle
from collections import Counter


In [2]:
csv_path = "C:/Users/HP/Downloads/GLEASON/Train_with_Gleason_and_Labels.csv"
df = pd.read_csv(csv_path)

df['image_path'] = df['image_path'].apply(lambda x: os.path.join('C:/Users/HP/Downloads/GLEASON/Train Imgs', os.path.basename(x)))

def map_risk_label(gleason_score):
    if gleason_score in [0, 6]:
        return 0
    elif gleason_score == 7:
        return 1
    elif gleason_score in [8, 9, 10]:
        return 2
    else:
        return -1

df['risk_label'] = df['gleason_score'].apply(map_risk_label)
df = df[df['risk_label'] != -1]

In [3]:
class MRIOnlyDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = Image.open(row['image_path']).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(row['risk_label']).long()
        return image, label

In [4]:
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [5]:
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['risk_label'], random_state=42)
train_dataset = MRIOnlyDataset(train_df, transform=image_transform)
val_dataset = MRIOnlyDataset(val_df, transform=image_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CNN3ChannelDeep(nn.Module):
    def __init__(self):
        super(CNN3ChannelDeep, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 3)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = CNN3ChannelDeep().to(device)

In [11]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

if os.path.exists("train_labels.pkl"):
    with open("train_labels.pkl", "rb") as f:
        train_labels = pickle.load(f)
else:
    train_labels = [int(label) for _, label in tqdm(train_dataset)]
    with open("train_labels.pkl", "wb") as f:
        pickle.dump(train_labels, f)

print(Counter(train_labels))

class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(train_labels),
                                     y=np.array(train_labels))
weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

Counter({2: 383, 0: 284, 1: 269})


In [15]:
from sklearn.metrics import confusion_matrix
best_f1 = 0.0
patience = 5
patience_counter = 0
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(weight=weights_tensor)

for epoch in range(1, 21):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # === Evaluation ===
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    cm = confusion_matrix(all_labels, all_preds)

    print("\n" + "="*50)
    print(f"Epoch {epoch} Summary:")
    print(f"Train Loss: {running_loss:.4f}")
    print(f"Val Acc:   {acc:.4f}")
    print(f"Val F1:    {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=["No/Low Risk", "Intermediate", "High"], zero_division=0))

    print("Prediction Distribution:", torch.bincount(torch.tensor(all_preds)).tolist())

    print("\nConfusion Matrix:")
    print(f"{'':15}{'Pred 0':>10}{'Pred 1':>10}{'Pred 2':>10}")
    for i, row in enumerate(cm):
        print(f"True {i:<11}" + "".join(f"{val:10d}" for val in row))
    print("="*50 + "\n")

    if f1 > best_f1:
        best_f1 = f1
        patience_counter = 0
        torch.save(model.state_dict(), "best_model_mri_only.pt")
        print("Best model saved.")
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print("Early stopping triggered.")
        break


Epoch 1:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 1 Summary:
Train Loss: 60.0976
Val Acc:   0.4936
Val F1:    0.4386

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.63      0.40      0.49        72
Intermediate       0.46      0.16      0.24        67
        High       0.46      0.79      0.58        96

    accuracy                           0.49       235
   macro avg       0.52      0.45      0.44       235
weighted avg       0.51      0.49      0.46       235

Prediction Distribution: [46, 24, 165]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  29         8        35
True 1                   2        11        54
True 2                  15         5        76

Best model saved.


Epoch 2:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 2 Summary:
Train Loss: 57.6281
Val Acc:   0.3660
Val F1:    0.2527

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.34      0.99      0.50        72
Intermediate       0.00      0.00      0.00        67
        High       0.71      0.16      0.26        96

    accuracy                           0.37       235
   macro avg       0.35      0.38      0.25       235
weighted avg       0.39      0.37      0.26       235

Prediction Distribution: [211, 3, 21]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  71         1         0
True 1                  61         0         6
True 2                  79         2        15



Epoch 3:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 3 Summary:
Train Loss: 56.8983
Val Acc:   0.5106
Val F1:    0.4534

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.50      0.74      0.59        72
Intermediate       0.31      0.12      0.17        67
        High       0.58      0.61      0.60        96

    accuracy                           0.51       235
   macro avg       0.46      0.49      0.45       235
weighted avg       0.48      0.51      0.47       235

Prediction Distribution: [107, 26, 102]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  53         8        11
True 1                  27         8        32
True 2                  27        10        59

Best model saved.


Epoch 4:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 4 Summary:
Train Loss: 56.8420
Val Acc:   0.5489
Val F1:    0.5004

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.54      0.75      0.63        72
Intermediate       0.43      0.18      0.25        67
        High       0.59      0.66      0.62        96

    accuracy                           0.55       235
   macro avg       0.52      0.53      0.50       235
weighted avg       0.53      0.55      0.52       235

Prediction Distribution: [100, 28, 107]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  54         8        10
True 1                  21        12        34
True 2                  25         8        63

Best model saved.


Epoch 5:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 5 Summary:
Train Loss: 55.9789
Val Acc:   0.4979
Val F1:    0.3672

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.83      0.35      0.49        72
Intermediate       0.00      0.00      0.00        67
        High       0.45      0.96      0.61        96

    accuracy                           0.50       235
   macro avg       0.43      0.44      0.37       235
weighted avg       0.44      0.50      0.40       235

Prediction Distribution: [30, 0, 205]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  25         0        47
True 1                   1         0        66
True 2                   4         0        92



Epoch 6:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 6 Summary:
Train Loss: 54.9764
Val Acc:   0.4681
Val F1:    0.4308

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.44      0.90      0.59        72
Intermediate       0.30      0.21      0.25        67
        High       0.78      0.32      0.46        96

    accuracy                           0.47       235
   macro avg       0.50      0.48      0.43       235
weighted avg       0.54      0.47      0.44       235

Prediction Distribution: [148, 47, 40]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  65         6         1
True 1                  45        14         8
True 2                  38        27        31



Epoch 7:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 7 Summary:
Train Loss: 54.9472
Val Acc:   0.4553
Val F1:    0.4149

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.43      0.93      0.59        72
Intermediate       0.32      0.24      0.27        67
        High       0.80      0.25      0.38        96

    accuracy                           0.46       235
   macro avg       0.52      0.47      0.41       235
weighted avg       0.55      0.46      0.41       235

Prediction Distribution: [155, 50, 30]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  67         4         1
True 1                  46        16         5
True 2                  42        30        24



Epoch 8:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 8 Summary:
Train Loss: 54.6486
Val Acc:   0.3702
Val F1:    0.2847

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.37      0.97      0.54        72
Intermediate       0.17      0.07      0.10        67
        High       0.67      0.12      0.21        96

    accuracy                           0.37       235
   macro avg       0.40      0.39      0.28       235
weighted avg       0.43      0.37      0.28       235

Prediction Distribution: [187, 30, 18]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  70         2         0
True 1                  56         5         6
True 2                  61        23        12



Epoch 9:   0%|          | 0/59 [00:00<?, ?it/s]


Epoch 9 Summary:
Train Loss: 54.0817
Val Acc:   0.5021
Val F1:    0.4604

Classification Report:
              precision    recall  f1-score   support

 No/Low Risk       0.66      0.46      0.54        72
Intermediate       0.33      0.19      0.25        67
        High       0.49      0.75      0.60        96

    accuracy                           0.50       235
   macro avg       0.50      0.47      0.46       235
weighted avg       0.50      0.50      0.48       235

Prediction Distribution: [50, 39, 146]

Confusion Matrix:
                   Pred 0    Pred 1    Pred 2
True 0                  33        16        23
True 1                   3        13        51
True 2                  14        10        72

Early stopping triggered.
