# Testing GCNN on ModelNet40

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from Scripts import CNN

# Use GPU if available, has issues with these small batch sizes
# if torch.cuda.is_available(): 
#     dev = "cuda:0" 
# else: 
#     dev = "cpu"

dev = "cpu"

print(dev)
device = torch.device(dev)

# Load training data
data = np.load('Data/ModelNet40/ModelNet40Train.npz', allow_pickle=True)

# Parser cuts off anything with an underscore. This array is modified to match
modelnet_labels = [
    'airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car',
    'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower',
    'glass', 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor',
    'night', 'person', 'piano', 'plant', 'radio', 'range', 'sink',
    'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv', 'vase',
    'wardrobe', 'xbox'
]

train_data = np.expand_dims(data['data'], axis=1)
train_labels = data['labels']

print(train_data.shape)
print(train_labels.shape)

# Convert labels to integer values so they can be converted to tensors
label_as_int = []
for i in data['labels']:
   label_as_int.append(int(modelnet_labels.index(i)))

train_data = torch.utils.data.TensorDataset(torch.FloatTensor(train_data).to(device),torch.LongTensor(label_as_int).to(device))


# Load test data
data = np.load('Data/ModelNet40/ModelNet40Test.npz', allow_pickle=True)

test_data = np.expand_dims(data['data'], axis=1)
test_labels = data['labels']

print(test_data.shape)
print(test_labels.shape)

# Convert labels to integer values so they can be converted to tensors
label_as_int = []
for i in data['labels']:
   label_as_int.append(int(modelnet_labels.index(i)))

test_data = torch.utils.data.TensorDataset(torch.FloatTensor(test_data).to(device),torch.LongTensor(label_as_int).to(device))

# Hyperparameters
batch_size = 32
num_classes = 40
learning_rate = 0.01 # Decreases by a factor of 10 every 10000 batches
momentum = 0.9
weight_decay = 0.001
decay_iter = 10000
epochs = 10
batch_number = 0

# Load data in batch sizes
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)

model = CNN.GCNN(num_classes)
model = model.float()
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
criterion = torch.nn.CrossEntropyLoss()

training_accuracy = []
testing_accuracy = []

for epoch in range(epochs):
    print("epoch:", epoch+1)
    
    for (vox_grids, vox_labels) in train_loader:
        batch_number += 1

        if batch_number % decay_iter == 0:
            print("Decreasing learning rate")
            learning_rate *= 0.1

        if (len(vox_grids) < batch_size):
            continue

        output = model(vox_grids)
        loss = criterion(output, vox_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Test VoxNet after training
    with torch.no_grad():
        num_correct = 0
        total = 0

        # Test model on training set
        for (grids, labels) in train_loader:
            if (len(grids) < batch_size):
                continue
            
            output = model(grids)

            _, predictions = torch.max(output, 1)

            num_correct += (predictions == labels).sum().item()
            total += labels.shape[0]
        
        # Get training accuracy
        accuracy = (num_correct / total) * 100.0
        training_accuracy.append(accuracy)
        
        print("Training accuracy:", accuracy)

        num_correct = 0
        total = 0

        # Test model on testing set
        for (grids, labels) in test_loader:
            if (len(grids) < batch_size):
                continue

            output = model(grids)

            _, predictions = torch.max(output, 1)

            num_correct += (predictions == labels).sum().item()
            total += labels.shape[0]
        
        # Get testing accuracy
        accuracy = (num_correct / total) * 100.0
        testing_accuracy.append(accuracy)
        
        print("Testing accuracy:", accuracy)

line_one = plt.plot(range(len(training_accuracy)), training_accuracy, label='Training Accuracy')
line_two = plt.plot(range(len(testing_accuracy)), testing_accuracy, label='Testing Accuracy')
plt.legend()
plt.ylim(0,100.0)
plt.ylabel('Accuracy')
plt.xlabel('Epochs')