In [13]:
import torch, os
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from train_model import train_model
from test_model import test_model
%matplotlib inline

In [14]:
data_dir = 'tiny-images/'
num_workers = {'train' : 4,'val'   : 0,'test'  : 0}
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
    ])
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 
                  for x in ['train', 'val','test']}
dataloaders = {x: data.DataLoader(image_datasets[x], batch_size=100, shuffle=False, num_workers=num_workers[x])
                  for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

In [15]:
model = models.resnet18()
model.avgpool = nn.AdaptiveAvgPool2d(1)
model.fc.out_features = 200
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('models/224/model_10_epoch.pt', map_location=device))
model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

test_model(model, dataloaders, dataset_sizes, criterion, optimizer, phases=['val'])

Iteration: 1/100, Loss: 44.255492091178894.Running corrects: 87/100

Iteration: 2/100, Loss: 200.26464462280273.Running corrects: 42/100

Iteration: 3/100, Loss: 195.12858390808105.Running corrects: 50/100

Iteration: 4/100, Loss: 108.98988246917725.Running corrects: 78/100

Iteration: 5/100, Loss: 113.1702184677124.Running corrects: 75/100

Iteration: 6/100, Loss: 111.78959608078003.Running corrects: 70/100

Iteration: 7/100, Loss: 95.64816355705261.Running corrects: 78/100

Iteration: 8/100, Loss: 128.11332941055298.Running corrects: 65/100

Iteration: 9/100, Loss: 113.72027397155762.Running corrects: 70/100

Iteration: 10/100, Loss: 146.6173529624939.Running corrects: 62/100

Iteration: 11/100, Loss: 108.71661901473999.Running corrects: 71/100

Iteration: 12/100, Loss: 57.514458894729614.Running corrects: 83/100

Iteration: 13/100, Loss: 96.09923362731934.Running corrects: 77/100

Iteration: 14/100, Loss: 141.75066947937012.Running corrects: 64/100

Iteration: 15/100, Loss: 154.6946