<a href="https://colab.research.google.com/github/greenkode/pytorch/blob/master/Pytorch_Chapter_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

In [2]:
# !python download.py

In [3]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [4]:
train_data_path = "./train/"
transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456,0.406],
                         std=[0.229, 0.224, 0.225])
])

train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=transforms, is_valid_file=check_image)

In [5]:
val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path, transform=transforms, is_valid_file=check_image)

In [6]:
test_data_path = "./test/"
test_data = torchvision.datasets.ImageFolder(root=test_data_path, transform=transforms, is_valid_file=check_image)

In [7]:
batch_size=64
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

In [8]:
class SimpleNet(nn.Module):
  def __init__(self):
    super(SimpleNet, self).__init__()
    self.fc1 = nn.Linear(12288, 84)
    self.fc2 = nn.Linear(84, 50)
    self.fc3 = nn.Linear(50, 2)

  def forward(self, x):
    x = x.view(-1, 12288)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

simplenet = SimpleNet()

In [9]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

In [10]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

simplenet.to(device)

SimpleNet(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

In [11]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch+1, training_loss,
        valid_loss, num_correct / num_examples))

In [12]:
train(simplenet, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, test_data_loader, 10, device)

Epoch: 1, Training Loss: 3.36, Validation Loss: 2.24, accuracy = 0.48
Epoch: 2, Training Loss: 1.93, Validation Loss: 1.17, accuracy = 0.73
Epoch: 3, Training Loss: 1.75, Validation Loss: 0.69, accuracy = 0.73
Epoch: 4, Training Loss: 0.57, Validation Loss: 0.62, accuracy = 0.75
Epoch: 5, Training Loss: 0.50, Validation Loss: 0.65, accuracy = 0.72
Epoch: 6, Training Loss: 0.48, Validation Loss: 0.55, accuracy = 0.75
Epoch: 7, Training Loss: 0.31, Validation Loss: 0.61, accuracy = 0.76
Epoch: 8, Training Loss: 0.37, Validation Loss: 0.58, accuracy = 0.74
Epoch: 9, Training Loss: 0.26, Validation Loss: 0.58, accuracy = 0.78
Epoch: 10, Training Loss: 0.26, Validation Loss: 0.59, accuracy = 0.76


In [13]:
labels = ['cat', 'fish']

img = Image.open("./val/fish/100_1422.JPG")
img = transforms(img).to(device)

prediction = F.softmax(simplenet(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction])

fish


In [14]:
torch.save(simplenet, "/tmp/simplenet")
simplenet = torch.load("/tmp/simplenet")

In [15]:
torch.save(simplenet.state_dict(), "/tmp/simplenet")
simplenet = SimpleNet()
simplenet_state_dict = torch.load("/tmp/simplenet")
simplenet.load_state_dict(simplenet_state_dict)

<All keys matched successfully>