In [13]:
import torch 
import torch.nn as nn
import torchvision
import numpy as np
from torchvision.transforms import transforms
from torch import optim
from torch.optim import lr_scheduler
import time
import copy

In [14]:
net =torchvision.models.resnet50(pretrained='imagenet')

In [15]:
for params in net.parameters():
    params.requires_grad = True

In [16]:

net.fc = nn.Linear(net.fc.in_features,10)

In [20]:
model = torch.nn.DataParallel(net).to('cuda')

In [None]:
#net = torch.nn.DataParallel(model)

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.SGD(net.parameters(),lr=.001,momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
data_trainsformation = {'train' : transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]),

'val': transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])}

In [None]:
trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=data_trainsformation['train'])


testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=data_trainsformation['val'])

dataloader =  {'train': torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2),
'val':torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)}

datasizes ={'train':len(trainset),'val':len(testset)}

# classes = (beaver, dolphin, otter, seal, whale, 
# aquarium fish, flatfish, ray, shark, trout, 
# orchids, poppies, roses, sunflowers, tulips, 
# bottles, bowls, cans, cups, plates, 
# apples, mushrooms, oranges, pears, sweet peppers, 
# clock, computer keyboard, lamp, telephone, television, 
# bed, chair, couch, table, wardrobe, 
# bee, beetle, butterfly, caterpillar, cockroach, 
# bear, leopard, lion, tiger, wolf, 
# bridge, castle, house, road, skyscraper, 
# cloud, forest, mountain, plain, sea, 
# camel, cattle, chimpanzee, elephant, kangaroo, 
# fox, porcupine, possum, raccoon, skunk, 
# crab, lobster, snail, spider, worm, 
# baby, boy, girl, man, woman, 
# crocodile, dinosaur, lizard, snake, turtle, 
# hamster, mouse, rabbit, shrew, squirrel, 
# maple, oak, palm, pine, willow, 
# bicycle, bus, motorcycle, pickup truck, train, 
# lawn-mower, rocket, streetcar, tank, tractor

In [18]:
def train_model(model,criterian,optimizer,scheduler,num_epochs=15):
    
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        since = time.time()
        for phase in ['train','val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss     = 0.00
            running_corrects = 0
            
    
            for inputs,labels in dataloader[phase]:

                inputs,labels = inputs.to('cuda'),labels.to('cuda')
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    
                    outputs = model(inputs)
                    _,pred = torch.max(outputs,1)
                    loss = criterion(outputs,labels)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(pred==labels.data)


            epoch_loss = running_loss/ datasizes[phase]
            epoch_acc  = running_corrects.double()/datasizes[phase]
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                
                
            time_elapsed = time.time() - since
            print('This epoch complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        #print('total Acc {}'.format(total_acc.item()))
        print('Epoch: {} | {} Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,phase, epoch_loss, epoch_acc))
    
    
    model.load_state_dict(best_model_wts)
    return model

        

In [None]:
best_model_  = train_model(model,criterion,optimizer,exp_lr_scheduler,num_epochs=15)