In [2]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from custom_models import cnn
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import time
import os
from PIL import Image
from torchvision.transforms import v2
from tempfile import TemporaryDirectory
import shutil
import random

cudnn.benchmark = True
plt.ion()   # interactive mode

<contextlib.ExitStack at 0x7f4b6ef7ca00>

#### Data Transformations

In [3]:
# data transformations to loop through
minimal_transforms = {
    'synthetic_train': transforms.Compose([
        v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
        v2.Grayscale(num_output_channels=1),
        v2.Normalize([0.5, ],[0.5, ]),
        v2.Resize((128, 128))
    ]),
    'test': transforms.Compose([
        v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
        v2.Grayscale(num_output_channels=1),
        v2.Normalize([0.5, ],[0.5, ]),
        v2.Resize((128, 128))
    ]),
}

basic_transforms = {
    'synthetic_train': transforms.Compose([
        v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
        v2.Grayscale(num_output_channels=1),
        v2.Normalize([0.5, ],[0.5, ]),
        v2.RandomHorizontalFlip(p=0.5),
        v2.Resize((128, 128))
    ]),
    'test': transforms.Compose([
        v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
        v2.Grayscale(num_output_channels=1),
        v2.Normalize([0.5, ],[0.5, ]),
        v2.Resize((128, 128))
    ]),
}

auto_transforms = {
    'synthetic_train': transforms.Compose([
        v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
        v2.Grayscale(num_output_channels=1),
        v2.Normalize([0.5, ],[0.5, ]),
        v2.AutoAugment(policy=v2.AutoAugmentPolicy.IMAGENET),
        v2.Resize((128, 128))
    ]),
    'test': transforms.Compose([
        v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
        v2.Grayscale(num_output_channels=1),
        v2.Normalize([0.5, ],[0.5, ]),
        v2.Resize((128, 128))
    ]),
}

#### Create Training Set with Real/Synthetic Images

In [4]:
def makeSyntheticTrain(train_directory, synthetic_train_directory, train_dict, synthetic_dict):

    # Remove any existing images in directory
    try:
        shutil.rmtree(synthetic_train_directory)
    except:
        print("directory does not exist")

    # Loop through subfolders, generate synthetic images
    subfolders = [f for f in os.listdir(train_directory)]

    for s in subfolders:
        # for each subfolder in the train directory, make the same in the synthetic train directory
        os.makedirs(f"{synthetic_train_directory}/{s}", exist_ok=True)
        
        # get a random sample from each subfolder
        subfolder_path = f"{train_directory}/{s}"
        files = os.listdir(subfolder_path)
        sample_files = random.sample(files, round(len(files)*train_dict[s]))
        
        # create synthetic sample based on sampled original images
        synthetic_subfolder_path = subfolder_path.replace('train','synthetic')
        synthetic_files = [f for f in os.listdir(synthetic_subfolder_path) if int(f.replace('.png','').split('_')[1]) in [int(f.replace('.png','').split('_')[1]) for f in sample_files]]
        synthetic_sample_files = random.sample(synthetic_files, round(len(files)*synthetic_dict[s]))
        
        # Move sample files to synthetic directory
        for f in sample_files:
            
            image_path = f"{subfolder_path}/{f}"
            destination_directory = f"{synthetic_train_directory}/{s}/"
            shutil.copyfile(image_path, destination_directory + image_path.split('/')[-1])

        # Move synthetic sample files to synthetic directory
        for f in synthetic_sample_files:

            image_path = f"{synthetic_subfolder_path}/{f}"
            destination_directory = f"{synthetic_train_directory}/{s}/"
            shutil.copyfile(image_path, destination_directory + image_path.split('/')[-1])

#### Read Data

In [5]:
def get_data(data_dir, data_sets, data_transforms):
    
    image_datasets = {
        x: datasets.ImageFolder(
            os.path.join(data_dir, x),
            data_transforms[x]
        )
        for x in data_sets
    }

    dataloaders = {
        x: DataLoader(
            image_datasets[x],
            batch_size=16,
            shuffle=True,
        )
        for x in data_sets
    }

    dataset_sizes = {
        x: len(image_datasets[x]) 
        for x in data_sets
    }

    class_names = image_datasets['synthetic_train'].classes

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    return image_datasets, dataloaders, dataset_sizes, class_names, device

#### Train Model

In [6]:
def train_model(model, criterion, optimizer, dataloaders, dataset_sizes, num_epochs=10):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        
        best_test_loss = float('inf')
        patience = 2  # Number of epochs to wait for improvement before stopping
        test_losses = []
        train_losses = []
        test_acc = []
        train_acc = []
        patience_counter = 0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['synthetic_train', 'test']:
                if phase == 'synthetic_train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'synthetic_train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'synthetic_train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double().item() / dataset_sizes[phase]

                if phase =='synthetic_train':
                    train_losses.append(epoch_loss)
                    train_acc.append(epoch_acc)
                else:
                    test_losses.append(epoch_loss)
                    test_acc.append(epoch_acc)
                    
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # deep copy the model
                if phase == 'test' and epoch_loss <= best_test_loss:
                    best_test_loss = epoch_loss
                    patience_counter = 0  # Reset counter
                    torch.save(model.state_dict(), best_model_params_path)
                elif phase == 'test' and epoch_loss > best_test_loss:
                    patience_counter += 1
                
                    # Early stopping check
                    if patience_counter >= patience:
                        print("Stopping early due to no improvement in validation loss.")
                        break

        # store results in dataframe
        dat = {
            "epoch": range(len(test_losses)),
            "test_losses": test_losses,
            "train_losses": train_losses,
            "test_accuracies": test_acc,
            "train_accuracies": train_acc
        }

        result = pd.DataFrame(data=dat)
        print()
        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best test Loss: {best_test_loss:4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
    return result, model

#### Loop through different training scenarios

In [7]:
num_epochs = 15
num_sims = 20

train_percentage_dict = {
    'NonDemented':1.0,
    'VeryMildDemented':1.0,
    'MildDemented':1.0,
    'ModerateDemented':1.0,
}

synth_percentage_dict = {
    'NonDemented':0.2,
    'VeryMildDemented':0.2,
    'MildDemented':0.2,
    'ModerateDemented':0.2,
}

transforms = {
    'minimal': minimal_transforms,
    'basic': basic_transforms,
    'auto': auto_transforms
}

for active_transform in transforms.keys():
    df_all_results = pd.DataFrame()
    for n in range(num_sims):

        # make the synthetic training dataset
        makeSyntheticTrain(
            train_directory='../data/alzheimer_mri/train',
            synthetic_train_directory='../data/alzheimer_mri/synthetic_train', 
            train_dict=train_percentage_dict, 
            synthetic_dict=synth_percentage_dict
        )

        # get and load datasets
        data_dir = '../data/alzheimer_mri'
        data_sets = ['synthetic_train','test']
        image_datasets, dataloaders, dataset_sizes, class_names, device = get_data(
            data_dir=data_dir, 
            data_sets=data_sets, 
            data_transforms=transforms.get(active_transform)
        )

        # instantiate model
        model = cnn.CNN(in_channels=1, num_classes=4)

        # Set the size of each output sample to nn.Linear(num_ftrs, len(class_names))
        #num_ftrs = 4 #model.fc.in_features
        #model.fc = nn.Linear(num_ftrs, len(class_names))
        model = model.to(device)

        # train the model
        df_results,_ = model = train_model(
            model=model, 
            criterion = nn.CrossEntropyLoss(),
            optimizer = optim.Adam(model.parameters(), lr=0.001), 
            dataloaders=dataloaders,
            dataset_sizes=dataset_sizes,
            num_epochs=num_epochs
        )

        #df_results['train_percentage'] = train_percentage
        #df_results['synth_percentage'] = synthetic_percentage
        df_results['train_synth_ratio'] = '__'.join([k+str(v1)+'_'+str(v2) for k,v1,v2 in zip(train_percentage_dict.keys(),train_percentage_dict.values(),synth_percentage_dict.values())])
        df_results['transform'] = active_transform
        df_results['sim_num'] = n
        df_results['category'] = df_results.apply(lambda row: row['transform']+'_'+'__'.join([k+str(v1)+'_'+str(v2) for k,v1,v2 in zip(train_percentage_dict.keys(),train_percentage_dict.values(),synth_percentage_dict.values())]), axis=1)
        df_all_results = pd.concat([df_all_results, df_results],ignore_index=True)
    df_all_results.to_csv(f'../results/results_cnn_{active_transform}' + '__'.join([k+str(v1)+'_'+str(v2) for k,v1,v2 in zip(train_percentage_dict.keys(),train_percentage_dict.values(),synth_percentage_dict.values())]) + '.csv')

Epoch 0/14
----------
synthetic_train Loss: 0.9455 Acc: 0.5387
test Loss: 0.8193 Acc: 0.5820
Epoch 1/14
----------
synthetic_train Loss: 0.5870 Acc: 0.7493
test Loss: 0.4248 Acc: 0.8375
Epoch 2/14
----------
synthetic_train Loss: 0.2602 Acc: 0.9030
test Loss: 0.1945 Acc: 0.9313
Epoch 3/14
----------
synthetic_train Loss: 0.0898 Acc: 0.9674
test Loss: 0.3026 Acc: 0.9094
Epoch 4/14
----------
synthetic_train Loss: 0.0565 Acc: 0.9814
test Loss: 0.1549 Acc: 0.9508
Epoch 5/14
----------
synthetic_train Loss: 0.0369 Acc: 0.9893
test Loss: 0.1132 Acc: 0.9711
Epoch 6/14
----------
synthetic_train Loss: 0.0198 Acc: 0.9927
test Loss: 0.1498 Acc: 0.9641
Epoch 7/14
----------
synthetic_train Loss: 0.0282 Acc: 0.9909
test Loss: 0.1820 Acc: 0.9578
Stopping early due to no improvement in validation loss.
Epoch 8/14
----------
synthetic_train Loss: 0.0102 Acc: 0.9972
test Loss: 0.2261 Acc: 0.9570
Stopping early due to no improvement in validation loss.
Epoch 9/14
----------
synthetic_train Loss: 0.025