In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

In [3]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("torch.version.cuda:", torch.version.cuda)
print("torch.cuda.device_count():", torch.cuda.device_count())

Torch version: 2.10.0.dev20251211+cu130
CUDA available: True
torch.version.cuda: 13.0
torch.cuda.device_count(): 1


In [4]:
print(torch.version.cuda)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")

13.0
Using: cuda


#  Datasets

In [5]:
transform = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

In [6]:
dataset0  = datasets.ImageFolder(root='./dataset/',transform=None)
class_names = dataset0.classes
print(class_names)
print(len(class_names))

['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
4


In [7]:
class DataModule(pl.LightningDataModule):
    def __init__(self, transform = transform, test_transform=None, batch_size = 32, test_size=0.3, random_seed=42, max_samples_non_demented=None, non_demented_label='Non Demented'):
        super().__init__()
        self.root_dir = './dataset/'
        self.transform = transform
        if test_transform is None:
            self.test_transform = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
        else:
            self.test_transform = test_transform
        self.batch_size = batch_size
        self.test_size = test_size
        self.random_seed = random_seed
        # Limit for Non Demented class in the training pool (None => no limit)
        self.max_samples_non_demented = max_samples_non_demented
        self.non_demented_label = non_demented_label
    
    def setup(self,stage=None):
        # Prepare separate train/test datasets with stratified split so all classes appear in test set
        full_train_ds = datasets.ImageFolder(root=self.root_dir, transform=self.transform)
        full_test_ds = datasets.ImageFolder(root=self.root_dir, transform=self.test_transform)

        # Optionally undersample the over-represented "Non Demented" class in the training pool
        indices = list(range(len(full_train_ds)))
        targets = full_train_ds.targets

        # Print distribution before sampling (useful for debugging)
        try:
            from collections import Counter
            orig_counts = Counter(targets)
            print("Original class counts:", {class_names[k]: v for k, v in orig_counts.items()})
        except Exception:
            pass

        subset_indices = indices
        subset_targets = targets
        if self.max_samples_non_demented is not None and self.non_demented_label in class_names:
            nd_idx = class_names.index(self.non_demented_label)
            # indices belonging to Non Demented
            nd_indices = [i for i, t in enumerate(targets) if t == nd_idx]
            keep_nd = min(len(nd_indices), self.max_samples_non_demented)
            import random
            rnd = random.Random(self.random_seed)
            kept_nd_indices = set(rnd.sample(nd_indices, keep_nd)) if keep_nd < len(nd_indices) else set(nd_indices)

            # Build subset indices keeping all other classes and the sampled Non Demented
            kept_indices = [i for i in indices if (targets[i] != nd_idx) or (i in kept_nd_indices)]
            subset_indices = kept_indices
            subset_targets = [targets[i] for i in subset_indices]

            try:
                print(f"After undersampling '{self.non_demented_label}':", {class_names[k]: v for k, v in Counter(subset_targets).items()})
            except Exception:
                pass

        # Recompute split sizes based on (possibly) pruned subset
        n_data = len(subset_indices)
        n_test = int(self.test_size * n_data)

        train_idx, test_idx = train_test_split(subset_indices, test_size=n_test, stratify=subset_targets, random_state=self.random_seed)
        print(f"Train: {train_idx}\nTest: {test_idx}")

        train_subset = Subset(full_train_ds, train_idx)
        test_subset = Subset(full_test_ds, test_idx)

        self.train_dataset = DataLoader(train_subset,batch_size=self.batch_size, shuffle=True)
        self.test_dataset = DataLoader(test_subset,batch_size=self.batch_size,shuffle=False)

    def train_dataloader(self):
        return self.train_dataset
    
    def test_dataloader(self):
        return self.test_dataset

# CNN

In [None]:
class ConvulationalNetwork(pl.LightningModule):
    def __init__(self):
        super(ConvulationalNetwork,self).__init__()
        self.conv1 = nn.Conv2d(3,6,3,1)
        self.conv2 = nn.Conv2d(6,16,3,1)
        self.fc1 = nn.Linear(16*54*54,120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 20)
        self.fc4 = nn.Linear(20, len(class_names))
        
    def forward(self, X):
        X = F.relu(self.conv1(X))
        X = F.max_pool2d(X, 2, 2)
        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X, 2, 2)
        X = X.view(-1, 16 * 54 * 54)
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        X = F.relu(self.fc3(X))
        X = self.fc4(X)
        return F.log_softmax(X, dim=1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        X, y = train_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        X, y = val_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def test_step(self, test_batch, batch_idx):
        X, y = test_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("test_loss", loss)
        self.log("test_acc", acc)

In [9]:
if __name__ == '__main__':
    datamodule = DataModule()
    datamodule.setup()
    model = ConvulationalNetwork()
    trainer = pl.Trainer(max_epochs=20)
    trainer.fit(model = model, datamodule= datamodule, )
    datamodule.setup(stage='test')
    test_loader = datamodule.test_dataloader()
    trainer.test(dataloaders=test_loader)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:70: You defined a `valida

Original class counts: {'Mild Dementia': 5002, 'Moderate Dementia': 488, 'Non Demented': 67222, 'Very mild Dementia': 13725}
Train: [51669, 85304, 57860, 15929, 24415, 53517, 325, 65214, 85631, 11515, 23241, 25928, 81557, 17247, 80135, 62280, 74806, 79046, 69611, 85502, 86247, 47633, 5570, 28528, 32621, 52986, 49460, 54511, 20469, 48930, 16394, 71656, 25064, 33381, 52608, 63822, 73827, 8884, 22791, 34792, 59438, 935, 35379, 64158, 80030, 9137, 77801, 53107, 2123, 16296, 65956, 40905, 15787, 33324, 24710, 31129, 79217, 32268, 39296, 28688, 7643, 78606, 42544, 65692, 47419, 20360, 70358, 55194, 65309, 60316, 34393, 39864, 38599, 2488, 66032, 26550, 26100, 7768, 47039, 82253, 37580, 86026, 83952, 68530, 44705, 46411, 70491, 55524, 51240, 44634, 55719, 21918, 12247, 51161, 60121, 78840, 63582, 54343, 12595, 1667, 13512, 53939, 27191, 54035, 36786, 44579, 78076, 69984, 12892, 59142, 70633, 62869, 66763, 57956, 14154, 9100, 18843, 12715, 41599, 71957, 57152, 68291, 81348, 24303, 1978, 17486,

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | conv1 | Conv2d | 168    | train | 0    
1 | conv2 | Conv2d | 880    | train | 0    
2 | fc1   | Linear | 5.6 M  | train | 0    
3 | fc2   | Linear | 10.2 K | train | 0    
4 | fc3   | Linear | 1.7 K  | train | 0    
5 | fc4   | Linear | 84     | train | 0    
-------------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.447    Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode
0         Total Flops
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\utilities\_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\traine

Original class counts: {'Mild Dementia': 5002, 'Moderate Dementia': 488, 'Non Demented': 67222, 'Very mild Dementia': 13725}
Train: [51669, 85304, 57860, 15929, 24415, 53517, 325, 65214, 85631, 11515, 23241, 25928, 81557, 17247, 80135, 62280, 74806, 79046, 69611, 85502, 86247, 47633, 5570, 28528, 32621, 52986, 49460, 54511, 20469, 48930, 16394, 71656, 25064, 33381, 52608, 63822, 73827, 8884, 22791, 34792, 59438, 935, 35379, 64158, 80030, 9137, 77801, 53107, 2123, 16296, 65956, 40905, 15787, 33324, 24710, 31129, 79217, 32268, 39296, 28688, 7643, 78606, 42544, 65692, 47419, 20360, 70358, 55194, 65309, 60316, 34393, 39864, 38599, 2488, 66032, 26550, 26100, 7768, 47039, 82253, 37580, 86026, 83952, 68530, 44705, 46411, 70491, 55524, 51240, 44634, 55719, 21918, 12247, 51161, 60121, 78840, 63582, 54343, 12595, 1667, 13512, 53939, 27191, 54035, 36786, 44579, 78076, 69984, 12892, 59142, 70633, 62869, 66763, 57956, 14154, 9100, 18843, 12715, 41599, 71957, 57152, 68291, 81348, 24303, 1978, 17486,


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


# Load Model

In [None]:
latest_version = 5 
checkpoint_path = f'./lightning_logs/version_{latest_version}/checkpoints/epoch=19-step=37820.ckpt'

In [None]:
# Use a deterministic test transform (no random augmentations)
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
datamodule = DataModule(transform=transform, test_transform=test_transform)
datamodule.setup(stage='test')
model  = ConvulationalNetwork.load_from_checkpoint(checkpoint_path)
model.to(device)
model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_images, test_labels in datamodule.test_dataloader():
        test_images = test_images.to(device)
        test_labels = test_labels.to(device)
        outputs = model(test_images)
        preds = outputs.argmax(dim=1)
        y_true.extend(test_labels.cpu().tolist())
        y_pred.extend(preds.cpu().tolist())

print(classification_report(y_true,y_pred,target_names=class_names,digits=4))

                    precision    recall  f1-score   support

     Mild Dementia     0.0000    0.0000    0.0000      1501
 Moderate Dementia     0.0000    0.0000    0.0000       146
      Non Demented     0.7777    1.0000    0.8750     20167
Very mild Dementia     0.0000    0.0000    0.0000      4117

          accuracy                         0.7777     25931
         macro avg     0.1944    0.2500    0.2187     25931
      weighted avg     0.6048    0.7777    0.6805     25931



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
