In [1]:
from ai.TrashDetector import TrashDetector
from ai.DataLoader2 import DataLoader
from env import *

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
ETA = 1e-3
EPOCHS = 30

In [3]:
train_dloader = DataLoader(DETECTOR_TRAIN_DATA_PATH, DETECTOR_CAT, noise=True)
valid_dloader = DataLoader(DETECTOR_VALID_DATA_PATH, DETECTOR_CAT, noise=True)

Number of data: 288
Number of batch: 12
Number of data: 47
Number of batch: 1


In [4]:
detector = TrashDetector().cuda()

FeatureAE was loaded.


In [5]:
criterion = nn.NLLLoss()
optimizer = optim.Adam(detector.parameters(), lr=ETA)

In [6]:

top_valid_acc = 0.0
for e in range(EPOCHS):
    
    train_loss = 0.0
    train_acc = 0.0
    valid_loss = 0.0
    valid_acc = 0.0
    
    detector.train()
    
    for x_batch, y_batch in train_dloader.next_batch():
        x = torch.FloatTensor(x_batch).cuda()
        y = torch.LongTensor(np.repeat(y_batch, 8, axis=0)).cuda()
        
        logps = detector(x)
        loss = criterion(logps, y)
        train_loss += loss.item()
        
        with torch.no_grad():
            ps = torch.exp(logps)
            ps_k, top_k = ps.topk(1, dim=1)
            train_acc += torch.mean((top_k == y.view(*top_k.size())).type(torch.FloatTensor))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    with torch.no_grad():
        detector.eval()
        
        for x_batch, y_batch in valid_dloader.next_batch():
            x = torch.FloatTensor(x_batch).cuda()
            y = torch.LongTensor(np.repeat(y_batch, 8, axis=0)).cuda()
            
            logps = detector(x)
            loss = criterion(logps, y)
            valid_loss += loss.item()
            
            ps = torch.exp(logps)
            _, topk = ps.topk(1, dim=1)
            valid_acc += torch.mean((topk == y.view(*topk.size())).type(torch.FloatTensor))
    
    train_loss /= len(train_dloader)
    train_acc /= len(train_dloader)
    valid_loss /= len(valid_dloader)
    valid_acc /= len(valid_dloader)
    
    print(f"Epochs {e+1}/{EPOCHS}")
    print(f"Train loss: {train_loss:.8f}")
    print(f"Train acc: {train_acc:.4f}")
    print(f"Valid loss: {valid_loss:.8f}")
    print(f"Valid acc: {valid_acc:.4f}")
    
    if valid_acc > top_valid_acc:
        top_valid_acc = valid_acc
        detector.save(DET_CKPT_PATH)
    

Epochs 1/30
Train loss: 0.50667866
Train acc: 0.7413
Valid loss: 0.57872659
Valid acc: 0.6667
Detector was saved.
Epochs 2/30
Train loss: 0.41269368
Train acc: 0.8299
Valid loss: 0.21542799
Valid acc: 0.9167
Detector was saved.
Epochs 3/30
Train loss: 0.31811236
Train acc: 0.8733
Valid loss: 0.21827489
Valid acc: 0.9375
Detector was saved.
Epochs 4/30
Train loss: 0.26474833
Train acc: 0.9097
Valid loss: 0.29734591
Valid acc: 0.8958
Epochs 5/30
Train loss: 0.25571543
Train acc: 0.9115
Valid loss: 0.18532193
Valid acc: 0.9167
Epochs 6/30
Train loss: 0.27284602
Train acc: 0.8958
Valid loss: 0.19505970
Valid acc: 0.9375
Epochs 7/30
Train loss: 0.26174472
Train acc: 0.9080
Valid loss: 0.44475320
Valid acc: 0.8333
Epochs 8/30
Train loss: 0.18920416
Train acc: 0.9253
Valid loss: 0.25634494
Valid acc: 0.8958
Epochs 9/30
Train loss: 0.17023488
Train acc: 0.9392
Valid loss: 0.28903630
Valid acc: 0.8958
Epochs 10/30
Train loss: 0.18798455
Train acc: 0.9358
Valid loss: 0.13472039
Valid acc: 0.9375