In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data
import numpy as np

In [10]:
epoch_num = 100
lr_init = 0.1
batch_size = 32
lr_change=30

In [3]:
# create the network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.clf = nn.Sequential(
            nn.Linear(7*7*128, 512), nn.ReLU(),
            nn.Linear(512, 128), nn.ReLU(),
            nn.Linear(128, 2)
        )
    
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(-1, 7*7*128)
        x = self.clf(x)
        return x

In [4]:
# read the data
train_dataset = datasets.ImageFolder('data_files/dogs-vs-cats/train',
                                    transforms.Compose([
                                        transforms.RandomResizedCrop(224),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])
                                    ]))
val_dataset = datasets.ImageFolder('data_files/dogs-vs-cats/val',
                                    transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])
                                    ]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# train (func)
def train(epoch):
    model.train()
    
    losses = []
    batch_num = len(train_loader)
    
    for i, (images, targets) in enumerate(train_loader):
        output = model(images)
        loss = criterion(output, targets)
        
        print('Epoch {}, batch {}/{}, train loss={}'.format(epoch, i, batch_num, loss.item()))
        losses.append(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Epoch {}, avg train loss={}'.format(epoch, np.mean(losses)))
        

In [11]:
# eval (func)
def validate(epoch):
    model.eval()
    
    losses = []
    batch_num = len(val_loader)
    
    for i, (images, targets) in enumerate(val_loader):
        output = model(images)
        loss = criterion(output, targets)
        
        print('Epoch {}, batch {}/{}, val loss={}'.format(epoch, i, batch_num, loss.item()))
        losses.append(loss.item())
        
    print('=== Epoch {}, avg val loss={} ==='.format(epoch, np.mean(losses)))

In [7]:
# maybe control learning rate schedule
def change_lr(epoch):
    lr = lr_init * (0.1 ** (epoch//lr_change))
    for params in optimizer.param_groups:
        params['lr'] = lr

In [12]:
# create network object, create loss function, optimizer
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr_init, momentum=0.9, weight_decay=0.01)

In [None]:
# training cycle
for epoch in range(epoch_num):
    train(epoch)
    validate(epoch)
    if epoch % lr_change:
        change_lr()
        
    # save checkpoint
    # visualize training process

Epoch 0, batch 0/625, train loss=0.6956437230110168
Epoch 0, batch 1/625, train loss=0.6946680545806885
Epoch 0, batch 2/625, train loss=0.6916537284851074
Epoch 0, batch 3/625, train loss=0.7081995010375977
Epoch 0, batch 4/625, train loss=0.704748809337616
Epoch 0, batch 5/625, train loss=0.694390594959259
Epoch 0, batch 6/625, train loss=0.6867038011550903
Epoch 0, batch 7/625, train loss=0.6997222900390625
Epoch 0, batch 8/625, train loss=0.6915119886398315
Epoch 0, batch 9/625, train loss=0.6929004788398743
Epoch 0, batch 10/625, train loss=0.7121329307556152
Epoch 0, batch 11/625, train loss=0.6946157813072205
Epoch 0, batch 12/625, train loss=0.7223030924797058
Epoch 0, batch 13/625, train loss=0.7180026173591614
Epoch 0, batch 14/625, train loss=0.710652768611908
Epoch 0, batch 15/625, train loss=0.6911607980728149
Epoch 0, batch 16/625, train loss=0.6934406757354736
Epoch 0, batch 17/625, train loss=0.6916505694389343
Epoch 0, batch 18/625, train loss=0.7008422613143921
Epoch 

Epoch 0, batch 155/625, train loss=0.6853950619697571
Epoch 0, batch 156/625, train loss=0.669978678226471
Epoch 0, batch 157/625, train loss=0.728427529335022
Epoch 0, batch 158/625, train loss=0.7306479215621948
Epoch 0, batch 159/625, train loss=0.6913806796073914
Epoch 0, batch 160/625, train loss=0.6928365230560303
Epoch 0, batch 161/625, train loss=0.691790759563446
Epoch 0, batch 162/625, train loss=0.7019097805023193
Epoch 0, batch 163/625, train loss=0.6958920955657959
Epoch 0, batch 164/625, train loss=0.6861026883125305
Epoch 0, batch 165/625, train loss=0.6918165683746338
Epoch 0, batch 166/625, train loss=0.6720803380012512
Epoch 0, batch 167/625, train loss=0.6776589751243591
Epoch 0, batch 168/625, train loss=0.7028496861457825
Epoch 0, batch 169/625, train loss=0.70401930809021
Epoch 0, batch 170/625, train loss=0.7040043473243713
Epoch 0, batch 171/625, train loss=0.7379370927810669
Epoch 0, batch 172/625, train loss=0.7280022501945496
Epoch 0, batch 173/625, train los

Epoch 0, batch 308/625, train loss=0.6817317008972168
Epoch 0, batch 309/625, train loss=0.6917080879211426
Epoch 0, batch 310/625, train loss=0.6904004216194153
Epoch 0, batch 311/625, train loss=0.6993986368179321
Epoch 0, batch 312/625, train loss=0.6925867795944214
Epoch 0, batch 313/625, train loss=0.6930519342422485
Epoch 0, batch 314/625, train loss=0.6892719864845276
Epoch 0, batch 315/625, train loss=0.7036823034286499
Epoch 0, batch 316/625, train loss=0.6872789263725281
Epoch 0, batch 317/625, train loss=0.7058111429214478
Epoch 0, batch 318/625, train loss=0.6863223910331726
Epoch 0, batch 319/625, train loss=0.7128068208694458
Epoch 0, batch 320/625, train loss=0.7103201746940613
Epoch 0, batch 321/625, train loss=0.698043167591095
Epoch 0, batch 322/625, train loss=0.6862731575965881
Epoch 0, batch 323/625, train loss=0.692210853099823
Epoch 0, batch 324/625, train loss=0.6965720653533936
Epoch 0, batch 325/625, train loss=0.6988438963890076
Epoch 0, batch 326/625, train 

Epoch 0, batch 461/625, train loss=0.6881989240646362
Epoch 0, batch 462/625, train loss=0.7004512548446655
Epoch 0, batch 463/625, train loss=0.6947242021560669
Epoch 0, batch 464/625, train loss=0.6942763328552246
Epoch 0, batch 465/625, train loss=0.6915109753608704
Epoch 0, batch 466/625, train loss=0.6912045478820801
Epoch 0, batch 467/625, train loss=0.7198827862739563
Epoch 0, batch 468/625, train loss=0.6866614818572998
Epoch 0, batch 469/625, train loss=0.6866626739501953
Epoch 0, batch 470/625, train loss=0.6675612330436707
Epoch 0, batch 471/625, train loss=0.7090970873832703
Epoch 0, batch 472/625, train loss=0.6857061982154846
Epoch 0, batch 473/625, train loss=0.7048378586769104
Epoch 0, batch 474/625, train loss=0.667069137096405
Epoch 0, batch 475/625, train loss=0.6788381338119507
Epoch 0, batch 476/625, train loss=0.6853504776954651
Epoch 0, batch 477/625, train loss=0.7408403158187866
Epoch 0, batch 478/625, train loss=0.7138084173202515
Epoch 0, batch 479/625, train