In [1]:
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from utils.train_model import train_model
from torch.utils.tensorboard import SummaryWriter
from sgdhess import SGDHess
%matplotlib inline

In [2]:
%load_ext tensorboard

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

data_dir = 'tiny-imagenet-200'
torch.cuda.empty_cache()

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=64,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# class_names = image_datasets['train'].classes



In [4]:
model_ft = models.resnet18()
model_ft.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 200)
device = torch.device('cuda')
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# optimizer_ft = SGDHess(model_ft.parameters(), lr = 0.15, weight_decay = 1e-4)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_ft, [80, 120])

In [None]:
%%time
model_ft = train_model(model_ft, dataloaders, dataset_sizes, criterion, optimizer_ft, scheduler,
                       num_epochs=15, create_graph=False) # set to true for sgdhess

Epoch 1/15
----------




Iteration: 157/157, Loss: 98.77537536621094....Train Loss: 4.9561 Acc: 0.0372
Val Loss: 6.2761 Acc: 0.0081
Best Val Accuracy: 0.0081

Epoch 2/15
----------
Iteration: 157/157, Loss: 128.69920349121094.5.Train Loss: 4.2993 Acc: 0.1059
Val Loss: 8.3084 Acc: 0.0028
Best Val Accuracy: 0.0081

Epoch 3/15
----------
Iteration: 157/157, Loss: 122.6227798461914..4.Train Loss: 3.8474 Acc: 0.1654
Val Loss: 8.1727 Acc: 0.0041
Best Val Accuracy: 0.0081

Epoch 4/15
----------
Iteration: 8/1563, Loss: 226.48585510253906.

In [None]:
%tensorboard --logdir=runs