# Import Dependencies

In [None]:
import json
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import f1_score, mean_squared_error, confusion_matrix, ConfusionMatrixDisplay
from Scripts import CNN
from Scripts import VoxelizeData as vd

#%set_env CUDA_LAUNCH_BLOCKING=1
#%set_env TORCH_USE_CUDA_DSA=1

# Use GPU if available, has issues with these small batch sizes
if torch.cuda.is_available(): 
    dev = torch.device("cuda")
else: 
    dev = torch.device("cpu")

print(torch.__version__)

#dev = torch.device("cpu")

print(dev)

### Load Dataset To Train Model

Change **training_set** to any of the following:

- ModelNet40
- ShapeNet
- Toys

If you have 16GB of RAM or less, I recommend using the Toys dataset, and then avoiding the cross-dataset evaluation.

In [None]:
training_set = "ModelNet40"

train_data, train_labels, test_data, test_labels = vd.load_dataset(training_set)

### Find Best Parameters Using Grid Search

In [None]:
batch_size = 32
classes = np.unique(train_labels)
num_classes = len(classes)
print(num_classes)

# Make a parameter grid for grid search
param_grid = {'epochs': [40, 50],
        'learning_rate': [0.01, 0.001, 0.0001],
        'momentum': [0.9, 0.95, 0.99, 0.999],
        'weight_decay': [0.01, 0.001],
        'decay_iter': [10000, 20000, 40000, 80000]
        }

# Store index of the best iteration for each metric
best_results = {'f1': 0, 'acc': 0, 'mse': 0}

# Store the f1 score, accuracy, MSEs, and confusion matrices for each parameter combination, as well as the parameters
metrics = {'f1_scores': [], 'accuracies': [], 'mses': [], 'conf_matrices': []}
results = {'train': [], 'test': [], 'params': [], 'best': best_results}

# Load data in batch sizes
train_loader = vd.build_dataloader(train_data, train_labels, batch_size=batch_size, shuffle=True)
test_loader = vd.build_dataloader(test_data, test_labels, batch_size=batch_size, shuffle=True)

best_acc = 0
best_f1 = 0
best_mse = float('inf')

# Apply grid search using the sklearn ParameterGrid
for i, params in enumerate(ParameterGrid(param_grid)):
    print(f'iteration {i}: {params}')

    # Get the parameters
    learning_rate = params['learning_rate']
    momentum = params['momentum']
    weight_decay = params['weight_decay']
    decay_iter = params['decay_iter']
    epochs = params['epochs']

    # Store the metrics for this iteration so we can plot them later
    results['test'].append(metrics.copy())
    results['train'].append(metrics.copy())

    # Store the parameters for this iteration
    results['params'].append(params)

    # Initialize the model
    model = CNN.CNN(num_classes)
    model = model.float()
    model.to(dev)

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    criterion = torch.nn.CrossEntropyLoss()

    batch_number = 0

    # Train and test the model for the specified number of epochs
    for epoch in range(epochs):
        
        # Train CNN
        for (vox_grids, vox_labels) in train_loader:
            vox_grids = vox_grids.to(dev)
            vox_labels = vox_labels.to(dev)

            batch_number += 1

            if batch_number % decay_iter == 0:
                print("Decreasing learning rate")
                learning_rate *= 0.1

            # print(len(vox_grids))
            if (len(vox_grids) < batch_size):
                continue

            output = model(vox_grids)
            loss = criterion(output, vox_labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Test CNN after training
        with torch.no_grad():
            num_correct = 0
            total = 0
            predictions_train = []
            labels_train = []

            # Test the model on the training data
            for (grids, labels) in train_loader:
                if (len(grids) < batch_size):
                    continue

                grids = grids.to(dev)
                labels = labels.to(dev)

                output = model(grids)

                _, predictions = torch.max(output, 1)

                # Determine the number of correct predictions
                num_correct += (predictions.cpu() == labels.cpu()).sum().item()
                
                # Determine the number of total predictions made
                total += labels.shape[0]
                
                # Store the predictions and labels for calculating the metrics
                predictions_train += predictions.cpu().tolist()
                labels_train += labels.cpu().tolist()
            
            # Calculate the metrics
            accuracy = (num_correct / total) * 100.0
            f1 = f1_score(labels_train, predictions_train, average='weighted')
            mse = mean_squared_error(labels_train, predictions_train)

            # Store the training results
            results['train'][i]['accuracies'].append(accuracy)
            results['train'][i]['f1_scores'].append(f1)
            results['train'][i]['mses'].append(mse)
            results['train'][i]['conf_matrices'].append(confusion_matrix(labels_train, predictions_train, labels=[i for i in range(num_classes)]))

            # Print final results of the model before moving on to the next parameter combination
            if epoch == epochs-1:
                print("Final Training Accuracy:", accuracy)
                print("Final Training F1 Score:", f1)
                print("Final Training MSE:", mse)

            num_correct = 0
            total = 0

            predictions_test = []
            labels_test = []

            # Test the model on the test data
            for (grids, labels) in test_loader:
                if (len(grids) < batch_size):
                    continue

                grids = grids.to(dev)
                labels = labels.to(dev)

                output = model(grids)

                _, predictions = torch.max(output, 1)

                # Determine the number of correct predictions
                num_correct += (predictions.cpu() == labels.cpu()).sum().item()

                # Determine the number of total predictions made
                total += labels.shape[0]

                # Store the predictions and labels for calculating the metrics
                predictions_test += predictions.cpu().tolist()
                labels_test += labels.cpu().tolist()
            
            # Calculate the metrics
            accuracy = (num_correct / total) * 100.0
            f1 = f1_score(labels_test, predictions_test, average='weighted')
            mse = mean_squared_error(labels_test, predictions_test)

            # Store the results
            results['test'][i]['accuracies'].append(accuracy)
            results['test'][i]['f1_scores'].append(f1)
            results['test'][i]['mses'].append(mse)
            results['test'][i]['conf_matrices'].append(confusion_matrix(labels_test, predictions_test, labels=[i for i in range(num_classes)]))

            # Print final results of the model before moving on to the next parameter combination
            if epoch == epochs-1:
                print("Final Testing Accuracy:", accuracy)
                print("Final Testing F1 Score:", f1)
                print("Final Testing MSE:", mse)

            # Take note of the model with the best accuracy on the test set once all epochs are complete
            if accuracy > best_acc and epoch == epochs-1:
                results['best']['acc'] = i
            
            # Take note of the model with the best f1 score on the test set once all epochs are complete
            if f1_score(labels_test, predictions_test, average='weighted') > best_f1 and epoch == epochs-1:
                results['best']['f1'] = i

            # Take note of the model with the best MSE on the test set once all epochs are complete
            if mean_squared_error(labels_test, predictions_test) < best_mse and epoch == epochs-1:
                results['best']['mse'] = i

# Convert the results to a JSON string
json_string = json.dumps(results, indent=4)