In [None]:
# Import packages
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
import cv2
import optuna
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Need to build my own pytorch dataset by pulling together all images from the 5 folders in train/test/val with their
# corresponding labels (0...4)
train_path = '/sc/arion/projects/shenl03_ml/2022_kevin_mammo/CBIS-DDSM_patch_set/train'
val_path = '/sc/arion/projects/shenl03_ml/2022_kevin_mammo/CBIS-DDSM_patch_set/val'
test_path = '/sc/arion/projects/shenl03_ml/2022_kevin_mammo/CBIS-DDSM_patch_set/test'


#train_img_paths = []
#classes = [] # should get a list of 5

def get_png_paths(path):
    img_paths = []
    classes = []
    
    for p, _, files in os.walk(path):
        for f in files:
            if '.png' in f:
                img_paths.append(os.path.join(p,f))
    return img_paths

train_img_paths = get_png_paths(train_path)
val_img_paths = get_png_paths(val_path)
test_img_paths = get_png_paths(test_path)

mammo_classes = {'background':0,
           'calc_ben':1,
           'calc_mal':2,
           'mass_ben':3,
           'mass_mal':4}

In [None]:
# Create a custom PyTorch Dataset with tuples of image array values and corresponding class represented as integer from mammo_classes dict
class MammoDataset(Dataset):
    def __init__(self,img_paths,transform):
        self.img_paths = img_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.img_paths)
        
    def __getitem__(self, i):
        img_path = self.img_paths[i]
        img = cv2.imread(img_path)

        label = img_path.split('/')[-2] # retrieve label from folder above
        label = mammo_classes.get(label) #turn string label into integer
        if self.transform:
            # apply transformations
            # ToTensor() auto converts uint8 to range (0,1)
            img = self.transform(img)
        return img, label

In [None]:
# Add Normalize() to initial transforms
# ToTensor() converts uint8 values to range (0,1)
train_tf = transforms.Compose([
        transforms.ToTensor(), 
        transforms.RandomHorizontalFlip(.2), 
        transforms.RandomVerticalFlip(.1), 
        transforms.RandomRotation(30), 
        transforms.Normalize((0.3788, 0.3788, 0.3788),(0.1508, 0.1508, 0.1508))
    ])

val_tf = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.3788, 0.3788, 0.3788),(0.1508, 0.1508, 0.1508))
    ])

train_dataset = MammoDataset(train_img_paths,train_tf)
val_dataset = MammoDataset(val_img_paths,val_tf)
test_dataset = MammoDataset(test_img_paths,val_tf)

In [None]:
# Define the neural network that will be used

class mammonet(nn.Module):
    def __init__(self):
        super().__init__()
        #list of layers used
        self.conv1 = nn.Conv2d(3,64,3,1,1)
        self.conv2 = nn.Conv2d(64,64,3,1,1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64,128,3,1,1)
        self.conv4 = nn.Conv2d(128,128,3,1,1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv5 = nn.Conv2d(128,256,3,1,1)
        self.conv6 = nn.Conv2d(256,256,3,1,1)
        self.bn3 = nn.BatchNorm2d(256)
        
        self.conv7 = nn.Conv2d(256,512,3,1,1)
        self.conv8 = nn.Conv2d(512,512,3,1,1)
        self.bn4 = nn.BatchNorm2d(512)
        
        self.pool = nn.MaxPool2d(2,2)
        
        self.fc1 = nn.Linear(512*7*7,5)
        
    def forward(self,x):
        #list of order of nn
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(F.relu(self.bn1(self.conv2(x))))

        x = F.relu(self.bn2(self.conv3(x)))
        x = self.pool(F.relu(self.bn2(self.conv4(x))))

        x = F.relu(self.bn3(self.conv5(x)))
        x = F.relu(self.bn3(self.conv6(x)))
        x = self.pool(F.relu(self.bn3(self.conv6(x))))

        x = F.relu(self.bn4(self.conv7(x)))
        x = F.relu(self.bn4(self.conv8(x)))
        x = self.pool(F.relu(self.bn4(self.conv8(x))))

        x = F.relu(self.bn4(self.conv8(x)))
        x = F.relu(self.bn4(self.conv8(x)))
        x = self.pool(F.relu(self.bn4(self.conv8(x))))

        x = torch.flatten(x,1)
        x = self.fc1(x)

        return x

In [None]:
def objective(trial):
    
    model = mammonet()
    model.to(device)
    
    lr = trial.suggest_float("lr",1e-5,1e-1,log=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    batch_size=trial.suggest_int("batch_size", 2, 6, step=1)
    


    # Make Dataloaders for each dataset
    train_loader = DataLoader(train_dataset,batch_size=2**batch_size,shuffle=True)
    val_loader = DataLoader(val_dataset,batch_size=2**batch_size,shuffle=True)

    epochs = 20
    for epoch in range(epochs):
    
        total_loss = 0
        best_val_acc = -np.inf
        model.train()

        for i, data in enumerate(train_loader):

            # Retrieve inputs
            images, labels = data[0].to(device), data[1].to(device)

            #clear gradient
            optimizer.zero_grad()

            #forward step
            output = model(images)
            loss = criterion(output,labels)

            #backward step
            loss.backward()

            #optimize
            optimizer.step()
            total_loss += loss.item()
        

        # Validation loss
        val_loss = 0.0
        total = 0
        correct = 0
        model.eval()
        for i, data in enumerate(val_loader):
            with torch.no_grad():
                inputs, labels = data[0].to(device), data[1].to(device)

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.item()
        accuracy = correct/total
        if accuracy > best_val_acc:
            best_val_acc = accuracy
            # If this model's val acc is best, save/overwrite best model for this trial
            trial.set_user_attr('Epoch',epoch)
            with open("./mammo_models/mammo_trial_{}.pickle".format(trial.number), "wb") as fout:
                pickle.dump(model, fout)
                
        trial.report(accuracy,epoch)
        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
            
    print("Finished Training")
    return accuracy

In [None]:
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

trial = study.best_trial

print('Accuracy: {}'.format(trial.value))
print("Best hyperparameters: {}".format(trial.params))
print('Best epoch: {}'.format(trial.user_attrs))
# Load the best model.
with open("./mammo_models/mammo_trial_{}.pickle".format(trial.number), "rb") as fin:
    best_model = pickle.load(fin)
df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
print(df)

In [None]:
# Test network
correct = 0
total = 0

test_loader = DataLoader(test_dataset,batch_size=2**trial.params["batch_size"],shuffle=False)
# no gradients necessary when testing
with torch.no_grad():
    for data in test_loader:
        images, labels = data[0].to(device), data[1].to(device)
        output = best_model(images)
        
        _, predicted = torch.max(output.data,1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on test images: {100 * correct // total} %')