In [3]:
# from __future__ import print_function
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torchvision import datasets, transforms

from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
#
# from skimage.io import imread
# from io import open
# import glob
import os
#
import pandas as pd
#
# from os import listdir
from PIL import Image as PImage
#
# import matplotlib.pyplot as plt
# import matplotlib.image as img
#
# import importlib
#
# # from sam.sam import *
# import time

12

0.13.0


In [None]:
class ResNetBasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, padding):
        super(ResNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size, 1, padding)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x = F.relu(self.bn(self.conv1(x)))
        x = self.conv2(x)
        return x

class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ResNetBlock, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch   
        self.proj = nn.Conv2d(in_ch, out_ch, 1, 2)
        self.bn_proj = nn.BatchNorm2d(out_ch)
        
        if (in_ch == out_ch):
            self.block1 = ResNetBasicBlock(in_ch, out_ch, 3, 1, 1)
        else:
            self.block1 = ResNetBasicBlock(in_ch, out_ch, 3, 2, 1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.block2 = ResNetBasicBlock(out_ch, out_ch, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.block3 = ResNetBasicBlock(out_ch, out_ch, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(out_ch)

    def forward(self, x):

        if (self.in_ch == self.out_ch):
            shortcut1 = x.clone()
        else:
            shortcut1 = self.proj(x)
        x = self.block1(x)
        x += shortcut1
        x = self.bn1(x)
        x = F.relu(x)

        shortcut2 = x.clone()
        x = self.block2(x)
        x += shortcut2
        x = self.bn2(x)
        x = F.relu(x)

        shortcut3 = x.clone()
        x = self.block3(x)
        x += shortcut3
        x = self.bn3(x)

        return x

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn0 = nn.BatchNorm2d(16)
        self.resBlock1 = ResNetBlock(16, 16)
        self.resBlock2 = ResNetBlock(16, 32)
        self.resBlock3 = ResNetBlock(32, 64)
        self.avgPool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.bn0(self.conv1(x)))

        x = F.relu(self.resBlock1(x))
        x = F.relu(self.resBlock2(x))
        x = F.relu(self.resBlock3(x))

        x = self.avgPool(x)
        x = x.view(-1, 64)
        x = self.fc(x)

        return x

In [None]:
def train(model, device, train_loader, optimizer, cross_entropy, epoch, isSam):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
#         loss = F.nll_loss(output, target)
        loss = cross_entropy(output, target)
        loss.backward()
        
        optimizer.step()
        
#         if isSam:
#             optimizer.first_step(zero_grad=True)
#             cross_entropy(model(data), target).mean().backward()
#             optimizer.second_step(zero_grad=True)
#         else:
#             optimizer.step()
        
        train_loss += loss.item() * data.size(0)
        
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return loss.item()
       
def validate(model, device, validation_loader, cross_entropy):
    model.eval()
    valid_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in validation_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
#             valid_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            valid_loss += cross_entropy(output, target).item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    valid_loss /= len(validation_loader.dataset)
    
    print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        valid_loss, correct, len(validation_loader.dataset),
        100. * correct / len(validation_loader.dataset)))
    return valid_loss

def test(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(
            correct, total, 100. * correct / total))

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

train_loader, validation_loader, test_loader = loadDatabase(isData1=True)

model = ResNet().to(device)

cross_entropy = nn.CrossEntropyLoss()

adam_optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay = 0.0001)
# sam_optimizer = SAM(model.parameters(), torch.optim.SGD, lr=0.001, momentum=0.9)
# adadelta_optimizer = torch.optim.Adadelta(model.parameters(), lr=0.001)
# sgd_optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

counter, train_losses, valid_losses = [], [], []

for epoch in range(1, 50 + 1):

    train_losses.append( train(model, device, train_loader, adam_optimizer, cross_entropy, epoch, False) )
    valid_losses.append( validate(model, device, validation_loader, cross_entropy) )
    counter.append(epoch)

plt.figure(figsize=(9, 6))
plt.ylabel("Loss")
plt.xlabel("Number of Epochs")
plt.plot(counter, train_losses, "r", label = "Train loss")
plt.plot(counter, valid_losses, "b", label = "Validation loss")
plt.title("Loss")
plt.show()

torch.save(model.state_dict(), 'model1.ckpt') 

test(model, device, test_loader)