In [1]:
import os
import sys
pdir = os.path.dirname(os.getcwd())
sys.path.append(pdir)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import wandb

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import utils
from simpleview_pytorch import SimpleView

from torch.utils.data.dataset import Dataset

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmja2106[0m (use `wandb login --relogin` to force relogin)


True

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Load the data:

In [4]:
trees_data = torch.load('trees_256.pt')
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
JUNIPE        2
QUERCUS       2
NA            2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485


In [5]:
params = {
    "dataset_type":type(trees_data),
    "batch_size":128,
    "validation_split":.2,
    "shuffle_dataset":True,
    "random_seed":0,
    "learning_rate":0.0005,
    "momentum":0.9,
    "epochs":100,
    "loss_fn":"smooth-loss",
    "optimizer":"adam",
    "jitter":False,
    "random_rotation":False,
    "random_scaling":False,
    "random_translation":False,
    "voting":"None",
    
    "model":"SimpleView",
    
    "image_dim":trees_data.image_dim,
    "camera_fov_deg":trees_data.camera_fov_deg,
    "f":trees_data.f,
    "camera_dist":trees_data.camera_dist,
    "depth_averaging":"min",
    
    "species":["QUEFAG", "PINNIG", "QUEILE", "PINSYL", "PINPIN"],
    "data_resolution":"2.5cm"
}


if params["dataset_type"] == utils.dataset.TreeSpeciesPointDataset: #Change these by hand using point dataset
    params["image_dim"] = 256
    params["camera_fov_deg"] = 90 
    params["f"] = 1
    params["camera_dist"] = 1.4
    params["depth_averaging"] = "min"
    params["soft_min_k"] = 50
    params["num_views"] = 6
    
    trees_data.set_params(image_dim = params["image_dim"],
                         camera_fov_deg = params["camera_fov_deg"],
                         f = params["f"],
                         camera_dist = params["camera_dist"],
                         soft_min_k = params["soft_min_k"])  
    
elif params["dataset_type"] == utils.dataset.TreeSpeciesDataset:
    params["image_dim"] = trees_data.image_dim
    params["camera_fov_deg"] = trees_data.camera_fov_deg
    params["f"] = trees_data.f
    params["camera_dist"] = trees_data.camera_dist
    params["num_views"] = trees_data.depth_images.shape[1]
    params["depth_averaging"] = "min"

    
if trees_data.soft_min_k:
    params["soft_min_k"] = trees_data.soft_min_k

experiment_name = wandb.util.generate_id()
    
run = wandb.init(
    project='laser-trees',
    group=experiment_name,
    config=params)    

config = wandb.config
torch.manual_seed(config.random_seed)

[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


<torch._C.Generator at 0x7f189bb5ff10>

#### Remove low-count species:

In [6]:
for specie in list(set(trees_data.species) - set(config.species)):
    print("Removing: {}".format(specie))
    trees_data.remove_species(specie)
    
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

Removing: QUERCUS
Removing: JUNIPE
Removing: NA
Removing: DEAD
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478


#### Train-validation split:

In [7]:
dataset_size = len(trees_data)
indices = list(range(dataset_size))
split = int(np.floor(config.validation_split * dataset_size))

if config.shuffle_dataset :
    np.random.seed(config.random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [8]:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(trees_data, batch_size=config.batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(trees_data, batch_size=config.batch_size,
                                                sampler=valid_sampler)

In [9]:
#Run to print no. instances of each class
if None:
    for x in train_loader:
        print(torch.unique(x['labels'], return_counts = True))
    print()

    for x in validation_loader:
        print(torch.unique(x['labels'], return_counts = True))
    print()

### Define model, loss fn, optimiser:

In [10]:
assert set(config.species) == set(trees_data.species)

if config.model=="SimpleView":
    model = SimpleView(
        num_views=config.num_views,
        num_classes=len(config.species)
    )

model = model.to(device=device)

if config.loss_fn=="cross-entropy":
    loss_fn = nn.CrossEntropyLoss()
if config.loss_fn=="smooth-loss":
    loss_fn = utils.smooth_loss

if config.optimizer=="sgd":
    optimizer = optim.SGD(model.parameters(), lr=config.learning_rate, momentum=config.momentum)
elif config.optimizer=="adam":
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

### Train & Test Loops:

In [11]:
#wandb.watch(model)
for epoch in range(config.epochs):  # loop over the dataset multiple times
    
    #Training loop============================================
    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        depth_images = data['depth_images']
        labels = data['labels']

        depth_images = depth_images.to(device=device)
        labels = labels.to(device=device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(depth_images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2))
            running_loss = 0.0
                
    #Test loop================================================
    num_train_correct = 0
    num_train_samples = 0
    
    num_val_correct = 0
    num_val_samples = 0
    
    running_train_loss = 0
    running_val_loss = 0
    
    model.eval()  
    with torch.no_grad():
        #Train set eval==============
        for data in train_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_train_correct += (predictions == labels).sum()
            num_train_samples += predictions.size(0)
            
            running_train_loss += loss_fn(scores, labels)
        
        train_acc = float(num_train_correct)/float(num_train_samples)
        train_loss = running_train_loss/len(validation_loader)
        
        #Test set eval===============
        for data in validation_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_val_correct += (predictions == labels).sum()
            num_val_samples += predictions.size(0)
            
            running_val_loss += loss_fn(scores, labels)
        
        val_acc = float(num_val_correct)/float(num_val_samples)
        val_loss = running_val_loss/len(validation_loader)
        
        print(f'Got {num_val_correct} / {num_val_samples} with accuracy {val_acc*100:.2f}')
        
        wandb.log({
            "Train Loss":train_loss,
            "Validation Loss":val_loss,
            "Train Accuracy":train_acc,
            "Validation Accuracy":val_acc
            })
        

print('Finished Training')
run.finish()

[1,     5] loss: 3.661
[1,    10] loss: 3.485
[1,    15] loss: 3.387
Got 46 / 495 with accuracy 9.29
[2,     5] loss: 3.261
[2,    10] loss: 3.134
[2,    15] loss: 3.134
Got 130 / 495 with accuracy 26.26
[3,     5] loss: 3.030
[3,    10] loss: 3.075
[3,    15] loss: 3.006
Got 305 / 495 with accuracy 61.62
[4,     5] loss: 2.931
[4,    10] loss: 2.930
[4,    15] loss: 2.886
Got 339 / 495 with accuracy 68.48
[5,     5] loss: 2.792
[5,    10] loss: 2.780
[5,    15] loss: 2.769
Got 298 / 495 with accuracy 60.20
[6,     5] loss: 2.640
[6,    10] loss: 2.545
[6,    15] loss: 2.596
Got 204 / 495 with accuracy 41.21
[7,     5] loss: 2.433
[7,    10] loss: 2.417
[7,    15] loss: 2.396
Got 295 / 495 with accuracy 59.60
[8,     5] loss: 2.269
[8,    10] loss: 2.201
[8,    15] loss: 2.229
Got 286 / 495 with accuracy 57.78
[9,     5] loss: 2.121
[9,    10] loss: 2.123
[9,    15] loss: 2.110
Got 298 / 495 with accuracy 60.20
[10,     5] loss: 2.046
[10,    10] loss: 2.048
[10,    15] loss: 2.044
Got

KeyboardInterrupt: 

In [None]:
trees_data.__getitem__(1)["depth_images"].shape