In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from tqdm import tqdm
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

In [None]:
from dataset import TerrainDataset
from pointnet import PointNet

dataset = TerrainDataset(root='/home/atas/traversablity_estimation_net/data',train=True)


In [None]:
from IPython.display import Javascript  # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

from torch_geometric.loader import DataLoader

train_dataset =  TerrainDataset(root='/home/atas/traversablity_estimation_net/data', train=True )
test_dataset = TerrainDataset(root='/home/atas/traversablity_estimation_net/data', train=False)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)

model = PointNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#criterion = torch.nn.MSELoss()  # Mean Squared Error Loss Function
criterion = torch.nn.L1Loss()  # Mean Absolute Error (L1 Loss Function)

def train(model, optimizer, loader):
    model.train()
    total_loss = 0.0
    for data in loader:
        optimizer.zero_grad()                   # Clear gradients.
        # reshape to have 1 at the end
        data.pos = data.pos.reshape((data.pos.shape[0], data.pos.shape[1], 1))
        logits = model(data.pos, data.batch)                # Forward pass.
        loss = criterion(logits, data.y)        # Loss computation.
        loss.backward()                         # Backward pass.
        optimizer.step()                        # Update model parameters.
        total_loss += loss.item()

    return total_loss


@torch.no_grad()
def test(model, loader):
    model.eval()

    error = 0.0
    for data in loader:
        data.pos = data.pos.reshape((data.pos.shape[0], data.pos.shape[1], 1))
        pred = model(data.pos, data.batch)
        error += torch.pow((pred - data.y), 2).sum().item()
    
    # convert error to percentage accuracy
    return error

In [None]:
train_loss=[]
test_mse=[]

# Save every 10th epoch model.
for epoch in range(1, 300):
    loss = train(model, optimizer, train_loader)
    mse = test(model, test_loader)
    if epoch % 50 == 0:
        torch.save(model.state_dict(), f'epoch_{epoch}.pt')
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, MSE: {mse:.4f}')
    train_loss.append(loss)
    test_mse.append(mse)
    
fig, (ax1, ax2) = plt.subplots(2, figsize=(12, 6), sharex=True)
ax1.plot(train_loss)
ax1.set_ylabel("training loss")
ax2.plot(test_mse)
ax2.set_ylabel("mse error")
ax2.set_xlabel("epochs")    

In [None]:
visual_test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load the model
# print curdir()
print(os.getcwd())

net = PointNet()
net.load_state_dict(torch.load('/home/atas/traversablity_estimation_net/weights/epoch_200.pt'))
net.eval()

error = 0.0
for data in visual_test_loader:
 
    inputs, labels = data.pos, data.y
    inputs = inputs.reshape((inputs.shape[0], inputs.shape[1], 1))
    outputs = net(inputs, data.batch)
    
    error += torch.pow((outputs - data.y), 2).sum().item()
    outputs = outputs.cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()

error = error / len(visual_test_loader)
print("MSE: " + str(error))    
