In [1]:
import os
import sys
pdir = os.path.dirname(os.getcwd())
sys.path.append(pdir)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import wandb

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import utils
from simpleview_pytorch import SimpleView

from torch.utils.data.dataset import Dataset

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmja2106[0m (use `wandb login --relogin` to force relogin)


True

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Load the data:

In [4]:
trees_data = torch.load('trees_128.pt')
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
JUNIPE        2
NA            2
QUERCUS       2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485


In [19]:
params = {
    "batch_size":64,
    "validation_split":.1,
    "shuffle_dataset":True,
    "random_seed":0,
    "learning_rate":0.0005,
    "momentum":0.9,
    "epochs":100,
    "loss_fn":"cross-entropy",
    "optimizer":"sgd",
    "jitter":False,
    "random_rotation":False,
    "random_scaling":False,
    "random_translation":False,
    "voting":"None",
    
    "model":"SimpleView",
    
    "image_dim":trees_data.image_dim,
    "camera_fov_deg":trees_data.camera_fov_deg,
    "f":trees_data.f,
    "camera_dist":trees_data.camera_dist,
    "num_views":trees_data.depth_images.shape[1],
    "depth_averaging" = "min",
    
    "species":["QUEFAG", "PINNIG", "QUEILE", "PINSYL", "PINPIN"]
    "data_resolution":"2.5cm"
}

if trees_data.soft_min_k:
    params["soft_min_k"] = trees_data.soft_min_k

experiment_name = wandb.util.generate_id()
    
run = wandb.init(
    project='laser-trees',
    group=experiment_name,
    config=params)    

config = wandb.config


[34m[1mwandb[0m: wandb version 0.10.28 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#### Remove low-count species:

In [20]:
for specie in list(set(trees_data.species) - set(config.species)):
    print("Removing: {}".format(specie))
    trees_data.remove_species(specie)
    
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478


#### Train-validation split:

In [21]:
dataset_size = len(trees_data)
indices = list(range(dataset_size))
split = int(np.floor(config.validation_split * dataset_size))

if config.shuffle_dataset :
    np.random.seed(config.random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [22]:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(trees_data, batch_size=config.batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(trees_data, batch_size=config.batch_size,
                                                sampler=valid_sampler)

In [23]:
#Run to print no. instances of each class
if None:
    for x in train_loader:
        print(torch.unique(x['labels'], return_counts = True))
    print()

    for x in validation_loader:
        print(torch.unique(x['labels'], return_counts = True))
    print()

### Define model, loss fn, optimiser:

In [24]:
assert set(config.species) == set(trees_data.species)

if config.model=="SimpleView":
    model = SimpleView(
        num_views=config.num_views,
        num_classes=len(config.species)
    )

model = model.to(device=device)

if config.loss_fn=="cross-entropy":
    loss_fn = nn.CrossEntropyLoss()

if config.optimizer=="sgd":
    optimizer = optim.SGD(model.parameters(), lr=config.learning_rate, momentum=config.momentum)

### Train & Test Loops:

In [25]:
wandb.watch(model)
for epoch in range(config.epochs):  # loop over the dataset multiple times
    
    #Training loop============================================
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        depth_images = data['depth_images']
        labels = data['labels']

        depth_images = depth_images.to(device=device)
        labels = labels.to(device=device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(depth_images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2))
            running_loss = 0.0
                
    #Test loop================================================
    num_train_correct = 0
    num_train_samples = 0
    
    num_val_correct = 0
    num_val_samples = 0
    
    running_train_loss = 0
    running_val_loss = 0
    
    model.eval()  
    with torch.no_grad():
        #Train set eval==============
        for data in train_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_train_correct += (predictions == labels).sum()
            num_train_samples += predictions.size(0)
            
            running_train_loss += loss_fn(scores, labels)
        
        train_acc = float(num_train_correct)/float(num_train_samples)
        train_loss = running_train_loss/len(validation_loader)
        
        #Test set eval===============
        for data in validation_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_val_correct += (predictions == labels).sum()
            num_val_samples += predictions.size(0)
            
            running_val_loss += loss_fn(scores, labels)
        
        val_acc = float(num_val_correct)/float(num_val_samples)
        val_loss = running_val_loss/len(validation_loader)
        
        print(f'Got {num_val_correct} / {num_val_samples} with accuracy {val_acc*100:.2f}')
        
        wandb.log({
            "Train Loss":train_loss,
            "Validation Loss":val_loss,
            "Train Accuracy":train_acc,
            "Validation Accuracy":val_acc
            })
        

print('Finished Training')
run.finish()

[1,     5] loss: 4.299
[1,    10] loss: 3.608
[1,    15] loss: 3.536
[1,    20] loss: 3.639
[1,    25] loss: 3.550
[1,    30] loss: 3.261
[1,    35] loss: 3.505
Got 109 / 247 with accuracy 44.13
[2,     5] loss: 3.369
[2,    10] loss: 3.411
[2,    15] loss: 3.089
[2,    20] loss: 3.026
[2,    25] loss: 3.206
[2,    30] loss: 3.172
[2,    35] loss: 3.228
Got 130 / 247 with accuracy 52.63
[3,     5] loss: 2.869
[3,    10] loss: 3.062
[3,    15] loss: 3.109
[3,    20] loss: 3.096
[3,    25] loss: 2.946
[3,    30] loss: 3.001
[3,    35] loss: 3.043
Got 132 / 247 with accuracy 53.44
[4,     5] loss: 2.903
[4,    10] loss: 2.778
[4,    15] loss: 2.820
[4,    20] loss: 2.951
[4,    25] loss: 2.934
[4,    30] loss: 2.911
[4,    35] loss: 2.818
Got 135 / 247 with accuracy 54.66
[5,     5] loss: 2.767
[5,    10] loss: 2.952
[5,    15] loss: 2.746
[5,    20] loss: 2.788
[5,    25] loss: 2.720
[5,    30] loss: 2.598
[5,    35] loss: 2.667
Got 144 / 247 with accuracy 58.30
[6,     5] loss: 2.664
[6

VBox(children=(Label(value=' 0.05MB of 0.05MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train Loss,0.1211
Validation Loss,1.81661
Train Accuracy,1.0
Validation Accuracy,0.62348
_step,99.0
_runtime,219.0
_timestamp,1620036703.0


0,1
Train Loss,█▇▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Validation Loss,▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▅▃▄▃▄▅▅▅▅▆█▆▆▇▇▇▇▇▇
Train Accuracy,▁▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇██████████████████
Validation Accuracy,▁▄▆▇▇▇▇▇▇▇▇██▇▇██▇█▇█▇▇█▆▇▆▇▆▆▇▆▆▆▆▆▇▇▆▆
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███


### Update old object to new class (Don't run every time):

In [12]:
if None:
    metadata_file = "../data/treesXYZ/meta/META.csv"
    data_dir = "../data/treesXYZ/"
    trees_new = utils.TreeSpeciesDataset(data_dir, metadata_file)

    trees_old=torch.load('trees_old.pt')

    trees_new.depth_images = trees_old.depth_images
    trees_new.labels = trees_old.labels.long()

    torch.save(trees_new, 'trees_new.pt')

In [13]:
if None: 
    metadata_file = "../data/treesXYZ/meta/META.csv"
    data_dir = "../data/treesXYZ/"
    trees_tmp = utils.TreeSpeciesDataset(data_dir=data_dir, metadata_file=metadata_file)
    trees_tmp.depth_images=trees_data.depth_images
    trees_tmp.labels=trees_data.labels
    trees_tmp.image_dim=128
    trees_tmp.camera_fov_deg=90
    trees_tmp.f=1
    trees_tmp.camera_dist=1.4

    torch.save(trees_tmp, 'trees_128.pt')