In [None]:
import os
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import numpy as np
from skimage import io, transform
import fnmatch
import string
from tqdm import tqdm
from shutil import copyfile
import matplotlib.pyplot as plt
from sacred import Experiment
from sacred.observers import FileStorageObserver

In [None]:
class ImageDataset(Dataset):

    def __init__(self, data_dir="Braille Dataset/"):
        # Initialization of data directory and list of all of the paths to each image in the data
        self.data_dir = data_dir
        self.image_path_list = sorted(self._find_files(data_dir))

    def __len__(self):
        '''
        Function to get the length of the dataset
        '''
        return len(self.image_path_list)

    def __getitem__(self, index):
        '''
        Function to be able to select images and corresponding labels from the dataset
        '''
        # Convert string labels to integers
        labels = []
        for path in self.image_path_list:
            label = path.replace(self.data_dir + 'sorted_data/', '')[0]
            labels.append(label)
        labels = sorted(list(set(labels)))
        labels2tensor = {label: labels.index(label) for label in labels}

        image_path_ex = self.image_path_list[index]
        label_ex = image_path_ex.replace(self.data_dir + 'sorted_data/', '')[0]
        # Load image and transform it into a tensor as a grayscale image (since the images don't contain any colors other than black, white, gray)
        image_ex = io.imread(image_path_ex)
        # Normalize image (make values between 0 and 1)
        image_ex = image_ex / np.max(image_ex)
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.ConvertImageDtype(dtype=torch.float32)])
                                                    #, torchvision.transforms.Grayscale(num_output_channels=1)])
        # Transform image and get label's corresponding integer
        image_ex = transform(image_ex)
        label_ex = labels2tensor[label_ex]

        return image_ex, label_ex

    def prep_data(self, data_dir):
        '''
        Function to organize the letter images into one directory for each letter
        - this function assumes the dataset was downloaded and unzipped in the same directory as the python scripts
        '''
        os.makedirs(f'{data_dir}/sorted_data/', exist_ok=True)  # Creates a sorted_data directory within Braille Dataset directory
        for root, dirs, files in os.walk(f'{data_dir}/Braille Dataset'):
            for file in tqdm(sorted(files)):
                if file.endswith('.jpg'):
                    os.makedirs(f'{data_dir}/sorted_data/{file[0]}/',
                                exist_ok=True)  # Adds a directory for each letter of the alphabet
                    copyfile(f'{root}/{file}',
                             f'{data_dir}/sorted_data/{file[0]}/{file}')  # Adds each letter image to it's corresponding directory

    def _find_files(self, directory):
        '''
        Function to get all files in data directory
        '''
        image_path_list = []
        sorted_dir = os.path.join(directory, "sorted_data")
        if not os.path.isdir(sorted_dir):
            print("Processing the data.")
            self.prep_data(self.data_dir)
        for letter in string.ascii_lowercase:
            curr_dir = os.path.join(sorted_dir, letter)
            image_path_list += [os.path.join(curr_dir, f) for f in os.listdir(curr_dir)]
        return image_path_list

In [None]:
class CNNClassif(nn.Module):
    """Convolutional neural network classifier for Braille letter images"""
    def __init__(self, num_channels1=16, num_channels2=32, num_channels3=64,
                 num_lin_channels1=128, num_lin_channels2=64, num_classes=26):
        super(CNNClassif, self).__init__()
        # Convolutional channel values
        self.num_channels1 = num_channels1
        self.num_channels2 = num_channels2
        self.num_channels3 = num_channels3
        # Linear channel values
        self.num_lin_channels1 = num_lin_channels1
        self.num_lin_channels2 = num_lin_channels2
        # Number of classes
        self.num_classes = num_classes

        self.cnn_layer1 = nn.Sequential(nn.Conv2d(3, num_channels1, kernel_size=5, padding=2), 
                                   nn.ReLU(),
                                   nn.MaxPool2d(kernel_size=2))
        self.cnn_layer2 = nn.Sequential(nn.Conv2d(num_channels1, num_channels2, kernel_size=5, padding=2), 
                                   nn.ReLU(),
                                   nn.MaxPool2d(kernel_size=2))
        self.cnn_layer3 = nn.Sequential(nn.Conv2d(num_channels2, num_channels3, kernel_size=3, padding=2), 
                                   nn.ReLU(),
                                   nn.MaxPool2d(kernel_size=2))
        self.linear_layer1 = nn.Sequential(nn.Linear(num_channels3*4*4, num_lin_channels1), nn.ReLU())
        self.linear_layer2 = nn.Sequential(nn.Linear(num_lin_channels1, num_lin_channels2), nn.ReLU())
        self.linear_layer3 = nn.Sequential(nn.Linear(num_lin_channels2, num_classes), nn.ReLU())
        
    def forward(self, x):
        w = self.cnn_layer1(x)
        y = self.cnn_layer2(w)
        z = self.cnn_layer3(y)
        #print(z.shape) # This shape will help you give correct input shape to linear_layer1
        z2 = z.reshape(z.shape[0], -1)
        lin1 = self.linear_layer1(z2)
        lin2 = self.linear_layer2(lin1)
        out = self.linear_layer3(lin2)
        return out 

def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0.01)
    return

In [None]:
ex = Experiment('braille_cnn')
ex.observers.append(FileStorageObserver('runs'))

@ex.config
def config():
    """Configuration of the Braille Image Classifier experiment."""
    seed = 0
    batch_size = 8
    num_epochs = 50
    loss_fn = nn.CrossEntropyLoss()
    learning_rate = 0.01
    patience = 10
    num_channels1 = 32
    num_channels2 = 64
    num_channels3 = 128
    num_lin_channels1 = 128
    num_lin_channels2 = 64
    
    
@ex.capture
def training_cnn_classifier(model, train_dataloader, val_dataloader, num_epochs, loss_fn,
                            learning_rate, patience, verbose=True):
    model_tr = copy.deepcopy(model)
    model_tr.train()
    
    optimizer = torch.optim.SGD(model_tr.parameters(), lr=learning_rate)
    
    loss_all_epochs = []
    no_improve = 0  # value to track for how many epochs validation accuracy is not improving
    
    for epoch in range(num_epochs):
        loss_current_epoch = 0
        
        for batch_index, (images, labels) in enumerate(train_dataloader):
            
            y_pred = model_tr.forward(images)
            loss = loss_fn(y_pred, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_current_epoch += loss.item()

        loss_all_epochs.append(loss_current_epoch / (batch_index + 1))
        val_accuracy = eval_cnn_classifier(model_tr, eval_dataloader=val_dataloader)
        # Early stopping implementation
        if epoch == 0:
            best_acc = val_accuracy
        elif val_accuracy > best_acc:
            best_acc = val_accuracy
            torch.save(model_tr.state_dict(), 'test_model.pt')
            no_improve = 0
        else:
            no_improve += 1
        if verbose:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss_current_epoch/(batch_index + 1):.4f}')
            print(f'-----> Validation Accuracy: {val_accuracy:.3f}%')
            ex.log_scalar('loss', loss_current_epoch, step=epoch+1)

        if no_improve >= patience:
            break
        
    return model_tr, loss_all_epochs


def eval_cnn_classifier(model, eval_dataloader):

    model.eval() 

    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in eval_dataloader:
            y_predicted = model(images)
            _, label_predicted = torch.max(y_predicted.data, 1)
            total += labels.size(0)
            correct += (label_predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    
    return accuracy


@ex.automain
def run(seed, batch_size, num_epochs,
        num_channels1, num_channels2, num_channels3,
        num_lin_channels1, num_lin_channels2):

    # Instantiating the dataset
    dataset = ImageDataset()
    # Splitting the dataset
    split_data = random_split(dataset, [1248, 156, 156], generator=torch.Generator().manual_seed(seed))
    train_data, val_data, test_data = split_data

    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

    num_classes = len(list(set([datapoint[1] for datapoint in train_data])))
    print("Number of classes: ", num_classes)
    batch_data, batch_name =  next(iter(train_dataloader))
    print(f'Batch shape [batch_size, image_shape]: {batch_data.shape}')
    print('Number of batches:', len(train_dataloader))

    print("== Initializing model...")
    model = CNNClassif(num_channels1, num_channels2, num_channels3,
                       num_lin_channels1, num_lin_channels2, num_classes)
    torch.manual_seed(seed)
    model.apply(init_weights)
    num_params = sum(p.numel() for p in model.parameters())
    ex.log_scalar('number_of_params', num_params)
    print(model)

    print("== Training...")
    model, loss_total = training_cnn_classifier(model, train_dataloader, val_dataloader)
    # Best model is saved within training function
    # torch.save(model.state_dict(), 'test_model.pt')
    ex.add_artifact('test_model.pt')

    # TO DO: make it prettier
    plt.plot(loss_total)
    plt.savefig('loss.png')
    ex.add_artifact('loss.png')

    print("== Evaluating...")
    # Instantiating our model and loading the best model checkpoint from training
    model_eval = CNNClassif(num_channels1, num_channels2, num_channels3,
                            num_lin_channels1, num_lin_channels2, num_classes)
    model_eval.load_state_dict(torch.load('test_model.pt'))
    accuracy = eval_cnn_classifier(model_eval, test_dataloader)
    ex.log_scalar('accuracy', accuracy)
    return f'{accuracy:.3f}%'