In [1]:
import os
import sys
import copy
pdir = os.path.dirname(os.getcwd())
sys.path.append(pdir)

import torch
import numpy as np

import utils

model_dir = 'models'
data_dir = 'tree_points.pt'

In [2]:
trees_data = torch.load(data_dir)

In [3]:
params_list = []

params_list.append({
    "dataset_type":type(trees_data),
    "batch_size":128,
    "validation_split":.2,
    "shuffle_dataset":True,
    "random_seed":0,
    "learning_rate":[0.0005, 100, 0.5],  #[init, step_size, gamma] for scheduler
    "momentum":0.9, #Only used for sgd
    "epochs":300,
    "loss_fn":"cross-entropy",
    "optimizer":"adam",
    "voting":"None",
    "train_sampler":"balanced",

    "model":"SimpleView",

    "image_dim":256,
    "camera_fov_deg":90,
    "f":1,
    "camera_dist":1.4,
    "depth_averaging":"min",
    "soft_min_k":50,
    "num_views":6,

    "transforms":['rotation','translation'], #,'translation'
    "min_rotation":0,
    "max_rotation":2*np.pi,
    "min_translation":0,
    "max_translation":0.5,
    "jitter_std":0, 

    "species":["QUEFAG", "PINNIG", "QUEILE", "PINSYL", "PINPIN"],
    "data_resolution":"2.5cm"
})

In [4]:
for params in params_list:
    utils.train(data_dir=data_dir, model_dir=model_dir, params=params)

[34m[1mwandb[0m: Currently logged in as: [33mmja2106[0m (use `wandb login --relogin` to force relogin)


QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
NA            2
JUNIPE        2
QUERCUS       2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485


[34m[1mwandb[0m: wandb version 0.10.31 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Removing: JUNIPE
Removing: DEAD
Removing: QUERCUS
Removing: NA
Train Dataset:
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Validation Dataset (should match):
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Shuffling dataset...
Using balanced sampling...
Using cross-entropy loss...
Optimizing with AdaM...
Using step LR scheduler...
[1,     5] loss: 3.727
[1,    10] loss: 3.541
[1,    15] loss: 3.557
[1,    20] loss: 3.603
OVERALL: Got 42 / 495 with accuracy 8.48
PINNIG: Got 0/129 with accuracy 0.00
PINPIN: Got 0/33 with accuracy 0.00
PINSYL: Got 42/42 with accuracy 100.00
QUEFAG: Got 0/231 with accuracy 0.00
QUEILE: Got 0/60 with accuracy 0.00


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
PINNIG Accuracy,0.81395
PINPIN Accuracy,0.60606
PINSYL Accuracy,0.59524
QUEFAG Accuracy,0.8355
QUEILE Accuracy,0.81667
Train Loss,0.55482
Validation Loss,0.93376
Train Accuracy,0.95642
Validation Accuracy,0.79192
Learning Rate,0.00013


0,1
PINNIG Accuracy,▁▅▂██▆▃▇█▆▇▅█▇▆▆▆█▇▆▇▇▇▆▇█▇▇▇▆▇██▇▆▇▇█▇▇
PINPIN Accuracy,▂▁▅▁▅▆▅▂▃▆▂█▃▆▃▅▅▅▇▅▅▅▄▅▅▅▅▅▅▅▆▅▆▆▅▆▅▆▆▆
PINSYL Accuracy,▁▂▂▂▂▃▅▅▄▄▆▆▅▂▆▆▃▄▆▇▇▆▆█▆▆▃▇▆▇▆▅▆▇█▆▇▆▆▆
QUEFAG Accuracy,█▄█▃▅▆▇▁▄▂▄▅▄▇▇▆▇▇▇▇▆▇▅▃▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇
QUEILE Accuracy,▁█▄█▆█▆███▆▆▆█▅█▇▇█▇█▇▇▆▇▇▆▇▇▆▇█▇▇▇█▇▆▆▇
Train Loss,█▄▄▃▃▂▃▃▂▂▃▂▂▂▂▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Validation Loss,█▃▃▃▂▁▂▅▂▄▄▃▃▁▂▁▁▁▁▂▂▁▂▅▁▁▁▂▁▂▁▂▁▁▂▂▁▂▂▂
Train Accuracy,▁▃▃▄▅▆▅▄▆▆▄▆▆▇▆▇▆▇▇▇▇▇▇▆██▇█████████████
Validation Accuracy,▂▃▄▃▅▆▅▁▅▂▃▄▄▇▆▆▇██▇▇█▆▃▇████▇█▇███▇████
Learning Rate,██████████████▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁
