In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy
from PIL import Image

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

#커스텀 신경망 클래스 선언

class NN(nn.Module) :
    def __init__(self) :
        super().__init__()
        self.f = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size=(3,3),padding=(1,1)), #(3,32,32)>(32,32,32)
            nn.ReLU(),
            nn.Conv2d(in_channels=32,out_channels=32,kernel_size=(3,3),padding=(1,1)), #(32,32,32)>(32,32,32)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2), #(32,32,32)>(32,16,16)
            nn.Dropout(0.25),

            nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(3,3),padding=(1,1)), #(32,16,16)>(64,16,16)
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(3,3),padding=(1,1)), #(64,16,16)>(64,16,16)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2), #(64,16,16)>(64,8,8)
            nn.Dropout(0.25),

            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(3,3),padding=(1,1)), #(64,8,8)>(128,8,8)
            nn.MaxPool2d(kernel_size=2,stride=2), #(128,8,8)>(128,4,4)
            nn.ReLU(),
            nn.Dropout(0.25),

            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=(3,3),padding=(1,1)), #(128,4,4)>(128,4,4)
            nn.MaxPool2d(kernel_size=2,stride=2), #(128,4,4)>(128,2,2)
            nn.ReLU(),
            nn.Dropout(0.25),

            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(3,3),padding=(1,1)), #(128,2,2)>(256,2,2)
            nn.MaxPool2d(kernel_size=2,stride=2), #(256,2,2)>(256,1,1)
            nn.ReLU(),
            nn.Dropout(0.25),
        )
        self.g = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10),
        )
    def forward(self, x) :
        x = self.f(x)
        x = x.reshape(-1, 256)
        x = self.g(x)
        return x

In [3]:
#커스텀 신경망 불러오기

F = torch.load("CIFAR10.pt", weights_only=False)
F = F.to(device)

In [4]:
import os
li = os.listdir("image/")
print(li)

['airplane.jpg', 'automobile.jpg', 'bird.jpg', 'cat.jpg', 'deer.jpg', 'dog.jpg', 'frog.jpg', 'horse.jpg', 'ship.jpg', 'truck.jpg']


In [8]:
with open("list.txt", mode = "r") as f:
    target_list = f.readlines()

print(target_list)

['Airplane\n', 'Automobile\n', 'Bird\n', 'Cat\n', 'Deer\n', 'Dog\n', 'Frog\n', 'Horse\n', 'Ship\n', 'Truck']


In [21]:
F.eval()
print(target_list)
for name in li :
    if name.find(".jpg") < 0 :
        continue
    img = Image.open("image/"+name)
    img = img.convert(mode = "RGB").resize((32,32))
    img.save("image/converted_"+name)
    img = numpy.transpose(numpy.array(img),(2,0,1)) / 255
    img = torch.tensor(img, dtype = torch.float, device = device)
    
    y = F(img)
    prop = (nn.functional.softmax(y, dim = -1) * 100).squeeze().type(torch.long).numpy().tolist()
    for i in range(len(prop)) :
        print(f"{target_list[i].strip()}:{prop[i]}, ", end = "")
    print()
    print("y : ", target_list[y.argmax().item()].strip())
    print("t : ", name)
    print()

['Airplane\n', 'Automobile\n', 'Bird\n', 'Cat\n', 'Deer\n', 'Dog\n', 'Frog\n', 'Horse\n', 'Ship\n', 'Truck']
Airplane:100, Automobile:0, Bird:0, Cat:0, Deer:0, Dog:0, Frog:0, Horse:0, Ship:0, Truck:0, 
y :  Airplane
t :  airplane.jpg

Airplane:0, Automobile:100, Bird:0, Cat:0, Deer:0, Dog:0, Frog:0, Horse:0, Ship:0, Truck:0, 
y :  Automobile
t :  automobile.jpg

Airplane:0, Automobile:0, Bird:99, Cat:0, Deer:0, Dog:0, Frog:0, Horse:0, Ship:0, Truck:0, 
y :  Bird
t :  bird.jpg

Airplane:0, Automobile:0, Bird:0, Cat:99, Deer:0, Dog:0, Frog:0, Horse:0, Ship:0, Truck:0, 
y :  Cat
t :  cat.jpg

Airplane:0, Automobile:0, Bird:0, Cat:0, Deer:100, Dog:0, Frog:0, Horse:0, Ship:0, Truck:0, 
y :  Deer
t :  deer.jpg

Airplane:0, Automobile:0, Bird:0, Cat:2, Deer:2, Dog:68, Frog:0, Horse:25, Ship:0, Truck:0, 
y :  Dog
t :  dog.jpg

Airplane:0, Automobile:0, Bird:0, Cat:0, Deer:0, Dog:0, Frog:100, Horse:0, Ship:0, Truck:0, 
y :  Frog
t :  frog.jpg

Airplane:0, Automobile:0, Bird:0, Cat:0, Deer:0, Do