# 2 Layer Neural Network

## Imports

In [1]:
import os
import pandas as pd

## Import data and create .csv

In [2]:
# Directory to MRI Images
base_dir = '/Users/benrandoing/Downloads/archive/AugmentedAlzheimerDataset'
classes = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']

if not os.path.exists('dataset.csv'):
    data = []

    for label, class_name in enumerate(classes):
        class_dir = os.path.join(base_dir, class_name)
        for image_name in os.listdir(class_dir):
            image_path = os.path.join(class_dir, image_name)
            data.append([image_path, label])

    df = pd.DataFrame(data, columns=['path', 'label'])
    df.to_csv('dataset.csv', index=False)

## Data Loading and Pre-Processing

In [3]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.dataframe = pd.read_csv(csv_file)
        self.dataframe = self.dataframe.sample(frac=1).reset_index(drop=True)  # Shuffle the dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        image = Image.open(img_name)
        label = int(self.dataframe.iloc[idx, 1])

        if self.transform:
            image = self.transform(image)

        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize image to 224 X 224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    #transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Load the entire dataset
dataset = CustomDataset(csv_file='dataset.csv', transform=transform)

# Split the dataset into train, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [4]:
print(f'Training Samples: {len(train_dataset)}')
print(f'Validation Samples: {len(val_dataset)}')
print(f'Test Samples: {len(test_dataset)}')

Training Samples: 23788
Validation Samples: 5097
Test Samples: 5099


In [5]:
print(train_dataset)

<torch.utils.data.dataset.Subset object at 0x7fe46f6d4310>


In [6]:
import matplotlib.pyplot as plt
#plt.plot(train_dataset[0][0])
#plt.show()

## Pytorch Logistic Regression Model

In [7]:
class FiveLayerNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4, output_dim):
        super(FiveLayerNN, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim1)
        self.fc2 = torch.nn.Linear(hidden_dim1, hidden_dim2)
        self.fc3 = torch.nn.Linear(hidden_dim2, hidden_dim3)
        self.fc4 = torch.nn.Linear(hidden_dim3, hidden_dim4)
        self.fc5 = torch.nn.Linear(hidden_dim4, output_dim)

    def forward(self, x):
        #x = x.view(-1, 224 * 224 * 3)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = self.fc5(x)
        return x

input_dim = 224 * 224 * 3  # Assuming the images are RGB
hidden_dim1 = 512
hidden_dim2 = 256
hidden_dim3 = 128
hidden_dim4 = 64
output_dim = 4  # Four classes

#model = FiveLayerNN(input_dim, hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4, output_dim)
import torchvision
from torchvision.models import resnet18, ResNet18_Weights
model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

model.fc = torch.nn.Linear(model.fc.in_features, output_dim)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/venkatasaikrishna.gudladona/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100.0%


## Training Loop

In [8]:
learning_rate = 0.001
num_epochs = 10
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    print("num_epochs",num_epochs)
    model.train()
    total_loss = 0.0

    i = 0
    for inputs, labels in train_loader:
        i = i + 1
        print("Training continues..", total_loss)    
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Iteration {epoch} - Loss = {total_loss / len(train_loader)}")

num_epochs 10
Training continues.. 0.0
Training continues.. 1.5003341436386108
Training continues.. 2.9215891361236572
Training continues.. 4.453795909881592
Training continues.. 5.941872954368591
Training continues.. 7.392863512039185
Training continues.. 8.770994782447815
Training continues.. 10.117185235023499
Training continues.. 11.480648636817932
Training continues.. 12.96809732913971
Training continues.. 14.38499140739441
Training continues.. 15.781501173973083
Training continues.. 17.225167632102966
Training continues.. 18.595067381858826
Training continues.. 20.047435522079468
Training continues.. 21.464048981666565
Training continues.. 22.89060342311859
Training continues.. 24.31051743030548
Training continues.. 25.625083208084106
Training continues.. 27.006391048431396
Training continues.. 28.38243842124939
Training continues.. 29.74627196788788
Training continues.. 31.098942637443542
Training continues.. 32.41326177120209
Training continues.. 33.77719533443451
Training cont

Training continues.. 256.28248512744904
Training continues.. 257.31386172771454
Training continues.. 258.4779050350189
Training continues.. 259.54441690444946
Training continues.. 260.61894404888153
Training continues.. 261.72780895233154
Training continues.. 262.8543293476105
Training continues.. 263.93751060962677
Training continues.. 264.99278569221497
Training continues.. 266.0736118555069
Training continues.. 267.11607670783997
Training continues.. 268.1723461151123
Training continues.. 269.3209328651428
Training continues.. 270.39528250694275
Training continues.. 271.45183289051056
Training continues.. 272.46773767471313
Training continues.. 273.5633808374405
Training continues.. 274.65116453170776
Training continues.. 275.64660930633545
Training continues.. 276.64222425222397
Training continues.. 277.6154003739357
Training continues.. 278.7015183568001
Training continues.. 279.810181081295
Training continues.. 280.8164733052254
Training continues.. 281.8959828019142
Training con

Training continues.. 38.464529514312744
Training continues.. 39.29984134435654
Training continues.. 40.11619472503662
Training continues.. 40.94776624441147
Training continues.. 41.83887076377869
Training continues.. 42.688169836997986
Training continues.. 43.59857827425003
Training continues.. 44.6246218085289
Training continues.. 45.528781950473785
Training continues.. 46.44905632734299
Training continues.. 47.25319766998291
Training continues.. 48.147756934165955
Training continues.. 48.95941227674484
Training continues.. 49.79781687259674
Training continues.. 50.664252161979675
Training continues.. 51.56045514345169
Training continues.. 52.59926122426987
Training continues.. 53.44268226623535
Training continues.. 54.33246046304703
Training continues.. 55.20475256443024
Training continues.. 56.094626903533936
Training continues.. 57.022112429142
Training continues.. 57.834626734256744
Training continues.. 58.61421847343445
Training continues.. 59.53823697566986
Training continues.. 

Training continues.. 210.51274210214615
Training continues.. 211.28048676252365
Training continues.. 212.07923287153244
Training continues.. 212.8750103712082
Training continues.. 213.59399032592773
Training continues.. 214.33833372592926
Training continues.. 215.0806645154953
Training continues.. 215.92350256443024
Training continues.. 216.74889916181564
Training continues.. 217.53458499908447
Training continues.. 218.22846615314484
Training continues.. 219.0228374004364
Training continues.. 219.90993005037308
Training continues.. 220.67727041244507
Training continues.. 221.42483747005463
Training continues.. 222.2897514104843
Training continues.. 223.05891531705856
Training continues.. 223.80280542373657
Training continues.. 224.63708597421646
Training continues.. 225.56425893306732
Training continues.. 226.27449142932892
Training continues.. 227.03208029270172
Training continues.. 227.74543952941895
Training continues.. 228.51450437307358
Training continues.. 229.24550485610962
Trai

Training continues.. 61.977045595645905
Training continues.. 62.59680885076523
Training continues.. 63.23989248275757
Training continues.. 63.961911618709564
Training continues.. 64.7553151845932
Training continues.. 65.42182916402817
Training continues.. 66.10650897026062
Training continues.. 66.84278231859207
Training continues.. 67.52621978521347
Training continues.. 68.23362994194031
Training continues.. 68.94355964660645
Training continues.. 69.62423759698868
Training continues.. 70.24223411083221
Training continues.. 71.07859015464783
Training continues.. 71.796122610569
Training continues.. 72.42774403095245
Training continues.. 73.17909586429596
Training continues.. 73.82068240642548
Training continues.. 74.47009325027466
Training continues.. 75.34447407722473
Training continues.. 76.09124100208282
Training continues.. 76.86599546670914
Training continues.. 77.58952206373215
Training continues.. 78.27082926034927
Training continues.. 79.04916560649872
Training continues.. 79.88

Training continues.. 204.57681930065155
Training continues.. 205.2857465147972
Training continues.. 206.07707518339157
Training continues.. 206.76275426149368
Training continues.. 207.52547496557236
Training continues.. 208.13849633932114
Training continues.. 208.7429414987564
Training continues.. 209.33810150623322
Training continues.. 209.98373401165009
Training continues.. 210.64724546670914
Training continues.. 211.34069126844406
Training continues.. 211.98547792434692
Training continues.. 212.653822183609
Training continues.. 213.49163782596588
Training continues.. 214.18268525600433
Training continues.. 214.76779317855835
Training continues.. 215.4430251121521
Training continues.. 216.07965517044067
Training continues.. 216.90135407447815
Training continues.. 217.51536452770233
Training continues.. 218.179079413414
Training continues.. 218.81965386867523
Training continues.. 219.44550615549088
Training continues.. 220.06781208515167
Training continues.. 220.58358359336853
Trainin

Training continues.. 80.95762354135513
Training continues.. 81.58998990058899
Training continues.. 82.16535288095474
Training continues.. 82.86776506900787
Training continues.. 83.39771175384521
Training continues.. 83.98869061470032
Training continues.. 84.60117048025131
Training continues.. 85.34197348356247
Training continues.. 85.92281752824783
Training continues.. 86.54925084114075
Training continues.. 87.07535934448242
Training continues.. 87.67800068855286
Training continues.. 88.3841187953949
Training continues.. 88.92864298820496
Training continues.. 89.60622322559357
Training continues.. 90.171790599823
Training continues.. 90.82531887292862
Training continues.. 91.43102651834488
Training continues.. 92.03835010528564
Training continues.. 92.67936110496521
Training continues.. 93.31703907251358
Training continues.. 93.95932883024216
Training continues.. 94.62813949584961
Training continues.. 95.2251181602478
Training continues.. 95.8742378950119
Training continues.. 96.558881

Training continues.. 207.0104037821293
Training continues.. 207.57402619719505
Training continues.. 208.17590817809105
Training continues.. 208.905361443758
Training continues.. 209.50347891449928
Training continues.. 210.09804233908653
Training continues.. 210.71492591500282
Training continues.. 211.30155017971992
Training continues.. 211.85471758246422
Training continues.. 212.35121458768845
Training continues.. 212.81796273589134
Training continues.. 213.44312927126884
Training continues.. 214.01055166125298
Training continues.. 214.6424082815647
Training continues.. 215.3042530119419
Training continues.. 215.96619352698326
Training continues.. 216.56938323378563
Training continues.. 217.19292613863945
Training continues.. 217.69452694058418
Training continues.. 218.27171322703362
Training continues.. 218.83744022250175
Training continues.. 219.39723166823387
Training continues.. 219.9431088268757
Training continues.. 220.46334365010262
Training continues.. 221.09753915667534
Traini

Training continues.. 94.95870858430862
Training continues.. 95.49061530828476
Training continues.. 96.16882854700089
Training continues.. 96.69646441936493
Training continues.. 97.22710573673248
Training continues.. 97.77465391159058
Training continues.. 98.3980250954628
Training continues.. 99.02155536413193
Training continues.. 99.55269020795822
Training continues.. 100.0714670419693
Training continues.. 100.6236457824707
Training continues.. 101.13568478822708
Training continues.. 101.63410115242004
Training continues.. 102.1428250670433
Training continues.. 102.7057626247406
Training continues.. 103.20550119876862
Training continues.. 103.80327850580215
Training continues.. 104.33151257038116
Training continues.. 105.02456772327423
Training continues.. 105.50536170601845
Training continues.. 106.11819383502007
Training continues.. 106.62074884772301
Training continues.. 107.06356757879257
Training continues.. 107.60210078954697
Training continues.. 108.19056111574173
Training conti

Training continues.. 2.8762634098529816
Training continues.. 3.4140922129154205
Training continues.. 3.979842871427536
Training continues.. 4.372456610202789
Training continues.. 4.844874680042267
Training continues.. 5.290949702262878
Training continues.. 5.659269541501999
Training continues.. 6.210281282663345
Training continues.. 6.700534850358963
Training continues.. 7.349573224782944
Training continues.. 7.789117395877838
Training continues.. 8.358679473400116
Training continues.. 8.72166445851326
Training continues.. 9.307283490896225
Training continues.. 9.877208024263382
Training continues.. 10.29994022846222
Training continues.. 10.931970596313477
Training continues.. 11.49390023946762
Training continues.. 12.012692391872406
Training continues.. 12.3861083984375
Training continues.. 13.043296694755554
Training continues.. 13.505269855260849
Training continues.. 13.974385023117065
Training continues.. 14.592040359973907
Training continues.. 15.155127882957458
Training continues

Training continues.. 106.58829766511917
Training continues.. 106.99073800444603
Training continues.. 107.43168243765831
Training continues.. 107.84056866168976
Training continues.. 108.30699652433395
Training continues.. 108.89279460906982
Training continues.. 109.3301619887352
Training continues.. 109.81817218661308
Training continues.. 110.3491282761097
Training continues.. 110.82620841264725
Training continues.. 111.32792520523071
Training continues.. 111.74870154261589
Training continues.. 112.19309517741203
Training continues.. 112.72046425938606
Training continues.. 113.28237774968147
Training continues.. 113.73221379518509
Training continues.. 114.12214684486389
Training continues.. 114.60078272223473
Training continues.. 115.03445175290108
Training continues.. 115.49677109718323
Training continues.. 115.93587562441826
Training continues.. 116.34999072551727
Training continues.. 116.90667933225632
Training continues.. 117.4337130188942
Training continues.. 118.04179793596268
Tra

Training continues.. 21.170723974704742
Training continues.. 21.689005613327026
Training continues.. 22.049361765384674
Training continues.. 22.49346151947975
Training continues.. 22.892235726118088
Training continues.. 23.478890746831894
Training continues.. 23.995106011629105
Training continues.. 24.39838805794716
Training continues.. 24.901491194963455
Training continues.. 25.358389675617218
Training continues.. 25.878132820129395
Training continues.. 26.322750210762024
Training continues.. 26.9504354596138
Training continues.. 27.348380506038666
Training continues.. 27.748483061790466
Training continues.. 28.136681973934174
Training continues.. 28.524559289216995
Training continues.. 29.025404423475266
Training continues.. 29.395657658576965
Training continues.. 29.960613250732422
Training continues.. 30.42141929268837
Training continues.. 30.806321173906326
Training continues.. 31.33890399336815
Training continues.. 31.78272956609726
Training continues.. 32.15987694263458
Training

Training continues.. 112.78610649704933
Training continues.. 113.19400569796562
Training continues.. 113.62058833241463
Training continues.. 114.00517362356186
Training continues.. 114.34245616197586
Training continues.. 114.7051708996296
Training continues.. 115.14301013946533
Training continues.. 115.57337212562561
Training continues.. 116.02441614866257
Training continues.. 116.39996057748795
Training continues.. 116.80924054980278
Training continues.. 117.20808127522469
Training continues.. 117.68226999044418
Training continues.. 118.11124467849731
Training continues.. 118.4291604757309
Training continues.. 118.83886107802391
Training continues.. 119.27108180522919
Training continues.. 119.68437397480011
Training continues.. 120.1104366183281
Training continues.. 120.6308012008667
Training continues.. 120.96030589938164
Training continues.. 121.33640050888062
Training continues.. 121.72622174024582
Training continues.. 122.06114235520363
Training continues.. 122.42890286445618
Trai

Training continues.. 35.99694350361824
Training continues.. 36.344687819480896
Training continues.. 36.67611852288246
Training continues.. 37.017181277275085
Training continues.. 37.38259422779083
Training continues.. 37.86085256934166
Training continues.. 38.3287858068943
Training continues.. 38.645299047231674
Training continues.. 38.98699280619621
Training continues.. 39.35627889633179
Training continues.. 39.78743356466293
Training continues.. 40.23144602775574
Training continues.. 40.56980940699577
Training continues.. 41.06157049536705
Training continues.. 41.42237198352814
Training continues.. 41.834830582141876
Training continues.. 42.255999594926834
Training continues.. 42.702450692653656
Training continues.. 43.146363258361816
Training continues.. 43.48755303025246
Training continues.. 43.86063453555107
Training continues.. 44.3322978913784
Training continues.. 44.68627879023552
Training continues.. 44.99230748414993
Training continues.. 45.37796354293823
Training continues..

Training continues.. 114.63925431668758
Training continues.. 115.01152639091015
Training continues.. 115.47365520894527
Training continues.. 115.81965081393719
Training continues.. 116.243626460433
Training continues.. 116.60152031481266
Training continues.. 117.07016687095165
Training continues.. 117.4462124556303
Training continues.. 117.8416572958231
Training continues.. 118.25684924423695
Training continues.. 118.57596434652805
Training continues.. 118.88402386009693
Training continues.. 119.22030831873417
Training continues.. 119.63581727445126
Training continues.. 119.96977816522121
Training continues.. 120.30851499736309
Training continues.. 120.6628097742796
Training continues.. 121.05455596745014
Training continues.. 121.39277245104313
Training continues.. 121.7208604067564
Training continues.. 121.98388250172138
Training continues.. 122.33136190474033
Training continues.. 122.66782413423061
Training continues.. 123.03515563905239
Training continues.. 123.469537332654
Training

Training continues.. 45.46091492474079
Training continues.. 45.811736926436424
Training continues.. 46.05620823800564
Training continues.. 46.3033741414547
Training continues.. 46.58445319533348
Training continues.. 46.792448088526726
Training continues.. 47.128706738352776
Training continues.. 47.398185178637505
Training continues.. 47.65236885845661
Training continues.. 47.94891966879368
Training continues.. 48.20311151444912
Training continues.. 48.455529525876045
Training continues.. 48.7575551122427
Training continues.. 49.1162516027689
Training continues.. 49.407301530241966
Training continues.. 49.77687497437
Training continues.. 50.06083790957928
Training continues.. 50.42003472149372
Training continues.. 50.75782190263271
Training continues.. 51.19907648861408
Training continues.. 51.58816812932491
Training continues.. 51.885988131165504
Training continues.. 52.23137886822224
Training continues.. 52.46744781732559
Training continues.. 52.83043074607849
Training continues.. 53.

Training continues.. 111.59266105294228
Training continues.. 111.99387401342392
Training continues.. 112.2480499446392
Training continues.. 112.5602317750454
Training continues.. 112.96343559026718
Training continues.. 113.31610116362572
Training continues.. 113.62340787053108
Training continues.. 113.98132386803627
Training continues.. 114.32718056440353
Training continues.. 114.58284592628479
Training continues.. 114.86625880002975
Training continues.. 115.19208216667175
Training continues.. 115.38280454277992
Training continues.. 115.75302219390869
Training continues.. 116.10359516739845
Training continues.. 116.37559109926224
Training continues.. 116.72572296857834
Training continues.. 117.03228148818016
Training continues.. 117.41908431053162
Training continues.. 117.67784005403519
Training continues.. 117.94858610630035
Training continues.. 118.23817598819733
Training continues.. 118.50220438838005
Iteration 8 - Loss = 0.31927151144832694
num_epochs 10
Training continues.. 0.0
Tr

Training continues.. 51.36725701391697
Training continues.. 51.60734781622887
Training continues.. 51.84012673795223
Training continues.. 52.07931046187878
Training continues.. 52.296122550964355
Training continues.. 52.53217801451683
Training continues.. 52.79763105511665
Training continues.. 53.14176094532013
Training continues.. 53.40663546323776
Training continues.. 53.633723959326744
Training continues.. 53.896506026387215
Training continues.. 54.22600491344929
Training continues.. 54.66295354068279
Training continues.. 54.96837244927883
Training continues.. 55.22846142947674
Training continues.. 55.44802577793598
Training continues.. 55.70440621674061
Training continues.. 56.022621020674706
Training continues.. 56.3091526478529
Training continues.. 56.50366294384003
Training continues.. 56.755234479904175
Training continues.. 56.9500337690115
Training continues.. 57.19842268526554
Training continues.. 57.46142067015171
Training continues.. 57.71751447021961
Training continues.. 5

## Evaluation

In [9]:
# Validation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

validation_accuracy = 100 * correct / total
print(f"Validation Accuracy: {validation_accuracy}%")

Validation Accuracy: 87.18854227977242%


In [10]:
# Testing loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy}%")

Test Accuracy: 87.40929594038046%
