In [1]:
import os
import sys
import copy
pdir = os.path.dirname(os.getcwd())
sys.path.append(pdir)

import torch
import numpy as np

import utils
import wandb

model_dir = 'models'
data_dir = 'tree_points.pt'

In [2]:
trees_data = torch.load(data_dir)

In [3]:
sweep_config = {
  "name":"laser-trees-bayes",
  "method":"bayes",
  "metric":{
      "name":"Best_acc",
      "goal":"maximize"
  },
  "parameters":{
    "loss_fn":{
      "values":["smooth-loss", "cross-entropy"]
    },
    "lr_init":{
      "min": 0.0001,
      "max": 0.001
    },
    "lr_step":{
      "min": 10,
      "max": 100
    },
    "lr_gamma":{
      "min": 0.1,
      "max": 0.9
    },
    "max_rotation":{
      "min":0,
      "max":2*np.pi
    },
    "max_translation":{
      "min":0,
      "max":2.0
    },
    "jitter_std":{
      "distribution":"log_uniform",
      "min":-7.0,
      "max":-2.0
    },
    "min_scale":{
      "min":0.3,
      "max":1
    },
    "max_scale":{
      "min":1,
      "max":2
    }
  }
}

In [4]:
def interface_to_train():
    with wandb.init() as run:
        params={
        "dataset_type":type(trees_data),
        "batch_size":128,
        "validation_split":.15,
        "test_split":.15,
        "shuffle_dataset":True,
        "random_seed":0,
        "learning_rate":[wandb.config["lr_init"], wandb.config["lr_step"], wandb.config["lr_gamma"]],  #[init, step_size, gamma] for scheduler
        "momentum":0.9, #Only used for sgd
        "epochs":150,
        "loss_fn":wandb.config["loss_fn"],
        "optimizer":"adam",
        "voting":"None",
        "train_sampler":"balanced",

        "model":"SimpleView",

        "image_dim":256,
        "camera_fov_deg":90,
        "f":1,
        "camera_dist":1.4,
        "depth_averaging":"min",
        "soft_min_k":50,
        "num_views":6,

        "transforms":['rotation','translation','scaling'], #,'translation'
        "min_rotation":0,
        "max_rotation":wandb.config["max_rotation"],
        "min_translation":0,
        "max_translation":wandb.config["max_translation"],
        "jitter_std":wandb.config["jitter_std"],
        "min_scale":wandb.config["min_scale"],
        "max_scale":wandb.config["max_scale"],

        "species":["QUEFAG", "PINNIG", "QUEILE", "PINSYL", "PINPIN"],
        "data_resolution":"2.5cm"
        }
        
        utils.train(data_dir=data_dir, model_dir=model_dir, params=params, init_wandb=False)
    return

In [5]:
sweep_id = '3gczm5qj'
#sweep_id = wandb.sweep(sweep_config, project='laser-trees-bayes')

count = 3
wandb.agent(sweep_id, project='laser-trees-bayes', function=interface_to_train, count=count)

[34m[1mwandb[0m: Agent Starting Run: 1tgkv5p9 with config:
[34m[1mwandb[0m: 	jitter_std: 0.0011448061215974888
[34m[1mwandb[0m: 	loss_fn: cross-entropy
[34m[1mwandb[0m: 	lr_gamma: 0.6542670409600482
[34m[1mwandb[0m: 	lr_init: 0.000251283891952897
[34m[1mwandb[0m: 	lr_step: 98
[34m[1mwandb[0m: 	max_rotation: 1.4077954480041475
[34m[1mwandb[0m: 	max_scale: 1
[34m[1mwandb[0m: 	max_translation: 0
[34m[1mwandb[0m: 	min_scale: 0.5651380075506061
[34m[1mwandb[0m: Currently logged in as: [33mmja2106[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
NA            2
JUNIPE        2
QUERCUS       2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485
Removing: QUERCUS
Removing: JUNIPE
Removing: NA
Removing: DEAD
Train Dataset:
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Validation Dataset (should match):
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Shuffling dataset...
Using balanced sampling...
Using cross-entropy loss...
Optimizing with AdaM...
Using step LR schedul

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[1,     5] loss: 3.832
[1,    10] loss: 3.609
[1,    15] loss: 3.544
[1,    20] loss: 3.564
OVERALL: Got 18 / 372 with accuracy 4.84
PINNIG: Got 0/93 with accuracy 0.00
PINPIN: Got 18/18 with accuracy 100.00
PINSYL: Got 0/40 with accuracy 0.00
QUEFAG: Got 0/177 with accuracy 0.00
QUEILE: Got 0/44 with accuracy 0.00
[2,     5] loss: 3.469
[2,    10] loss: 3.461
[2,    15] loss: 3.435
[2,    20] loss: 3.401
OVERALL: Got 41 / 372 with accuracy 11.02
PINNIG: Got 1/93 with accuracy 1.08
PINPIN: Got 0/18 with accuracy 0.00
PINSYL: Got 40/40 with accuracy 100.00
QUEFAG: Got 0/177 with accuracy 0.00
QUEILE: Got 0/44 with accuracy 0.00
[3,     5] loss: 3.338
[3,    10] loss: 3.207
[3,    15] loss: 3.242
[3,    20] loss: 3.189
OVERALL: Got 91 / 372 with accuracy 24.46
PINNIG: Got 29/93 with accuracy 31.18
PINPIN: Got 5/18 with accuracy 27.78
PINSYL: Got 19/40 with accuracy 47.50
QUEFAG: Got 1/177 with accuracy 0.56
QUEILE: Got 37/44 with accuracy 84.09
[4,     5] loss: 3.169
[4,    10] loss: 3.0

VBox(children=(Label(value=' 0.24MB of 0.24MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
PINNIG Accuracy,0.7957
PINPIN Accuracy,0.55556
PINSYL Accuracy,0.225
QUEFAG Accuracy,0.72316
QUEILE Accuracy,0.79545
Best_acc,0.76344
Best_min_acc,0.525
Train Loss,1.28396
Validation Loss,1.23675
Train Accuracy,0.92575


0,1
PINNIG Accuracy,▁▁▅▄▄▄▄▄▅▇▃▅▄▆▆▄▅▆▆▄█▂▃▇█▅▅▇▅▆▃▄▄▆▆▅▄▃▆▇
PINPIN Accuracy,▁▂▁▁▂█▄▄▇▅▇█▆▃▆▅▄▅▄▅▅▄▆▃▃▆▅▃▅▄▅▅▄▄▃▄▄▅▃▅
PINSYL Accuracy,█▁▃▅▂▄▇▇▂▃▅▂▅▅▂▇▃▅▅▄▂▂▅▂▃▂▄▅▆▅▅▅▇▄▂▆▆▄▃▂
QUEFAG Accuracy,▁▇▃▁▁▄▄▃▃▃▅▆▆▇▂▄▆▄▅▇▆█▇▇▃▇█▇▇▅▇▇▇▅▇▇▇▇▆▆
QUEILE Accuracy,▁▅█▇█▅▅▇▇█▇▅▆▆█▅▇▇▆▅▄▅▅▅▅▅▄▅▆▇▆▇▅▇▆▆▅▇▇▇
Best_acc,▁▅▅▇▇▇▇▇▇▇▇▇▇███████████████████████████
Best_min_acc,▁▁▂▄▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
Train Loss,█▆▅▇▆▄▄▄▃▃▃▄▃▂▄▃▃▂▂▂▂▇▂▃▄▂▂▁▁▁▁▁▂▁▁▁▁▂▁▁
Validation Loss,▆▃▃██▄▃▄▄▄▃▃▂▁▅▄▃▃▃▂▃▆▃▃▇▂▃▂▂▃▃▂▄▄▂▂▃▄▄▃
Train Accuracy,▁▃▄▃▄▅▄▅▅▆▆▅▆▇▆▆▆▇▇▇▇▄▇▆▅▇▇█████▇████▇▇█


[34m[1mwandb[0m: Agent Starting Run: koc0ak3l with config:
[34m[1mwandb[0m: 	jitter_std: 0.005986998259688522
[34m[1mwandb[0m: 	loss_fn: cross-entropy
[34m[1mwandb[0m: 	lr_gamma: 0.5278767400158753
[34m[1mwandb[0m: 	lr_init: 0.00019876135101088387
[34m[1mwandb[0m: 	lr_step: 83
[34m[1mwandb[0m: 	max_rotation: 0.3059247160550057
[34m[1mwandb[0m: 	max_scale: 2
[34m[1mwandb[0m: 	max_translation: 0
[34m[1mwandb[0m: 	min_scale: 0.5301797735621492
[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
NA            2
JUNIPE        2
QUERCUS       2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485
Removing: QUERCUS
Removing: JUNIPE
Removing: NA
Removing: DEAD
Train Dataset:
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Validation Dataset (should match):
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Shuffling dataset...
Using balanced sampling...
Using cross-entropy loss...
Optimizing with AdaM...
Using step LR schedul

VBox(children=(Label(value=' 0.37MB of 0.37MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
PINNIG Accuracy,0.86022
PINPIN Accuracy,0.22222
PINSYL Accuracy,0.275
QUEFAG Accuracy,0.81921
QUEILE Accuracy,0.65909
Best_acc,0.77151
Best_min_acc,0.58065
Train Loss,0.9241
Validation Loss,0.99172
Train Accuracy,0.94956


0,1
PINNIG Accuracy,▁▁▃▆▃▅▄▆▅▅▃▆▆▄▇█▃▆▆▇█▆▇▇▇▅▇▇▅▇▅▅▅▇▆▇▇▆▇█
PINPIN Accuracy,▃▄▇▁▄▇█▂▅▅█▅▇▁▇▄▅▇▇▄▄▅▄▇▆▄▃▄▅▂▄▃▆▂▇▅▄▄▅▃
PINSYL Accuracy,█▂▁▆▄▂▁▆▂█▃▁▂▅▂▃▅▁▁▁▁▃▄▁▂▅▄▅▇▃▅▇▅▃▄▃▅▃▁▂
QUEFAG Accuracy,▁█▅▁▄▄▆▅▆▄▇▆▆▇▂▅█▅▇▆▇▆▆▇▇▇█▆▇▆▇▇▇▆▇█▆▆▇▇
QUEILE Accuracy,▁▂▇▄█▅▄▆▇▆▅▇▅▅▆▅▄▇▆▇▅▆▅▅▅▆▃▄▃▆▅▅▅▆▅▄▅▆▅▅
Best_acc,▁▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇███████████████████████
Best_min_acc,▁▂▅▅▅▅▆▇▇▇▇▇▇▇▇▇████████████████████████
Train Loss,██▆█▆▅▅▅▅▄▅▃▄▅▄▃▄▄▂▃▂▂▂▁▁▁▂▂▂▁▁▂▁▁▁▁▁▁▁▁
Validation Loss,▇▃▃█▄▃▂▂▂▃▃▁▂▂▇▂▂▄▁▂▁▂▁▁▁▁▁▂▂▂▂▂▁▂▁▁▂▃▂▂
Train Accuracy,▁▂▄▂▄▄▅▅▅▅▅▆▆▅▆▆▆▆▇▆▇▇████▇▇▇██▇████████


[34m[1mwandb[0m: Agent Starting Run: gy1go4iu with config:
[34m[1mwandb[0m: 	jitter_std: 0.00183876272732213
[34m[1mwandb[0m: 	loss_fn: cross-entropy
[34m[1mwandb[0m: 	lr_gamma: 0.560165284763611
[34m[1mwandb[0m: 	lr_init: 0.00016579628647283914
[34m[1mwandb[0m: 	lr_step: 97
[34m[1mwandb[0m: 	max_rotation: 1.4767700393425809
[34m[1mwandb[0m: 	max_scale: 2
[34m[1mwandb[0m: 	max_translation: 0
[34m[1mwandb[0m: 	min_scale: 0.8998507321290261
[34m[1mwandb[0m: wandb version 0.12.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
NA            2
JUNIPE        2
QUERCUS       2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485
Removing: QUERCUS
Removing: JUNIPE
Removing: NA
Removing: DEAD
Train Dataset:
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Validation Dataset (should match):
QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478

Shuffling dataset...
Using balanced sampling...
Using cross-entropy loss...
Optimizing with AdaM...
Using step LR schedul

VBox(children=(Label(value=' 0.51MB of 0.51MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
PINNIG Accuracy,0.78495
PINPIN Accuracy,0.5
PINSYL Accuracy,0.425
QUEFAG Accuracy,0.59322
QUEILE Accuracy,0.75
Best_acc,0.75269
Best_min_acc,0.58065
Train Loss,1.60567
Validation Loss,0.92473
Train Accuracy,0.90274


0,1
PINNIG Accuracy,▁▂▁▆▄▄▄▅▂▅▃▆▆▆▅▆▅▆▇▇▅▂▆▆▄▆█▇▄▇▆▆▅▅▆▆▆▅▅▇
PINPIN Accuracy,▂▁▆▂▃▇▂▆▇▄▄▇█▆█▄█▃▆▄▇▆▇█▃▅▆▄▄▃▃▂▄▄▆▅▅▄▄▄
PINSYL Accuracy,█▄▂▄▃▂▇▃▄▆▅▃▂▃▂▅▂▄▁▂▂▇▃▁█▅▁▄█▃▅▆▇▆▂▃▆▅▂▃
QUEFAG Accuracy,▁██▄▅▇▄▄█▄█▆▆▅▅▇▆▅▇▆▆▅▆██▇▇▇▇▇▇██▅▆▇▇▇▇▆
QUEILE Accuracy,▁▃▅▆█▅▃▇▆▇▆▆▆█▇▆▆▇▇▇█▆▇▆▆▇▆▆▆▆▆▆▆██▆▆▆▇▇
Best_acc,▁▆▆▇▇▇▇▇▇▇▇▇▇▇▇█████████████████████████
Best_min_acc,▁▂▂▄▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇███████████████
Train Loss,█▆▇▅▅▅▆▅▄▄▄▃▃▃▃▃▄▃▂▂▃▄▂▃▂▂▃▁▃▂▁▁▂▂▂▁▁▁▂▁
Validation Loss,█▃▃▄▃▃▄▄▃▄▂▂▃▃▄▁▃▂▁▂▃▅▂▂▂▁▂▁▄▂▁▁▂▃▃▁▂▂▂▂
Train Accuracy,▁▃▃▄▄▅▄▅▅▆▅▆▆▆▆▇▆▇▇▇▆▆▇▇▇▇▇█▆▇██▇▇▇███▇█
