In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

In [2]:
def check_image(path):
    print(path)
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [3]:
CCHANNEL = 3 # color channel: 1, 3 or 4
imgSizeW, imgSizeH = 64, 64 # normalized size for training, not the original size
imgCategoryCnt = 3


In [4]:
img_transforms = transforms.Compose([
    transforms.Resize((imgSizeH,imgSizeW)), # argument of Resize: (height,width)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225] )
    ])

In [5]:
train_data_path = "./Data_Sets/Training_Data"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)
print("*****  total number of images in 'Training_Data' is  %4d" % len(train_data))

TypeError: __init__() got an unexpected keyword argument 'is_valid_file'

In [None]:
val_data_path = "./Data_Sets/Validation_Data"
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)
print("*****  total number of images in 'Validation_Data' is  %4d" % len(val_data))

In [None]:
test_data_path = "./Data_Sets/Test_Data"
test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms, is_valid_file=check_image)
print("*****  total number of images in 'Test_Data' is  %4d" % len(test_data))

In [None]:
batch_size=64

In [None]:
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size) 
test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

In [None]:
inputNodeCnt   = imgSizeW * imgSizeH * CCHANNEL
hidden1NodeCnt = 84 # this number is somewhat arbitrary, should be greater than hidden2NodeCnt
hidden2NodeCnt = 50 # this number is somewhat arbitrary, should be greater than outputNodeCnt
outputNodeCnt  = imgCategoryCnt


class SimpleNet(nn.Module):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(inputNodeCnt, hidden1NodeCnt)   # Linear = fully connected
        self.fc2 = nn.Linear(hidden1NodeCnt, hidden2NodeCnt)
        self.fc3 = nn.Linear(hidden2NodeCnt, outputNodeCnt)
    
    def forward(self, x):
        x = x.view(-1, inputNodeCnt)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
simplenet = SimpleNet()

In [None]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

simplenet.to(device)

In [None]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [None]:
train(simplenet, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)