In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
dataset = datasets.ImageFolder(
    'dataset_avoidance_v2',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 200, 200])
print(len(dataset) - 200, 200)

In [4]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
)

In [5]:
model = models.resnet18(pretrained=True)

In [6]:
model.fc = torch.nn.Linear(512, 2)

In [7]:
device = torch.device('cuda')
model = model.to(device)

**<font size = 10 color = blue>Start to train model!!</font>**

**<font size = 5 color = black>報錯修正</font>**

<font color = blue>如果無法執行模型，並發現是報錯在loss.backward()的話
<br>清除dataset(包含)資料夾內部的資料夾的ipynb_checkpoints
<br>清除方式:
<br><font color = red>du -chd 1 | sort -h 進行查詢</font>
<br><font color = red>rm -rf .ipynb_checkpoints 進行清除</font>
<br><font color = blue>驗證，觀察train_dataset, test_dataset的總和是否為正確數值</font>

In [None]:
labels = 0
NUM_EPOCHS = 1
BEST_MODEL_PATH = 'best_model_resnet18_avoidance_test.pth'
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        print(labels)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    # print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy
print("Finish!!")