In [4]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

from recycle_cnn import CNN

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device=='cuda':
    torch.cuda.manual_seed_all(777)
    
learning_rate = 0.001
epochs = 20
batch_size=16

In [6]:
trans = transforms.ToTensor()

train_set = torchvision.datasets.ImageFolder('data/train',transform=trans)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

In [7]:
model = CNN().to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

total_batch = len(train_loader)

In [8]:
print(model)

CNN(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=786432, out_features=128, bias=True)
  (drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=128, out_features=6, bias=True)
)


In [9]:
for epoch in range(epochs):
    avg_cost = 0
    model.train()
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        hypothesis = model(x)
        cost = criterion(hypothesis, y)
        cost.backward()
        optimizer.step()
        
        avg_cost += cost / total_batch
        
    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

[Epoch:    1] cost = 575.635803
[Epoch:    2] cost = 304.373383
[Epoch:    3] cost = 187.066635
[Epoch:    4] cost = 138.106552
[Epoch:    5] cost = 74.1338806
[Epoch:    6] cost = 38.3252754
[Epoch:    7] cost = 32.8701439
[Epoch:    8] cost = 17.5870667
[Epoch:    9] cost = 12.4766979
[Epoch:   10] cost = 15.6199522
[Epoch:   11] cost = 12.0677338
[Epoch:   12] cost = 8.50937748
[Epoch:   13] cost = 11.8942966
[Epoch:   14] cost = 8.63452911
[Epoch:   15] cost = 6.82889795
[Epoch:   16] cost = 6.91486883
[Epoch:   17] cost = 8.89770031
[Epoch:   18] cost = 5.48392534
[Epoch:   19] cost = 3.94427967
[Epoch:   20] cost = 3.60839772


In [10]:
torch.save(model, 'model.pt')