In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import natsort
import glob
import numpy as np
import pandas as pd
import torchvision.models.vgg as vgg
from PIL import Image

import torch.optim as optim
from torch.utils.data import Dataset,DataLoader

import torchvision
import torchvision.transforms as transforms

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(1218)
if device =='cuda':
    torch.cuda.manual_seed_all(1218)

In [None]:
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = torchvision.datasets.ImageFolder('your_path',transform=trans)

In [None]:
data_loader = DataLoader(dataset = train_data, batch_size = 32, shuffle = True, num_workers=0)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            #in 채널, out 채널, image_size
            nn.Conv2d(3, 64, 3, padding=1),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), #3, 64, 112

            nn.Conv2d(64, 128, 3, padding=1),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), #64, 128, 56

            nn.Conv2d(128, 256, 3, padding=1),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), #128, 256, 28

            nn.Conv2d(256, 512, 3, padding=1),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # 256, 512, 14

            nn.Conv2d(512, 512, 3, padding=1),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2) # 512, 512, 7
        )#512 7

        self.avg_pool = nn.AvgPool2d(7)
        #512 1
        self.classifier = nn.Linear(512, 3)

    def forward(self, x):
        features = self.conv(x)
        x = self.avg_pool(features)
        x = x.view(features.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
net=Net().to(device)

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.005)
loss_func = nn.CrossEntropyLoss().to(device)
lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.0001)

In [None]:
epochs = 10

for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    lr_sche.step()
    for i, data in enumerate(data_loader, 0):
        # get the inputs
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 30 == 29:    # print every 30 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 30))
            running_loss = 0.0
        

print('Finished Training')

In [None]:
torch.save(net.state_dict(), 'your_path.pth')

In [None]:
new_net = Net().to(device)

In [None]:
new_net.load_state_dict(torch.load('your_path.pth'))

In [None]:
class CustomDataSet(Dataset):

    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform

        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])

        image = Image.open(img_loc).convert("RGB")

        tensor_image = self.transform(image)

        return tensor_image



In [None]:
test_data = CustomDataSet('your_path', transform=trans)

In [None]:
test_set = DataLoader(dataset = test_data, batch_size = 32)