# 2 Layer Neural Network

## Imports

In [1]:
import os
import pandas as pd

## Import data and create .csv

In [2]:
# Directory to MRI Images
base_dir = '/Users/benrandoing/Downloads/archive/AugmentedAlzheimerDataset'
classes = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']

if not os.path.exists('dataset.csv'):
    data = []

    for label, class_name in enumerate(classes):
        class_dir = os.path.join(base_dir, class_name)
        for image_name in os.listdir(class_dir):
            image_path = os.path.join(class_dir, image_name)
            data.append([image_path, label])

    df = pd.DataFrame(data, columns=['path', 'label'])
    df.to_csv('dataset.csv', index=False)

## Data Loading and Pre-Processing

In [3]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.dataframe = pd.read_csv(csv_file)
        self.dataframe = self.dataframe.sample(frac=1).reset_index(drop=True)  # Shuffle the dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        image = Image.open(img_name)
        label = int(self.dataframe.iloc[idx, 1])

        if self.transform:
            image = self.transform(image)

        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224 X 224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    #transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Load the entire dataset
dataset = CustomDataset(csv_file='dataset.csv', transform=transform)

# Split the dataset into train, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [4]:
print(f'Training Samples: {len(train_dataset)}')
print(f'Validation Samples: {len(val_dataset)}')
print(f'Test Samples: {len(test_dataset)}')

Training Samples: 23788
Validation Samples: 5097
Test Samples: 5099


In [5]:
print(train_dataset)

<torch.utils.data.dataset.Subset object at 0x7f9d1261b400>


In [6]:
import matplotlib.pyplot as plt
#plt.plot(train_dataset[0][0])
#plt.show()

## Pytorch Logistic Regression Model

In [7]:
class FiveLayerNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4, output_dim):
        super(FiveLayerNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim1)
        self.fc2 = torch.nn.Linear(hidden_dim1, hidden_dim2)
        self.fc3 = torch.nn.Linear(hidden_dim2, hidden_dim3)
        self.fc4 = torch.nn.Linear(hidden_dim3, hidden_dim4)
        self.fc5 = torch.nn.Linear(hidden_dim4, output_dim)

    def forward(self, x):
        #x = x.view(-1, 224 * 224 * 3)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = self.fc5(x)
        return x

input_dim = 224 * 224 * 3  # Assuming the images are RGB
hidden_dim1 = 512
hidden_dim2 = 256
hidden_dim3 = 128
hidden_dim4 = 64
output_dim = 4  # Four classes

#model = FiveLayerNN(input_dim, hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4, output_dim)
import torchvision
from torchvision.models import resnet18, ResNet18_Weights
model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

model.fc = torch.nn.Linear(model.fc.in_features, output_dim)


## Training Loop

In [None]:
learning_rate = 0.001
num_epochs = 20
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()

import torch.optim
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    i = 0
    for inputs, labels in train_loader:
        i = i + 1
        print("Training continues..", total_loss, "current i = ", i)    
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    #lr decay
    #lr_scheduler.step()

    print(f"Iteration {epoch} - Loss = {total_loss / len(train_loader)}")

Training continues.. 0.0 current i =  1
Training continues.. 1.4263499975204468 current i =  2
Training continues.. 5.120266556739807 current i =  3
Training continues.. 8.428125977516174 current i =  4
Training continues.. 10.543232798576355 current i =  5
Training continues.. 12.810245633125305 current i =  6
Training continues.. 14.947182774543762 current i =  7
Training continues.. 16.718878746032715 current i =  8
Training continues.. 18.46541678905487 current i =  9
Training continues.. 19.88349986076355 current i =  10
Training continues.. 21.31885039806366 current i =  11
Training continues.. 22.89302897453308 current i =  12
Training continues.. 24.299073576927185 current i =  13
Training continues.. 25.665361642837524 current i =  14
Training continues.. 27.026344180107117 current i =  15
Training continues.. 28.625990986824036 current i =  16
Training continues.. 29.91148614883423 current i =  17
Training continues.. 31.349323630332947 current i =  18
Training continues.. 32

Training continues.. 183.3124237060547 current i =  149
Training continues.. 184.1912700533867 current i =  150
Training continues.. 185.0742319226265 current i =  151
Training continues.. 185.83030462265015 current i =  152
Training continues.. 186.66683346033096 current i =  153
Training continues.. 187.74933570623398 current i =  154
Training continues.. 188.72585892677307 current i =  155
Training continues.. 189.67435824871063 current i =  156
Training continues.. 190.4920238852501 current i =  157
Training continues.. 191.34994304180145 current i =  158
Training continues.. 192.24146115779877 current i =  159
Training continues.. 193.18362218141556 current i =  160
Training continues.. 193.91529697179794 current i =  161
Training continues.. 194.75204783678055 current i =  162
Training continues.. 195.5033985376358 current i =  163
Training continues.. 196.37418788671494 current i =  164
Training continues.. 197.19197416305542 current i =  165
Training continues.. 197.88320571184

Training continues.. 293.7495861053467 current i =  294
Training continues.. 294.48993796110153 current i =  295
Training continues.. 295.14964908361435 current i =  296
Training continues.. 295.76019990444183 current i =  297
Training continues.. 296.57848542928696 current i =  298
Training continues.. 297.3032215833664 current i =  299
Training continues.. 297.9525447487831 current i =  300
Training continues.. 298.73300701379776 current i =  301
Training continues.. 299.307635307312 current i =  302
Training continues.. 299.94731307029724 current i =  303
Training continues.. 300.61159855127335 current i =  304
Training continues.. 301.2545365691185 current i =  305
Training continues.. 301.79890835285187 current i =  306
Training continues.. 302.4129822254181 current i =  307
Training continues.. 303.0713435411453 current i =  308
Training continues.. 303.7832933664322 current i =  309
Training continues.. 304.5155639052391 current i =  310
Training continues.. 305.27224987745285 c

Training continues.. 39.36983558535576 current i =  69
Training continues.. 39.80686354637146 current i =  70
Training continues.. 40.23073253035545 current i =  71
Training continues.. 40.67722341418266 current i =  72
Training continues.. 41.22879698872566 current i =  73
Training continues.. 41.916723161935806 current i =  74
Training continues.. 42.74812349677086 current i =  75
Training continues.. 43.442818373441696 current i =  76
Training continues.. 44.3321273624897 current i =  77
Training continues.. 44.98985508084297 current i =  78
Training continues.. 45.601724058389664 current i =  79
Training continues.. 46.10538312792778 current i =  80
Training continues.. 46.56560042500496 current i =  81
Training continues.. 47.22177121043205 current i =  82
Training continues.. 47.85394695401192 current i =  83
Training continues.. 48.31217882037163 current i =  84
Training continues.. 48.79040586948395 current i =  85
Training continues.. 49.44301116466522 current i =  86
Training

Training continues.. 113.29497250914574 current i =  216
Training continues.. 113.91010156273842 current i =  217
Training continues.. 114.36842009425163 current i =  218
Training continues.. 114.68945625424385 current i =  219
Training continues.. 115.25883111357689 current i =  220
Training continues.. 115.65998527407646 current i =  221
Training continues.. 116.17230948805809 current i =  222
Training continues.. 116.48448231816292 current i =  223
Training continues.. 116.88683071732521 current i =  224
Training continues.. 117.31376591324806 current i =  225
Training continues.. 117.77589213848114 current i =  226
Training continues.. 118.11076191067696 current i =  227
Training continues.. 118.54855480790138 current i =  228
Training continues.. 118.94482830166817 current i =  229
Training continues.. 119.23791888356209 current i =  230
Training continues.. 119.62017717957497 current i =  231
Training continues.. 120.04772049188614 current i =  232
Training continues.. 120.607810

Training continues.. 169.43899676203728 current i =  361
Training continues.. 169.7824985086918 current i =  362
Training continues.. 170.24552205204964 current i =  363
Training continues.. 170.60231897234917 current i =  364
Training continues.. 170.88507282733917 current i =  365
Training continues.. 171.07451730966568 current i =  366
Training continues.. 171.29620280861855 current i =  367
Training continues.. 171.6477279663086 current i =  368
Training continues.. 171.87743411958218 current i =  369
Training continues.. 172.23293574154377 current i =  370
Training continues.. 172.63215158879757 current i =  371
Training continues.. 172.89146776497364 current i =  372
Iteration 1 - Loss = 0.4655208310052272
Training continues.. 0.0 current i =  1
Training continues.. 0.4144386351108551 current i =  2
Training continues.. 0.5976619273424149 current i =  3
Training continues.. 0.9036467224359512 current i =  4
Training continues.. 1.251486673951149 current i =  5
Training continues.

Training continues.. 36.14903808385134 current i =  136
Training continues.. 36.36758612841368 current i =  137
Training continues.. 36.6295011267066 current i =  138
Training continues.. 36.80062898248434 current i =  139
Training continues.. 36.976204328238964 current i =  140
Training continues.. 37.22168356925249 current i =  141
Training continues.. 37.49145341664553 current i =  142
Training continues.. 37.62768816202879 current i =  143
Training continues.. 37.83570263534784 current i =  144
Training continues.. 38.103677697479725 current i =  145
Training continues.. 38.47539187222719 current i =  146
Training continues.. 38.72031632810831 current i =  147
Training continues.. 38.867501862347126 current i =  148
Training continues.. 39.30091122537851 current i =  149
Training continues.. 39.473563857376575 current i =  150
Training continues.. 39.69163876026869 current i =  151
Training continues.. 40.04173978418112 current i =  152
Training continues.. 40.2431346103549 current

Training continues.. 69.08722611516714 current i =  282
Training continues.. 69.21466960757971 current i =  283
Training continues.. 69.41847639530897 current i =  284
Training continues.. 69.5885721668601 current i =  285
Training continues.. 69.75077164918184 current i =  286
Training continues.. 69.94692077487707 current i =  287
Training continues.. 70.09913142770529 current i =  288
Training continues.. 70.1760020032525 current i =  289
Training continues.. 70.29681654274464 current i =  290
Training continues.. 70.4418064802885 current i =  291
Training continues.. 70.53558108955622 current i =  292
Training continues.. 70.60388338565826 current i =  293
Training continues.. 70.72232811897993 current i =  294
Training continues.. 70.8886965289712 current i =  295
Training continues.. 71.0151039287448 current i =  296
Training continues.. 71.14447254687548 current i =  297
Training continues.. 71.22470985352993 current i =  298
Training continues.. 71.43717458844185 current i =  2

Training continues.. 7.38602520339191 current i =  58
Training continues.. 7.4418116342276335 current i =  59
Training continues.. 7.721503568813205 current i =  60
Training continues.. 8.160542769357562 current i =  61
Training continues.. 8.275759099051356 current i =  62
Training continues.. 8.32932079397142 current i =  63
Training continues.. 8.428493602201343 current i =  64
Training continues.. 8.499560778960586 current i =  65
Training continues.. 8.565170934423804 current i =  66
Training continues.. 8.652436038479209 current i =  67
Training continues.. 8.691313883289695 current i =  68
Training continues.. 8.73435296677053 current i =  69
Training continues.. 8.811100458726287 current i =  70
Training continues.. 8.884149899706244 current i =  71
Training continues.. 8.922008180990815 current i =  72
Training continues.. 8.97881349362433 current i =  73
Training continues.. 9.111321905627847 current i =  74
Training continues.. 9.135503748431802 current i =  75
Training cont

Training continues.. 27.219597302377224 current i =  204
Training continues.. 27.267809715121984 current i =  205
Training continues.. 27.393414687365294 current i =  206
Training continues.. 27.46031006798148 current i =  207
Training continues.. 27.56142073497176 current i =  208
Training continues.. 27.679152186959982 current i =  209
Training continues.. 27.820708122104406 current i =  210
Training continues.. 27.858247946947813 current i =  211
Training continues.. 27.937841709703207 current i =  212
Training continues.. 28.007064532488585 current i =  213
Training continues.. 28.07944344356656 current i =  214
Training continues.. 28.333222407847643 current i =  215
Training continues.. 28.356533970683813 current i =  216
Training continues.. 28.551719900220633 current i =  217
Training continues.. 28.600446987897158 current i =  218
Training continues.. 28.868714798241854 current i =  219
Training continues.. 28.995180550962687 current i =  220
Training continues.. 29.0511132031

Training continues.. 47.298360504209995 current i =  350
Training continues.. 47.409863986074924 current i =  351
Training continues.. 47.43405199982226 current i =  352
Training continues.. 47.54083116538823 current i =  353
Training continues.. 47.630503276363015 current i =  354
Training continues.. 47.72581229545176 current i =  355
Training continues.. 47.84404354728758 current i =  356
Training continues.. 47.90073680691421 current i =  357
Training continues.. 47.961995193734765 current i =  358
Training continues.. 48.08354468084872 current i =  359
Training continues.. 48.34317890740931 current i =  360
Training continues.. 48.50391101278365 current i =  361
Training continues.. 48.69423119165003 current i =  362
Training continues.. 48.76933057047427 current i =  363
Training continues.. 48.852411234751344 current i =  364
Training continues.. 49.04600844718516 current i =  365
Training continues.. 49.154876040294766 current i =  366
Training continues.. 49.30193886347115 cur

Training continues.. 10.85827820841223 current i =  126
Training continues.. 10.894641318358481 current i =  127
Training continues.. 10.948317998088896 current i =  128
Training continues.. 11.065633953548968 current i =  129
Training continues.. 11.108565990813076 current i =  130
Training continues.. 11.136850350536406 current i =  131
Training continues.. 11.217252128757536 current i =  132
Training continues.. 11.337877676822245 current i =  133
Training continues.. 11.366178666241467 current i =  134
Training continues.. 11.42918207962066 current i =  135
Training continues.. 11.440186542458832 current i =  136
Training continues.. 11.579512101598084 current i =  137
Training continues.. 11.647416417486966 current i =  138
Training continues.. 11.700343758799136 current i =  139
Training continues.. 11.766680807806551 current i =  140
Training continues.. 11.965469346381724 current i =  141
Training continues.. 12.095379025675356 current i =  142
Training continues.. 12.259339273

Training continues.. 22.411455892026424 current i =  271
Training continues.. 22.453161120414734 current i =  272
Training continues.. 22.47893986478448 current i =  273
Training continues.. 22.496838346123695 current i =  274
Training continues.. 22.54138421267271 current i =  275
Training continues.. 22.68954337388277 current i =  276
Training continues.. 22.9338633492589 current i =  277
Training continues.. 23.071681313216686 current i =  278
Training continues.. 23.116086658090353 current i =  279
Training continues.. 23.169192291796207 current i =  280
Training continues.. 23.255671232938766 current i =  281
Training continues.. 23.32889563590288 current i =  282
Training continues.. 23.39395473897457 current i =  283
Training continues.. 23.41503018513322 current i =  284
Training continues.. 23.537308763712645 current i =  285
Training continues.. 23.554905211552978 current i =  286
Training continues.. 23.566326931118965 current i =  287
Training continues.. 23.58576206304133 

Training continues.. 1.8978794128634036 current i =  44
Training continues.. 1.9212970775552094 current i =  45
Training continues.. 2.0002861809916794 current i =  46
Training continues.. 2.0988921341486275 current i =  47
Training continues.. 2.145326551515609 current i =  48
Training continues.. 2.1533113536424935 current i =  49
Training continues.. 2.1618494209833443 current i =  50
Training continues.. 2.2358464491553605 current i =  51
Training continues.. 2.28062003524974 current i =  52
Training continues.. 2.323721725959331 current i =  53
Training continues.. 2.350558251608163 current i =  54
Training continues.. 2.3775996253825724 current i =  55
Training continues.. 2.3979181335307658 current i =  56
Training continues.. 2.444048708770424 current i =  57
Training continues.. 2.4829732277430594 current i =  58
Training continues.. 2.5085088689811528 current i =  59
Training continues.. 2.559577312786132 current i =  60
Training continues.. 2.5973539580591023 current i =  61

Training continues.. 11.797932791989297 current i =  191
Training continues.. 11.852777648251504 current i =  192
Training continues.. 11.941053296905011 current i =  193
Training continues.. 12.035048361402005 current i =  194
Training continues.. 12.092915325891227 current i =  195
Training continues.. 12.134564898442477 current i =  196
Training continues.. 12.215563669335097 current i =  197
Training continues.. 12.29970571352169 current i =  198
Training continues.. 12.350403800141066 current i =  199
Training continues.. 12.383271187078208 current i =  200
Training continues.. 12.430196057539433 current i =  201
Training continues.. 12.528793822508305 current i =  202
Training continues.. 12.593178618233651 current i =  203
Training continues.. 12.717173803132027 current i =  204
Training continues.. 12.740290034096688 current i =  205
Training continues.. 12.811302687507123 current i =  206
Training continues.. 12.863578181248158 current i =  207
Training continues.. 13.03960001

Training continues.. 21.605446920497343 current i =  336
Training continues.. 21.669425279600546 current i =  337
Training continues.. 21.76550351898186 current i =  338
Training continues.. 21.799254291458055 current i =  339
Training continues.. 21.856953867478296 current i =  340
Training continues.. 21.877423519967124 current i =  341
Training continues.. 21.919846343575045 current i =  342
Training continues.. 21.95831316220574 current i =  343
Training continues.. 22.008680320112035 current i =  344
Training continues.. 22.01890587876551 current i =  345
Training continues.. 22.029456155141816 current i =  346
Training continues.. 22.05101174651645 current i =  347
Training continues.. 22.094815088203177 current i =  348
Training continues.. 22.14370827912353 current i =  349
Training continues.. 22.165281794732437 current i =  350
Training continues.. 22.172985561890528 current i =  351
Training continues.. 22.21748545090668 current i =  352
Training continues.. 22.3172796450089

## Evaluation

In [None]:
# Validation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

validation_accuracy = 100 * correct / total
print(f"Validation Accuracy: {validation_accuracy}%")

In [None]:
# Testing loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy}%")