In [1]:
import os
import sys
import copy
pdir = os.path.dirname(os.getcwd())
sys.path.append(pdir)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import wandb
import random

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR

import utils
from simpleview_pytorch import SimpleView

from torch.utils.data.dataset import Dataset

model_dir = 'models'

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmja2106[0m (use `wandb login --relogin` to force relogin)


True

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Load the data:

In [4]:
dataset_name = 'tree_points.pt'

trees_data = torch.load(dataset_name)
val_data = torch.load(dataset_name)
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
NA            2
JUNIPE        2
QUERCUS       2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485


In [5]:
type([1,2,3]) == list

True

In [6]:
params = {
    "dataset_type":type(trees_data),
    "batch_size":128,
    "validation_split":.2,
    "shuffle_dataset":True,
    "random_seed":0,
    "learning_rate":[0.0005, 100, 0.5],  #[init, step_size, gamma] for scheduler
    "momentum":0.9, #Only used for sgd
    "epochs":300,
    "loss_fn":"cross-entropy",
    "optimizer":"adam",
    "voting":"None",
    
    "model":"SimpleView",
    
    "image_dim":trees_data.image_dim,
    "camera_fov_deg":trees_data.camera_fov_deg,
    "f":trees_data.f,
    "camera_dist":trees_data.camera_dist,
    "depth_averaging":"min",
    
    "transforms":['rotation','translation','jitter'],
    "min_rotation":0,
    "max_rotation":2*np.pi,
    "min_translation":0,
    "max_translation":0.5,
    "jitter_std":3e-4, 
    
    "species":["QUEFAG", "PINNIG", "QUEILE", "PINSYL", "PINPIN"],
    "data_resolution":"2.5cm"
}


if params["dataset_type"] == utils.dataset.TreeSpeciesPointDataset: #Change these by hand using point dataset
    params["image_dim"] = 256
    params["camera_fov_deg"] = 90 
    params["f"] = 1
    params["camera_dist"] = 1.4
    params["depth_averaging"] = "min"
    params["soft_min_k"] = 50
    params["num_views"] = 6
    
    trees_data.set_params(image_dim = params["image_dim"],
                         camera_fov_deg = params["camera_fov_deg"],
                         f = params["f"],
                         camera_dist = params["camera_dist"],
                         soft_min_k = params["soft_min_k"],
                         transforms = params["transforms"],
                         min_rotation = params["min_rotation"],
                         max_rotation = params["max_rotation"],
                         min_translation = params["min_translation"],
                         max_translation = params["max_translation"],
                         jitter_std = params["jitter_std"]
                         )
    
    val_data.set_params(image_dim = params["image_dim"],
                         camera_fov_deg = params["camera_fov_deg"],
                         f = params["f"],
                         camera_dist = params["camera_dist"],
                         soft_min_k = params["soft_min_k"],
                         transforms = params["transforms"],
                         min_rotation = params["min_rotation"],
                         max_rotation = params["max_rotation"],
                         min_translation = params["min_translation"],
                         max_translation = params["max_translation"],
                         jitter_std = params["jitter_std"]
                         )
    
elif params["dataset_type"] == utils.dataset.TreeSpeciesDataset:
    params["image_dim"] = trees_data.image_dim
    params["camera_fov_deg"] = trees_data.camera_fov_deg
    params["f"] = trees_data.f
    params["camera_dist"] = trees_data.camera_dist
    params["num_views"] = trees_data.depth_images.shape[1]
    params["depth_averaging"] = "min"

    
if trees_data.soft_min_k:
    params["soft_min_k"] = trees_data.soft_min_k

experiment_name = wandb.util.generate_id()
    
run = wandb.init(
    project='laser-trees',
    group=experiment_name,
    config=params)    

config = wandb.config
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed(config.random_seed)
np.random.seed(config.random_seed)
random.seed(config.random_seed)

[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#### Remove low-count species:

In [7]:
for specie in list(set(trees_data.species) - set(config.species)):
    print("Removing: {}".format(specie))
    trees_data.remove_species(specie)
    val_data.remove_species(specie)

print('Train Dataset:')
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))
print()

print('Validation Dataset (should match):')
print(val_data.counts)
print('Species: ', val_data.species)
print('Labels: ', val_data.labels)
print('Total count: ', len(val_data))
print()

assert len(val_data) == len(trees_data)

Removing: QUERCUS
Removing: DEAD
Removing: NA
Removing: JUNIPE
Train Dataset:
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Validation Dataset (should match):
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478



#### Train-validation split:

In [8]:
dataset_size = len(trees_data)
indices = list(range(dataset_size))
split = int(np.floor(config.validation_split * dataset_size))

if config.shuffle_dataset :
    
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [9]:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(trees_data, batch_size=config.batch_size, 
                                           sampler=train_sampler)

val_data.set_params(transforms=['none']) #Turn off transforms for the validation dataset - DON'T GIVE IT AN EMPTY LIST
validation_loader = torch.utils.data.DataLoader(trees_data, batch_size=config.batch_size,
                                                sampler=valid_sampler)

In [10]:
#Run to print no. instances of each class
if None:
    for x in train_loader:
        print(torch.unique(x['labels'], return_counts = True))
    print()

    for x in validation_loader:
        print(torch.unique(x['labels'], return_counts = True))
    print()

### Define model, loss fn, optimiser:

In [11]:
assert set(config.species) == set(trees_data.species)

if config.model=="SimpleView":
    model = SimpleView(
        num_views=config.num_views,
        num_classes=len(config.species)
    )

model = model.to(device=device)

if config.loss_fn=="cross-entropy":
    loss_fn = nn.CrossEntropyLoss()
if config.loss_fn=="smooth-loss":
    loss_fn = utils.smooth_loss

if type(config.learning_rate) == list:
    lr = config.learning_rate[0]
    step_size = config.learning_rate[1]
    gamma = config.learning_rate[2]
else:
    lr = config.learning_rate
    
if config.optimizer=="sgd":
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=config.momentum)
elif config.optimizer=="adam":
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
if type(config.learning_rate) == list:
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

### Train & Test Loops:

In [12]:
#wandb.watch(model)
best_acc = 0
for epoch in range(config.epochs):  # loop over the dataset multiple times
    
    #Training loop============================================
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        depth_images = data['depth_images']
        labels = data['labels']

        depth_images = depth_images.to(device=device)
        labels = labels.to(device=device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(depth_images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2))
            running_loss = 0.0
                
    #Test loop================================================
    num_train_correct = 0
    num_train_samples = 0
    
    num_val_correct = 0
    num_val_samples = 0
    
    running_train_loss = 0
    running_val_loss = 0
    
    model.eval()  
    with torch.no_grad():
        #Train set eval==============
        for data in train_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_train_correct += (predictions == labels).sum()
            num_train_samples += predictions.size(0)
            
            running_train_loss += loss_fn(scores, labels)
        
        train_acc = float(num_train_correct)/float(num_train_samples)
        train_loss = running_train_loss/len(validation_loader)
        
        #Test set eval===============
        for data in validation_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_val_correct += (predictions == labels).sum()
            num_val_samples += predictions.size(0)
            
            running_val_loss += loss_fn(scores, labels)
        
        val_acc = float(num_val_correct)/float(num_val_samples)
        val_loss = running_val_loss/len(validation_loader)
        
        print(f'Got {num_val_correct} / {num_val_samples} with accuracy {val_acc*100:.2f}')
        
        if val_acc > best_acc:
            best_model_state = copy.deepcopy(model.state_dict())
            best_acc = val_acc
        
        wandb.log({
            "Train Loss":train_loss,
            "Validation Loss":val_loss,
            "Train Accuracy":train_acc,
            "Validation Accuracy":val_acc,
            "Learning Rate":optimizer.param_groups[0]['lr']
            })
        
        scheduler.step()
        

print('Finished Training')

print('Saving best model...')
torch.save(best_model_state,
           '{model_dir}/{fname}.pt'.format(
               model_dir=model_dir,
               fname=experiment_name+'_best.pt')
          )
print('Saved!')

run.finish()

  x_world_tilde=torch.cat((torch.tensor(cloud), torch.ones(cloud.shape[0],1)), 1).transpose(0,1)


[1,     5] loss: 3.429
[1,    10] loss: 3.067
[1,    15] loss: 2.931
Got 76 / 495 with accuracy 15.35
[2,     5] loss: 2.863
[2,    10] loss: 2.677
[2,    15] loss: 2.646
Got 124 / 495 with accuracy 25.05
[3,     5] loss: 2.471
[3,    10] loss: 2.380
[3,    15] loss: 2.552
Got 155 / 495 with accuracy 31.31
[4,     5] loss: 2.315
[4,    10] loss: 2.387
[4,    15] loss: 2.203
Got 163 / 495 with accuracy 32.93
[5,     5] loss: 2.354
[5,    10] loss: 2.274
[5,    15] loss: 2.256
Got 161 / 495 with accuracy 32.53
[6,     5] loss: 2.205
[6,    10] loss: 2.356
[6,    15] loss: 2.103
Got 341 / 495 with accuracy 68.89
[7,     5] loss: 2.132
[7,    10] loss: 2.149
[7,    15] loss: 2.226
Got 279 / 495 with accuracy 56.36
[8,     5] loss: 2.069
[8,    10] loss: 2.161
[8,    15] loss: 2.217
Got 327 / 495 with accuracy 66.06
[9,     5] loss: 2.215
[9,    10] loss: 2.221
[9,    15] loss: 1.999
Got 329 / 495 with accuracy 66.46
[10,     5] loss: 2.155
[10,    10] loss: 2.042
[10,    15] loss: 2.010
Go

VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train Loss,0.83618
Validation Loss,0.82606
Train Accuracy,0.91679
Validation Accuracy,0.79394
Learning Rate,0.00013
_step,299.0
_runtime,36400.0
_timestamp,1621447345.0


0,1
Train Loss,▇▃▆█▅▃▃▅▃▅▂▃▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁
Validation Loss,▇▂▅█▄▂▃▄▂▅▁▂▄▁▁▁▁▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▂▂▂▂▂▁▂
Train Accuracy,▁▅▃▁▄▅▅▄▆▃▆▆▄▇▇▇▇▆▇▇▆▇▇▇▇▇▇██▇█▇████████
Validation Accuracy,▂▆▄▁▅▅▆▄▇▃▇▇▅█▇▇█▇▇▇▇▇▇▇▇▇▇▇██▇▇███▇███▇
Learning Rate,██████████████▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███


In [13]:
best_acc

0.8484848484848485