In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, Subset
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

In [17]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("torch.version.cuda:", torch.version.cuda)
print("torch.cuda.device_count():", torch.cuda.device_count())

Torch version: 2.10.0.dev20251211+cu130
CUDA available: True
torch.version.cuda: 13.0
torch.cuda.device_count(): 1


In [18]:
print(torch.version.cuda)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")

13.0
Using: cuda


#  Datasets

In [19]:
transform = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

In [20]:
dataset0  = datasets.ImageFolder(root='./dataset/',transform=None)
class_names = dataset0.classes
print(class_names)
print(len(class_names))

['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
4


In [21]:
class DataModule(pl.LightningDataModule):
    def __init__(self, transform = transform, test_transform=None, batch_size = 32, test_size=0.3, random_seed=42):
        super().__init__()
        self.root_dir = './dataset/'
        self.transform = transform
        if test_transform is None:
            self.test_transform = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
        else:
            self.test_transform = test_transform
        self.batch_size = batch_size
        self.test_size = test_size
        self.random_seed = random_seed
    
    def setup(self,stage=None):
        # Prepare separate train/test datasets with stratified split so all classes appear in test set
        full_train_ds = datasets.ImageFolder(root=self.root_dir, transform=self.transform)
        full_test_ds = datasets.ImageFolder(root=self.root_dir, transform=self.test_transform)
        n_data = len(full_train_ds)
        n_test = int(self.test_size * n_data)
        indices = list(range(n_data))
        targets = full_train_ds.targets
        train_idx, test_idx = train_test_split(indices, test_size=n_test, stratify=targets, random_state=self.random_seed)

        train_subset = Subset(full_train_ds, train_idx)
        test_subset = Subset(full_test_ds, test_idx)

        self.train_dataset = DataLoader(train_subset,batch_size=self.batch_size, shuffle=True)
        self.test_dataset = DataLoader(test_subset,batch_size=self.batch_size,shuffle=False)

    def train_dataloader(self):
        return self.train_dataset
    
    def test_dataloader(self):
        return self.test_dataset

# CNN

In [22]:
class ConvulationalNetwork(pl.LightningModule):
    def __init__(self):
        super(ConvulationalNetwork,self).__init__()
        self.conv1 = nn.Conv2d(3,6,3,1)
        self.conv2 = nn.Conv2d(6,16,3,1)
        self.fc1 = nn.Linear(16*54*54,120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 20)
        self.fc4 = nn.Linear(20, len(class_names))
        
    def forward(self, X):
        X = F.relu(self.conv1(X))
        X = F.max_pool2d(X, 2, 2)
        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X, 2, 2)
        X = X.view(-1, 16 * 54 * 54)
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        X = F.relu(self.fc3(X))
        X = self.fc4(X)
        return F.log_softmax(X, dim=1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        X, y = train_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, val_batch, batch_idx):
        X, y = val_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def test_step(self, test_batch, batch_idx):
        X, y = test_batch
        y_hat = self(X)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        acc = pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
        self.log("test_loss", loss)
        self.log("test_acc", acc)

In [23]:
if __name__ == '__main__':
    datamodule = DataModule()
    datamodule.setup()
    model = ConvulationalNetwork()
    trainer = pl.Trainer(max_epochs=20)
    trainer.fit(model = model, datamodule= datamodule, )
    datamodule.setup(stage='test')
    test_loader = datamodule.test_dataloader()
    trainer.test(dataloaders=test_loader)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:70: You defined a `valida

Epoch 19: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1891/1891 [02:15<00:00, 13.92it/s, v_num=5]    

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1891/1891 [02:16<00:00, 13.89it/s, v_num=5]


Restoring states from the checkpoint path at c:\Users\Alvin\Downloads\PythonProjects\RM-Project\lightning_logs\version_5\checkpoints\epoch=19-step=37820.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\Alvin\Downloads\PythonProjects\RM-Project\lightning_logs\version_5\checkpoints\epoch=19-step=37820.ckpt
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\utilities\_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
c:\Users\Alvin\Downloads\PythonProjects\RM-Project\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 811/811 [04:25<00:00,  3.06it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_acc            0.7777177691459656
        test_loss            0.681828498840332
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€

# Load Model

In [24]:
latest_version = 5 
checkpoint_path = f'./lightning_logs/version_{latest_version}/checkpoints/epoch=19-step=37820.ckpt'

In [25]:
# Use a deterministic test transform (no random augmentations)
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
datamodule = DataModule(transform=transform, test_transform=test_transform)
datamodule.setup(stage='test')
model  = ConvulationalNetwork.load_from_checkpoint(checkpoint_path)
model.to(device)
model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_images, test_labels in datamodule.test_dataloader():
        test_images = test_images.to(device)
        test_labels = test_labels.to(device)
        outputs = model(test_images)
        preds = outputs.argmax(dim=1)
        y_true.extend(test_labels.cpu().tolist())
        y_pred.extend(preds.cpu().tolist())

print(classification_report(y_true,y_pred,target_names=class_names,digits=4))

                    precision    recall  f1-score   support

     Mild Dementia     0.0000    0.0000    0.0000      1501
 Moderate Dementia     0.0000    0.0000    0.0000       146
      Non Demented     0.7777    1.0000    0.8750     20167
Very mild Dementia     0.0000    0.0000    0.0000      4117

          accuracy                         0.7777     25931
         macro avg     0.1944    0.2500    0.2187     25931
      weighted avg     0.6048    0.7777    0.6805     25931



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
