# 얼굴검출 학습하기

In [1]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torchvision

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reproducibility
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

trans = torchvision.transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor()
])
    
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        
        # Final FC 7x7x64 inputs -> 10 outputs
        self.fc = torch.nn.Linear(7 * 7* 128, 625, bias=True)
        torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(625, 2, bias=True)
        torch.nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)   # Flatten them for FC
        out = self.fc(out)
        out = self.fc2(out)
        return out

if __name__ == "__main__":
    learning_rate = 0.001
    training_epochs = 10
    batch_size = 1

    mnist_train = torchvision.datasets.ImageFolder(root='./Dataset/',transform = trans)
    
    data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)

    model = CNN().to(device)

    # define cost/loss & optimizer
    criterion = torch.nn.CrossEntropyLoss().to(device)  # Softmax is internally computed.
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # train my model
    total_batch = len(data_loader)
    print('Learning started. It takes sometime.')
    for epoch in range(training_epochs):
        avg_cost = 0
        step = 0

        for X, Y in data_loader:
            # image is already size of (28x28), no reshape
            # label is not one-hot encoded
            X = X.to(device)
            Y = Y.to(device)
            if step % 100 == 0:
                print(step, ": ", X.shape)

            optimizer.zero_grad()
            print(X.shape)
            hypothesis = model(X)
            cost = criterion(hypothesis, Y)
            cost.backward()
            optimizer.step()
            avg_cost += cost / total_batch
            step = step + 1
            
        torch.save(model.state_dict(),'param/netD_epoch_%d.pth'%(epoch))
        #model.load_state_dict(torch.load(model))
    
    print('Learning Finished!')

    test_data = torchvision.datasets.ImageFolder(root='./training/',transform = trans)
    test_set = torch.utils.data.DataLoader(dataset=test_data, batch_size=len(test_data))
    with torch.no_grad():
        for num, data in enumerate(test_set):
            imgs, label = data
            imgs = imgs.to(device)
            label = label.to(device)
            print(imgs.shape)

            prediction = model(imgs)
            print(prediction)
            print(torch.argmax(prediction, 1))
            print(label)

            correct_prediction = torch.argmax(prediction, 1) == label

            accuracy = correct_prediction.float().mean()
            print('Accuracy:', accuracy.item())

Learning started. It takes sometime.
0 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
300 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
600 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
200 :  torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
500 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
200 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
500 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
200 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
500 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56,

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

# 결과값 확인하기

In [32]:
    test_data = torchvision.datasets.ImageFolder(root='./training/',transform = trans)
    test_set = torch.utils.data.DataLoader(dataset=test_data, batch_size=len(test_data))   
    with torch.no_grad():
        for num, data in enumerate(test_set):
            imgs, label = data
            imgs = imgs.to(device)
            label = label.to(device)
            print(label)
            print(imgs.shape)

            prediction = model(imgs)
            print(prediction)
            print(torch.argmax(prediction, 1))
            print(label)

            correct_prediction = torch.argmax(prediction, 1) == label
            #1이면 얼굴 0이면 아님
            accuracy = correct_prediction.float().mean()
            print('Accuracy:', accuracy.item())

tensor([1])
torch.Size([1, 3, 56, 56])
tensor([[-5.5577,  5.7315]])
tensor([1])
tensor([1])
Accuracy: 1.0


# 화재데이터 학습

In [34]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import torchvision

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# for reproducibility
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

trans = torchvision.transforms.Compose([
transforms.Resize((56, 56)),
transforms.ToTensor()
])
    
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2))
        
        # Final FC 7x7x64 inputs -> 10 outputs
        self.fc = torch.nn.Linear(7 * 7* 128, 625, bias=True)
        torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(625, 2, bias=True)
        torch.nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)   # Flatten them for FC
        out = self.fc(out)
        out = self.fc2(out)
        return out

if __name__ == "__main__":
    learning_rate = 0.001
    training_epochs = 10
    batch_size = 1

    mnist_train = torchvision.datasets.ImageFolder(root='./Fire_dataset/',transform = trans)
    
    data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              drop_last=True)

    model = CNN().to(device)

    # define cost/loss & optimizer
    criterion = torch.nn.CrossEntropyLoss().to(device)  # Softmax is internally computed.
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # train my model
    total_batch = len(data_loader)
    print('Learning started. It takes sometime.')
    for epoch in range(training_epochs):
        avg_cost = 0
        step = 0

        for X, Y in data_loader:
            # image is already size of (28x28), no reshape
            # label is not one-hot encoded
            X = X.to(device)
            Y = Y.to(device)
            if step % 100 == 0:
                print(step, ": ", X.shape)

            optimizer.zero_grad()
            print(X.shape)
            hypothesis = model(X)
            cost = criterion(hypothesis, Y)
            cost.backward()
            optimizer.step()
            avg_cost += cost / total_batch
            step = step + 1
            
        torch.save(model.state_dict(),'param/netD_epoch_%d.pth'%(epoch))
        #model.load_state_dict(torch.load(model))
    
    print('Learning Finished!')

    test_data = torchvision.datasets.ImageFolder(root='./Fire_training/',transform = trans)
    test_set = torch.utils.data.DataLoader(dataset=test_data, batch_size=len(test_data))
    with torch.no_grad():
        for num, data in enumerate(test_set):
            imgs, label = data
            imgs = imgs.to(device)
            label = label.to(device)
            print(imgs.shape)

            prediction = model(imgs)
            print(prediction)
            print(torch.argmax(prediction, 1))
            print(label)

            correct_prediction = torch.argmax(prediction, 1) == label

            accuracy = correct_prediction.float().mean()
            print('Accuracy:', accuracy.item())

Learning started. It takes sometime.
0 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
0 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 5

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
0 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 5

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
0 :  torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 5

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
torch.Size([1, 3, 56, 56])
t

# 결과값 확인

In [40]:
    test_data = torchvision.datasets.ImageFolder(root='./training/',transform = trans)
    test_set = torch.utils.data.DataLoader(dataset=test_data, batch_size=len(test_data))
    with torch.no_grad():
        for num, data in enumerate(test_set):
            imgs, label = data
            imgs = imgs.to(device)
            label = label.to(device)
            print(imgs.shape)

            prediction = model(imgs)
            print(prediction)
            print(torch.argmax(prediction, 1))
            print(label)

            correct_prediction = torch.argmax(prediction, 1) == label

            accuracy = correct_prediction.float().mean()
            print('Accuracy:', accuracy.item())

torch.Size([1, 3, 56, 56])
tensor([[ 6.3748, -5.5555]])
tensor([0])
tensor([1])
Accuracy: 0.0
