In [13]:
import pandas as pd
import torch
from torchvision import datasets , transforms
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from src.models.cnn import CNN
from src.utility.generate_images import makeSyntheticTrain
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [15]:
import os
[f for f in os.listdir('data/alzheimer_mri/train/NonDemented')
 if int(f.split("_")[1].split(".")[0]) not in [
     int(f.split("_")[1]) for f in os.listdir('data/alzheimer_mri/synthetic/NonDemented')
     ]
]

['NonDemented_1803.png',
 'NonDemented_846.png',
 'NonDemented_758.png',
 'NonDemented_179.png',
 'NonDemented_841.png',
 'NonDemented_2268.png',
 'NonDemented_256.png',
 'NonDemented_330.png',
 'NonDemented_876.png',
 'NonDemented_535.png',
 'NonDemented_1908.png',
 'NonDemented_2492.png',
 'NonDemented_1325.png',
 'NonDemented_117.png',
 'NonDemented_287.png',
 'NonDemented_2129.png',
 'NonDemented_2483.png',
 'NonDemented_1505.png',
 'NonDemented_1425.png',
 'NonDemented_1611.png',
 'NonDemented_1413.png',
 'NonDemented_507.png',
 'NonDemented_1421.png',
 'NonDemented_197.png',
 'NonDemented_1152.png',
 'NonDemented_596.png',
 'NonDemented_312.png',
 'NonDemented_232.png',
 'NonDemented_503.png',
 'NonDemented_940.png',
 'NonDemented_89.png',
 'NonDemented_1579.png',
 'NonDemented_1860.png',
 'NonDemented_1961.png',
 'NonDemented_1749.png',
 'NonDemented_314.png',
 'NonDemented_1511.png',
 'NonDemented_195.png',
 'NonDemented_1509.png',
 'NonDemented_101.png',
 'NonDemented_630.png'

#### Data Transformation

In [14]:
# transformation to the image dataseet:
transforms_minimal = v2.Compose(
    [
        transforms.ToTensor(),
        v2.Grayscale(num_output_channels=1),
        v2.Resize((128, 128))  # Resize to a fixed size
    ]
)

transforms_basic = v2.Compose(
    [
        transforms.ToTensor(),
        v2.Grayscale(num_output_channels=1),
        v2.RandomHorizontalFlip(p=0.5),
        v2.Resize((128, 128))  # Resize to a fixed size
    ]
)

transforms_auto = v2.Compose(
    [
        transforms.ToTensor(),
        v2.Grayscale(num_output_channels=1),
        v2.AutoAugment(policy=v2.AutoAugmentPolicy.IMAGENET),
        v2.Resize((128, 128))  # Resize to a fixed size
    ]
)


#### Train Model

In [15]:
def train_model(
        train_dir,
        test_dir,
        train_perc,
        synthetic_perc,
        transform,
        transform_name
):
    # load images
    train_dataset = datasets.ImageFolder(train_dir, transform=transform)
    test_dataset = datasets.ImageFolder(test_dir, transform=transform)
    
    # create dataloader
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

    # Split the dataset into training and validation sets (80-20 split)
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_data, val_data= torch.utils.data.random_split(train_dataset, [train_size, val_size])

    # Create data loaders
    batch_size = 64
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size)

    # Initialize the model
    model = CNN(in_channels=1, num_classes=4)

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # train model
    best_val_loss = float('inf')
    patience = 2  # Number of epochs to wait for improvement before stopping
    patience_counter = 0
    train_losses = []
    val_losses = []
    val_accuracies = []
    all_preds = []
    all_labels = []
    num_epochs=10

    for epoch in range(num_epochs):
        
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)

        model.eval()
        val_loss = 0.0
        val_corrects = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += torch.sum(preds == labels.data)
                all_preds.extend(preds.view(-1).cpu().numpy())
                all_labels.extend(labels.view(-1).cpu().numpy())

        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        val_accuracy = val_corrects.double() / len(val_loader.dataset)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0  # Reset counter
            torch.save(model.state_dict(), 'best_model.pth') # Save the model
        else:
            patience_counter += 1

        # Early stopping check
        if patience_counter >= patience:
            print("Stopping early due to no improvement in validation loss.")
            break
        
    # store results in dataframe
    dat = {
        "train_percentage":[train_perc]*len(val_losses),
        "synthetic_percentage":[synthetic_perc]*len(val_losses),
        "transform":transform_name,
        "epoch": range(len(val_losses)),
        "val_losses": val_losses,
        "train_losses": train_losses,
        "val_accuracies": [acc.item() for acc in val_accuracies]
    }

    result_df = pd.DataFrame(data=dat)

    return result_df

In [16]:
train_dir = "data/alzheimer_mri/train"
test_dir = "data/alzheimer_mri/test"
synthetic_dir = "data/alzheimer_mri/synthetic_train"
n_sims = 3

df_all_results = pd.DataFrame(columns = [
    "sim_num","train_percentage","synthetic_percentage","transform",
    "epoch","val_losses","train_losses","val_accuracies"
])

for train_perc in [0.8]:#, 0.6, 0.7, 0.8, 0.9, 1.0]:
    
    for synthetic_perc, transform, transform_name in zip(
        [0.1, 0.1, 0.1], #, 0.2, 0.2, 0.2],
        [transforms_minimal, transforms_basic, transforms_auto], #, transforms_minimal, transforms_basic, transforms_auto],
        ['minimal','basic','auto'] #,'minimal','basic','auto']
        ):

        #for transform, transform_name in zip([transforms_minimal, transforms_basic, transforms_auto],['minimal','basic','auto']):
        
            for n in range(n_sims):
                
                # make synthetic + real mix
                makeSyntheticTrain(train_dir, synthetic_dir, train_perc, synthetic_perc)
                
                df_sim_results = train_model(
                    train_dir = train_dir,
                    test_dir = synthetic_dir,
                    train_perc = train_perc,
                    synthetic_perc = synthetic_perc,
                    transform = transform,
                    transform_name = transform_name
                )

                df_sim_results["sim_num"] = n

                df_all_results = pd.concat([df_all_results, df_sim_results],ignore_index=True)

    df_all_results.to_csv(f"results/sim_results_train{str(train_perc*10).replace('.','')}_synth{str(synthetic_perc*10).replace('.','')}_trans{transform_name}.csv")

KeyboardInterrupt: 