In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
csv_path = "C:/Users/HP/Downloads/GLEASON/Train_with_Gleason_and_Labels.csv"
df = pd.read_csv(csv_path)
df['risk_class'] = df['risk_label'].map({
    0: 0,  # Gleason 0 → Low Risk
    1: 0,  # Gleason 6 → Low Risk
    2: 1,  # Gleason 7 → Intermediate Risk
    3: 2   # Gleason 8–10 → High Risk
})
df['image_path'] = df['image_path'].apply(lambda x: os.path.join("C:/Users/HP/Downloads/GLEASON/Train Imgs", os.path.basename(x)))
df['mask_path'] = df['mask_path'].apply(lambda x: os.path.join("C:/Users/HP/Downloads/GLEASON", x.strip('./')))

print("Loaded:", len(df), "rows")
print("Example:", df.iloc[0])

Loaded: 1171 rows
Example: image_path         C:/Users/HP/Downloads/GLEASON/Train Imgs\slide...
mask_path          C:/Users/HP/Downloads/GLEASON\Maps1_T/slide001...
expert_id                                                          1
primary_grade                                                      4
secondary_grade                                                    4
gleason_score                                                      8
risk_label                                                         3
risk_class                                                         2
Name: 0, dtype: object


In [8]:
class GleasonGradeDataset(Dataset):
    def __init__(self, df, target="primary"):
        self.df = df.reset_index(drop=True)
        self.target = target
        self.grade_map = {3: 0, 4: 1, 5: 2}
        self.transform_img = T.Compose([
            T.Resize((512, 512)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])
        self.transform_mask = T.Compose([
            T.Resize((512, 512)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['image_path']).convert("RGB")
        mask = Image.open(row['mask_path']).convert("L")

        image_tensor = self.transform_img(image)
        mask_tensor = self.transform_mask(mask)
        input_tensor = torch.cat([image_tensor, mask_tensor], dim=0)  # 4-channel input

        mask_np = np.array(mask)
        valid = [3, 4, 5]
        grade_pixels = {v: np.sum(mask_np == v) for v in valid}
        sorted_grades = sorted(grade_pixels.items(), key=lambda x: x[1], reverse=True)
        primary = sorted_grades[0][0] if sorted_grades else 3
        secondary = sorted_grades[1][0] if len(sorted_grades) > 1 else primary
        label = primary if self.target == "primary" else secondary

        return input_tensor, self.grade_map.get(label, 0)

In [9]:
class CNNGradePredictor(nn.Module):
    def __init__(self):
        super(CNNGradePredictor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(4, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),  # 256x256
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),  # 128x128
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),  # 64x64
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2),  # 32x32
            nn.Conv2d(256, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1))  # 1x1
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, 3)  # 3 classes: grade 3, 4, 5
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

In [10]:
def train_model(model, train_loader, val_loader, class_weights, save_path, label_name=""):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_f1 = 0
    patience, trigger = 5, 0

    for epoch in range(1, 12):
        model.train()
        total_loss, preds, labels = 0, [], []

        for x, y in tqdm(train_loader, desc=f"Epoch {epoch} - Training"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            preds.extend(torch.argmax(out, 1).cpu().numpy())
            labels.extend(y.cpu().numpy())

        train_acc = accuracy_score(labels, preds)
        train_f1 = f1_score(labels, preds, average='macro')

        model.eval()
        val_preds, val_labels = [], []

        with torch.no_grad():
            for x, y in tqdm(val_loader, desc=f"Epoch {epoch} - Validation"):
                x, y = x.to(device), y.to(device)
                out = model(x)
                pred = torch.argmax(out, dim=1)
                val_preds.extend(pred.cpu().numpy())
                val_labels.extend(y.cpu().numpy())

        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average='macro')
        cm = confusion_matrix(val_labels, val_preds)

        print(f"\nEpoch {epoch} Results:")
        print(f"Train Loss: {total_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}")
        print(f"Val   Acc: {val_acc:.4f} | Val   F1: {val_f1:.4f}")
        print("Confusion Matrix:\n", cm)

        if val_f1 > best_f1:
            best_f1 = val_f1
            trigger = 0
            torch.save(model.state_dict(), save_path)
            print("Best model saved.")
        else:
            trigger += 1
            print(f"No improvement. Patience: {trigger}/{patience}")
            if trigger >= patience:
                print("Early stopping triggered.")
                break

In [11]:
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['risk_class'], random_state=42)

train_primary = GleasonGradeDataset(train_df, target="primary")
val_primary = GleasonGradeDataset(val_df, target="primary")
train_secondary = GleasonGradeDataset(train_df, target="secondary")
val_secondary = GleasonGradeDataset(val_df, target="secondary")

train_loader_primary = DataLoader(train_primary, batch_size=8, shuffle=True)
val_loader_primary = DataLoader(val_primary, batch_size=8)
train_loader_secondary = DataLoader(train_secondary, batch_size=8, shuffle=True)
val_loader_secondary = DataLoader(val_secondary, batch_size=8)

y_primary = [train_primary[i][1] for i in tqdm(range(len(train_primary)))]
y_secondary = [train_secondary[i][1] for i in tqdm(range(len(train_secondary)))]

class_weights_primary = compute_class_weight(class_weight="balanced", classes=np.array([0, 1, 2]), y=y_primary)
class_weights_secondary = compute_class_weight(class_weight="balanced", classes=np.array([0, 1, 2]), y=y_secondary)

class_weights_primary_tensor = torch.tensor(class_weights_primary, dtype=torch.float32).to(device)
class_weights_secondary_tensor = torch.tensor(class_weights_secondary, dtype=torch.float32).to(device)

  0%|          | 0/936 [00:00<?, ?it/s]

  0%|          | 0/936 [00:00<?, ?it/s]

In [12]:
model_primary = CNNGradePredictor().to(device)
train_model(
    model=model_primary,
    train_loader=train_loader_primary,
    val_loader=val_loader_primary,
    class_weights=class_weights_primary_tensor,
    save_path="cnn_primary_model4.pth",
    label_name="Primary"
)

model_secondary = CNNGradePredictor().to(device)
train_model(
    model=model_secondary,
    train_loader=train_loader_secondary,
    val_loader=val_loader_secondary,
    class_weights=class_weights_secondary_tensor,
    save_path="cnn_secondary_model4.pth",
    label_name="Secondary"
)

Epoch 1 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 1 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 1 Results:
Train Loss: 131.4047 | Train Acc: 0.5032 | Train F1: 0.3642
Val   Acc: 0.4936 | Val   F1: 0.3995
Confusion Matrix:
 [[38 46 21]
 [20 74 28]
 [ 1  3  4]]
Best model saved.


Epoch 2 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 2 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 2 Results:
Train Loss: 120.0629 | Train Acc: 0.5684 | Train F1: 0.3902
Val   Acc: 0.5191 | Val   F1: 0.3593
Confusion Matrix:
 [[55 48  2]
 [47 67  8]
 [ 1  7  0]]
No improvement. Patience: 1/5


Epoch 3 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 3 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 3 Results:
Train Loss: 116.9993 | Train Acc: 0.5759 | Train F1: 0.4253
Val   Acc: 0.6128 | Val   F1: 0.3923
Confusion Matrix:
 [[ 38  67   0]
 [ 16 106   0]
 [  1   7   0]]
No improvement. Patience: 2/5


Epoch 4 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 4 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 4 Results:
Train Loss: 115.2319 | Train Acc: 0.5630 | Train F1: 0.4140
Val   Acc: 0.4894 | Val   F1: 0.3365
Confusion Matrix:
 [[103   1   1]
 [104  10   8]
 [  5   1   2]]
No improvement. Patience: 3/5


Epoch 5 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 5 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 5 Results:
Train Loss: 108.0969 | Train Acc: 0.6068 | Train F1: 0.4800
Val   Acc: 0.6170 | Val   F1: 0.4601
Confusion Matrix:
 [[58 46  1]
 [29 86  7]
 [ 1  6  1]]
Best model saved.


Epoch 6 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 6 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 6 Results:
Train Loss: 107.5065 | Train Acc: 0.6047 | Train F1: 0.4810
Val   Acc: 0.3149 | Val   F1: 0.2991
Confusion Matrix:
 [[38 30 37]
 [16 30 76]
 [ 1  1  6]]
No improvement. Patience: 1/5


Epoch 7 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 7 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 7 Results:
Train Loss: 104.1744 | Train Acc: 0.6058 | Train F1: 0.4998
Val   Acc: 0.1277 | Val   F1: 0.1341
Confusion Matrix:
 [[  9  24  72]
 [  1  14 107]
 [  1   0   7]]
No improvement. Patience: 2/5


Epoch 8 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 8 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 8 Results:
Train Loss: 98.9250 | Train Acc: 0.6528 | Train F1: 0.5326
Val   Acc: 0.5787 | Val   F1: 0.4476
Confusion Matrix:
 [[98  6  1]
 [55 34 33]
 [ 1  3  4]]
No improvement. Patience: 3/5


Epoch 9 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 9 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 9 Results:
Train Loss: 103.4344 | Train Acc: 0.6026 | Train F1: 0.4952
Val   Acc: 0.5660 | Val   F1: 0.4599
Confusion Matrix:
 [[87 15  3]
 [46 41 35]
 [ 1  2  5]]
No improvement. Patience: 4/5


Epoch 10 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 10 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 10 Results:
Train Loss: 100.7699 | Train Acc: 0.5972 | Train F1: 0.4942
Val   Acc: 0.6383 | Val   F1: 0.4838
Confusion Matrix:
 [[98  6  1]
 [57 50 15]
 [ 1  5  2]]
Best model saved.


Epoch 11 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 11 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 11 Results:
Train Loss: 96.0142 | Train Acc: 0.6485 | Train F1: 0.5288
Val   Acc: 0.6298 | Val   F1: 0.4277
Confusion Matrix:
 [[78 27  0]
 [51 70  1]
 [ 3  5  0]]
No improvement. Patience: 1/5


Epoch 1 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 1 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 1 Results:
Train Loss: 125.9472 | Train Acc: 0.5043 | Train F1: 0.3437
Val   Acc: 0.5489 | Val   F1: 0.2859
Confusion Matrix:
 [[119   5   0]
 [ 96  10   0]
 [  4   1   0]]
Best model saved.


Epoch 2 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 2 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 2 Results:
Train Loss: 119.7283 | Train Acc: 0.5395 | Train F1: 0.3656
Val   Acc: 0.6255 | Val   F1: 0.4205
Confusion Matrix:
 [[81 43  0]
 [40 66  0]
 [ 2  3  0]]
Best model saved.


Epoch 3 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 3 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 3 Results:
Train Loss: 122.3036 | Train Acc: 0.5534 | Train F1: 0.3739
Val   Acc: 0.6000 | Val   F1: 0.3941
Confusion Matrix:
 [[48 76  0]
 [13 93  0]
 [ 2  3  0]]
No improvement. Patience: 1/5


Epoch 4 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 4 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 4 Results:
Train Loss: 122.8812 | Train Acc: 0.5406 | Train F1: 0.3651
Val   Acc: 0.5915 | Val   F1: 0.3774
Confusion Matrix:
 [[101  23   0]
 [ 68  38   0]
 [  3   2   0]]
No improvement. Patience: 2/5


Epoch 5 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 5 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 5 Results:
Train Loss: 117.5812 | Train Acc: 0.5299 | Train F1: 0.3578
Val   Acc: 0.5745 | Val   F1: 0.3635
Confusion Matrix:
 [[ 35  89   0]
 [  6 100   0]
 [  1   4   0]]
No improvement. Patience: 3/5


Epoch 6 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 6 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 6 Results:
Train Loss: 117.1016 | Train Acc: 0.6047 | Train F1: 0.4086
Val   Acc: 0.6426 | Val   F1: 0.4302
Confusion Matrix:
 [[88 36  0]
 [43 63  0]
 [ 2  3  0]]
Best model saved.


Epoch 7 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 7 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 7 Results:
Train Loss: 115.7530 | Train Acc: 0.6357 | Train F1: 0.4309
Val   Acc: 0.4979 | Val   F1: 0.2718
Confusion Matrix:
 [[ 12 112   0]
 [  1 105   0]
 [  0   5   0]]
No improvement. Patience: 1/5


Epoch 8 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 8 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 8 Results:
Train Loss: 114.8302 | Train Acc: 0.6335 | Train F1: 0.4292
Val   Acc: 0.6809 | Val   F1: 0.4542
Confusion Matrix:
 [[97 27  0]
 [43 63  0]
 [ 2  3  0]]
Best model saved.


Epoch 9 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 9 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 9 Results:
Train Loss: 115.8263 | Train Acc: 0.6549 | Train F1: 0.4439
Val   Acc: 0.5915 | Val   F1: 0.3776
Confusion Matrix:
 [[ 38  86   0]
 [  5 101   0]
 [  0   5   0]]
No improvement. Patience: 1/5


Epoch 10 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 10 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 10 Results:
Train Loss: 110.3805 | Train Acc: 0.6966 | Train F1: 0.4721
Val   Acc: 0.6638 | Val   F1: 0.4366
Confusion Matrix:
 [[103  21   0]
 [ 53  53   0]
 [  2   3   0]]
No improvement. Patience: 2/5


Epoch 11 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 11 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


Epoch 11 Results:
Train Loss: 112.2556 | Train Acc: 0.7147 | Train F1: 0.4844
Val   Acc: 0.7319 | Val   F1: 0.4907
Confusion Matrix:
 [[ 72  52   0]
 [  6 100   0]
 [  2   3   0]]
Best model saved.


In [13]:
model_primary.load_state_dict(torch.load("cnn_primary_model4.pth"))
model_secondary.load_state_dict(torch.load("cnn_secondary_model4.pth"))

model_primary.eval()
model_secondary.eval()

CNNGradePredictor(
  (features): Sequential(
    (0): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_sta

In [14]:
class CNN4Dataset(Dataset):
    def __init__(self, df, cnn1_model, cnn2_model, transform=None):
        self.df = df.reset_index(drop=True)
        self.cnn1 = cnn1_model.eval()
        self.cnn2 = cnn2_model.eval()
        self.transform_img = transform or T.Compose([
            T.Resize((512, 512)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['image_path']
        mask_path = row['mask_path']

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        img_tensor = self.transform_img(image)

        # Dummy mask needed to form 4-channel input to CNN1/2
        mask_tensor = T.Resize((512, 512))(T.ToTensor()(mask))
        img4_tensor = torch.cat([img_tensor, mask_tensor], dim=0).unsqueeze(0).to(device)

        with torch.no_grad():
            pred_primary = self.cnn1(img4_tensor)
            pred_secondary = self.cnn2(img4_tensor)

        primary_grade = torch.argmax(pred_primary, dim=1).item() + 3
        secondary_grade = torch.argmax(pred_secondary, dim=1).item() + 3
        gleason_score = primary_grade + secondary_grade

        # Risk class logic: 0 = 0/6, 1 = 7, 2 = 8/9/10
        if gleason_score in [0, 6]:
            risk = 0
        elif gleason_score == 7:
            risk = 1
        elif gleason_score in [8, 9, 10]:
            risk = 2
        else:
            risk = 0

        # Construct 6-channel input: RGB + pri + sec + score
        input_tensor = torch.cat([
            img_tensor,
            torch.tensor([[primary_grade / 5.0]]).repeat(1, 512, 512),
            torch.tensor([[secondary_grade / 5.0]]).repeat(1, 512, 512),
            torch.tensor([[gleason_score / 10.0]]).repeat(1, 512, 512)
        ], dim=0)

        return input_tensor, torch.tensor(risk, dtype=torch.long)

In [15]:
cnn4_train_dataset = CNN4Dataset(train_df, model_primary, model_secondary)
cnn4_val_dataset = CNN4Dataset(val_df, model_primary, model_secondary)

cnn4_train_loader = DataLoader(cnn4_train_dataset, batch_size=8, shuffle=True)
cnn4_val_loader = DataLoader(cnn4_val_dataset, batch_size=8, shuffle=False)

# Check risk label distribution for weights
risk_labels_train = [cnn4_train_dataset[i][1].item() for i in tqdm(range(len(cnn4_train_dataset)))]
print("✅ Risk class distribution:", pd.Series(risk_labels_train).value_counts())

# Compute class weights
class_weights_cnn4 = compute_class_weight(class_weight='balanced', classes=np.array([0,1,2]), y=risk_labels_train)
class_weights_cnn4_tensor = torch.tensor(class_weights_cnn4, dtype=torch.float32).to(device)

  0%|          | 0/936 [00:00<?, ?it/s]

✅ Risk class distribution: 1    666
2    183
0     87
Name: count, dtype: int64


In [16]:
class CNN4RiskClassifier(nn.Module):
    def __init__(self):
        super(CNN4RiskClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=3, padding=1),  # Input: 6-channel image
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),  # [B, 512]
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 3)  # Final 3-class risk output
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [17]:
cnn4_model = CNN4RiskClassifier().to(device)
loss_fn = nn.CrossEntropyLoss(weight=class_weights_cnn4_tensor)
optimizer = torch.optim.Adam(cnn4_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

In [18]:
epochs = 20
patience = 5
best_f1 = 0
early_stop = 0

for epoch in range(1, epochs + 1):
    cnn4_model.train()
    train_loss, all_preds, all_labels = 0, [], []

    for inputs, targets in tqdm(cnn4_train_loader, desc=f"Epoch {epoch} - Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = cnn4_model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        all_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        all_labels.extend(targets.cpu().numpy())

    train_acc = accuracy_score(all_labels, all_preds)
    train_f1 = f1_score(all_labels, all_preds, average='macro')

    # === Validation ===
    cnn4_model.eval()
    val_preds, val_labels = [], []

    with torch.no_grad():
        for inputs, targets in tqdm(cnn4_val_loader, desc=f"Epoch {epoch} - Validation"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = cnn4_model(inputs)
            preds = torch.argmax(outputs, dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(targets.cpu().numpy())

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    cm = confusion_matrix(val_labels, val_preds)

    print(f"\n📊 Epoch {epoch} Results:")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}")
    print(f"Val   Acc: {val_acc:.4f} | Val   F1: {val_f1:.4f}")
    print("Confusion Matrix:\n", cm)

    scheduler.step(val_f1)

    if val_f1 > best_f1:
        best_f1 = val_f1
        early_stop = 0
        torch.save(cnn4_model.state_dict(), "Fused_Final_cnn4.pth")
        print("✅ Best model saved.")
    else:
        early_stop += 1
        print(f"No improvement. Patience: {early_stop}/{patience}")
        if early_stop >= patience:
            print("⏹️ Early stopping triggered.")
            break

Epoch 1 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 1 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


📊 Epoch 1 Results:
Train Loss: 39.2999 | Train Acc: 0.8974 | Train F1: 0.8468
Val   Acc: 0.9915 | Val   F1: 0.9891
Confusion Matrix:
 [[ 21   0   0]
 [  0 176   2]
 [  0   0  36]]
✅ Best model saved.


Epoch 2 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 2 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


📊 Epoch 2 Results:
Train Loss: 19.0040 | Train Acc: 0.9370 | Train F1: 0.9120
Val   Acc: 0.2723 | Val   F1: 0.4573
Confusion Matrix:
 [[ 21   0   0]
 [  0   7 171]
 [  0   0  36]]
No improvement. Patience: 1/5


Epoch 3 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 3 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


📊 Epoch 3 Results:
Train Loss: 22.8947 | Train Acc: 0.9252 | Train F1: 0.8922
Val   Acc: 0.9447 | Val   F1: 0.9364
Confusion Matrix:
 [[ 21   0   0]
 [  0 165  13]
 [  0   0  36]]
No improvement. Patience: 2/5


Epoch 4 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

Epoch 4 - Validation:   0%|          | 0/30 [00:00<?, ?it/s]


📊 Epoch 4 Results:
Train Loss: 10.7937 | Train Acc: 0.9701 | Train F1: 0.9517
Val   Acc: 1.0000 | Val   F1: 1.0000
Confusion Matrix:
 [[ 21   0   0]
 [  0 178   0]
 [  0   0  36]]
✅ Best model saved.


Epoch 5 - Training:   0%|          | 0/117 [00:00<?, ?it/s]

KeyboardInterrupt: 