In [1]:
from ai.TrashDetector import TrashDetector
from ai.DataLoader2 import DataLoader
from env import *

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
ETA = 1e-3
EPOCHS = 300

In [3]:
train_dloader = DataLoader(DETECTOR_TRAIN_DATA_PATH, DETECTOR_CAT, noise=True)
valid_dloader = DataLoader(DETECTOR_VALID_DATA_PATH, DETECTOR_CAT, noise=True)

Number of data: 3727
Number of batch: 33
Number of data: 183
Number of batch: 1


In [4]:
detector = TrashDetector().cuda()

In [5]:
criterion = nn.NLLLoss()
optimizer = optim.Adam(detector.parameters(), lr=ETA)

In [6]:

top_valid_acc = 0.0
for e in range(EPOCHS):
    
    train_loss = 0.0
    train_acc = 0.0
    valid_loss = 0.0
    valid_acc = 0.0
    
    for x_batch, y_batch in train_dloader.next_batch():
        x = torch.FloatTensor(x_batch).cuda()
        x = x.view(-1, IN_CHANNEL, HEIGHT, WIDTH)
        y = torch.LongTensor(np.repeat(y_batch, 8, axis=0)).cuda()
        
        logps = detector(x)
        loss = criterion(logps, y)
        train_loss += loss.item()
        
        with torch.no_grad():
            ps = torch.exp(logps)
            ps_k, top_k = ps.topk(1, dim=1)
            train_acc += torch.mean((top_k == y.view(*top_k.size())).type(torch.FloatTensor))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    with torch.no_grad():
        detector.eval()
        
        for x_batch, y_batch in valid_dloader.next_batch():
            x = torch.FloatTensor(x_batch).cuda()
            x = x.view(-1, IN_CHANNEL, HEIGHT, WIDTH)
            y = torch.LongTensor(np.repeat(y_batch, 8, axis=0)).cuda()
            
            logps = detector(x)
            loss = criterion(logps, y)
            valid_loss += loss.item()
            
            ps = torch.exp(logps)
            _, topk = ps.topk(1, dim=1)
            valid_acc += torch.mean((topk == y.view(*topk.size())).type(torch.FloatTensor))
            
        detector.train()
    
    train_loss /= len(train_dloader)
    train_acc /= len(train_dloader)
    valid_loss /= len(valid_dloader)
    valid_acc /= len(valid_dloader)
    
    print(f"Epochs {e+1}/{EPOCHS}")
    print(f"Train loss: {train_loss:.8f}")
    print(f"Train acc: {train_acc:.4f}")
    print(f"Valid loss: {valid_loss:.8f}")
    print(f"Valid acc: {valid_acc:.4f}")
    
    if valid_acc > top_valid_acc:
        top_valid_acc = valid_acc
        detector.save(DET_CKPT_PATH)
    

Epochs 1/300
Train loss: 0.77377599
Train acc: 0.5579
Valid loss: 0.60723031
Valid acc: 0.7188
Detector was saved.
Epochs 2/300
Train loss: 0.59773333
Train acc: 0.6776
Valid loss: 0.51107353
Valid acc: 0.7500
Detector was saved.
Epochs 3/300
Train loss: 0.50008606
Train acc: 0.7637
Valid loss: 0.40579250
Valid acc: 0.7902
Detector was saved.
Epochs 4/300
Train loss: 0.41986846
Train acc: 0.8133
Valid loss: 0.55703872
Valid acc: 0.7768
Epochs 5/300
Train loss: 0.37322606
Train acc: 0.8387
Valid loss: 0.37211844
Valid acc: 0.8482
Detector was saved.
Epochs 6/300
Train loss: 0.33914234
Train acc: 0.8605
Valid loss: 0.33628798
Valid acc: 0.8705
Detector was saved.
Epochs 7/300
Train loss: 0.31855433
Train acc: 0.8690
Valid loss: 0.36550474
Valid acc: 0.8527
Epochs 8/300
Train loss: 0.27567348
Train acc: 0.8952
Valid loss: 0.41536981
Valid acc: 0.7902
Epochs 9/300
Train loss: 0.27879339
Train acc: 0.8884
Valid loss: 0.26724336
Valid acc: 0.8616
Epochs 10/300
Train loss: 0.24586325
Train ac

KeyboardInterrupt: 