In [1]:
import os
import sys
pdir = os.path.dirname(os.getcwd())
sys.path.append(pdir)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import utils
from simpleview_pytorch import SimpleView

from torch.utils.data.dataset import Dataset

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Load the data:

In [3]:
trees_data = torch.load('trees_new.pt')
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

QUEFAG     1116
PINNIG      581
QUEILE      364
PINSYL      277
PINPIN      140
QUERCUS       2
JUNIPE        2
NA            2
DEAD          1
Name: sp, dtype: int64
Species:  ['DEAD', 'JUNIPE', 'NA', 'PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE', 'QUERCUS']
Labels:  tensor([8, 3, 6,  ..., 7, 3, 6])
Total count:  2485


#### Remove low-count species:

In [4]:
for specie in ['NA', 'QUERCUS', 'JUNIPE', 'DEAD']:
    trees_data.remove_species(specie)
    
print(trees_data.counts)
print('Species: ', trees_data.species)
print('Labels: ', trees_data.labels)
print('Total count: ', len(trees_data))

QUEFAG    1116
PINNIG     581
QUEILE     364
PINSYL     277
PINPIN     140
Name: sp, dtype: int64
Species:  ['PINNIG', 'PINPIN', 'PINSYL', 'QUEFAG', 'QUEILE']
Labels:  tensor([0, 3, 3,  ..., 4, 0, 3])
Total count:  2478


#### Train-validation split:

In [18]:
batch_size = 128
validation_split = .2
shuffle_dataset = True
random_seed = 0

In [19]:
dataset_size = len(trees_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [20]:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(trees_data, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(trees_data, batch_size=batch_size,
                                                sampler=valid_sampler)

In [21]:
for x in train_loader:
    print(torch.unique(x['labels'], return_counts = True))
print()

for x in validation_loader:
    print(torch.unique(x['labels'], return_counts = True))
print()

(tensor([0, 1, 2, 3, 4]), tensor([28,  9, 21, 49, 21]))
(tensor([0, 1, 2, 3, 4]), tensor([27,  5, 18, 59, 19]))
(tensor([0, 1, 2, 3, 4]), tensor([38,  6, 12, 58, 14]))
(tensor([0, 1, 2, 3, 4]), tensor([20, 16,  9, 62, 21]))
(tensor([0, 1, 2, 3, 4]), tensor([28,  5, 16, 53, 26]))
(tensor([0, 1, 2, 3, 4]), tensor([30,  9, 13, 59, 17]))
(tensor([0, 1, 2, 3, 4]), tensor([27,  8, 15, 58, 20]))
(tensor([0, 1, 2, 3, 4]), tensor([39,  6, 13, 47, 23]))
(tensor([0, 1, 2, 3, 4]), tensor([24,  8, 15, 63, 18]))
(tensor([0, 1, 2, 3, 4]), tensor([23,  2, 14, 71, 18]))
(tensor([0, 1, 2, 3, 4]), tensor([34,  3, 20, 49, 22]))
(tensor([0, 1, 2, 3, 4]), tensor([34,  8, 11, 52, 23]))
(tensor([0, 1, 2, 3, 4]), tensor([34,  4, 20, 47, 23]))
(tensor([0, 1, 2, 3, 4]), tensor([23,  4, 14, 76, 11]))
(tensor([0, 1, 2, 3, 4]), tensor([27,  9, 17, 54, 21]))
(tensor([0, 1, 2, 3, 4]), tensor([16,  5,  7, 28,  7]))

(tensor([0, 1, 2, 3, 4]), tensor([32,  9, 12, 61, 14]))
(tensor([0, 1, 2, 3, 4]), tensor([29,  9, 12, 6

### Define model, loss fn, optimiser:

In [27]:
model = SimpleView(
    num_views=trees_data.depth_images.shape[1],
    num_classes=len(trees_data.species)
)

model = model.to(device=device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)

### Train & Test Loops:

In [28]:
for epoch in range(100):  # loop over the dataset multiple times

    model.train()
    
    running_loss = 0.0
    #Training loop
    for i, data in enumerate(train_loader, 0):
        depth_images = data['depth_images']
        labels = data['labels']

        depth_images = depth_images.to(device=device)
        labels = labels.to(device=device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(depth_images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2))
            running_loss = 0.0
            
    #Test loop
    num_correct = 0
    num_samples = 0
    model.eval()  
    with torch.no_grad():
  
        
        for data in validation_loader:
            depth_images = data['depth_images']
            labels = data['labels']

            depth_images = depth_images.to(device=device)
            labels = labels.to(device=device)
            
            scores = model(depth_images)
            _, predictions = scores.max(1)
            num_correct += (predictions == labels).sum()
            num_samples += predictions.size(0)
            
        print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')

print('Finished Training')

[1,     5] loss: 4.104
[1,    10] loss: 3.582
[1,    15] loss: 3.630
Got 231 / 495 with accuracy 46.67
[2,     5] loss: 3.320
[2,    10] loss: 3.323
[2,    15] loss: 3.313
Got 141 / 495 with accuracy 28.48
[3,     5] loss: 3.196
[3,    10] loss: 3.149
[3,    15] loss: 3.089
Got 245 / 495 with accuracy 49.49
[4,     5] loss: 3.072
[4,    10] loss: 3.088
[4,    15] loss: 3.024
Got 280 / 495 with accuracy 56.57
[5,     5] loss: 2.981
[5,    10] loss: 2.954
[5,    15] loss: 2.959
Got 285 / 495 with accuracy 57.58
[6,     5] loss: 2.891
[6,    10] loss: 2.904
[6,    15] loss: 2.910
Got 278 / 495 with accuracy 56.16
[7,     5] loss: 2.857
[7,    10] loss: 2.806
[7,    15] loss: 2.891
Got 290 / 495 with accuracy 58.59
[8,     5] loss: 2.743
[8,    10] loss: 2.825
[8,    15] loss: 2.856
Got 290 / 495 with accuracy 58.59
[9,     5] loss: 2.801
[9,    10] loss: 2.810
[9,    15] loss: 2.691
Got 287 / 495 with accuracy 57.98
[10,     5] loss: 2.837
[10,    10] loss: 2.707
[10,    15] loss: 2.650
G

KeyboardInterrupt: 

### Update old object to new class (Don't run every time):

In [2]:
metadata_file = "../data/treesXYZ/meta/META.csv"
data_dir = "../data/treesXYZ/"
trees_new = utils.TreeSpeciesDataset(data_dir, metadata_file)

trees_old=torch.load('trees_old.pt')

trees_new.depth_images = trees_old.depth_images
trees_new.labels = trees_old.labels.long()

torch.save(trees_new, 'trees_new.pt')