In [1]:
import torch
import numpy as np
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import os
import os.path
import matplotlib.pyplot as plt

In [2]:
Batch_size = 256
Learning_rate = 0.001
seed = 3407
torch.manual_seed(seed)

<torch._C.Generator at 0x7fcee764ba10>

In [3]:
def try_gpu(i=0):
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def save_model(model, path, optimizer=None):
    """
    Saves the state_dict of a torch model and optional optimizer to 'path'
    Returns: None
    """
    state = {"model": model.state_dict()}
    if optimizer is not None:
        state["optimizer"] = optimizer.state_dict()
    torch.save(state, path)


def load_model(model, path, optimizer=None):
    """
    Loads the state_dict of a torch model and optional optimizer from 'path'
    Returns: None
    """
    state = torch.load(path)
    model.load_state_dict(state["model"])
    if optimizer is not None:
        optimizer.load_state_dict(state["optimizer"])

In [4]:
# Data Preprocessing
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
transform90 = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.RandomRotation([90,90]),
    torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
transform180 = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.RandomRotation([180,180]), 
    torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
transform270 = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.RandomRotation([270,270]), 
    torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))])

trainset_0 = torchvision.datasets.MNIST(root='./data',  train=True, download=True, transform=transform)
trainset_0.targets[trainset_0.targets < 11] = 0
trainset90 = torchvision.datasets.MNIST(root='./data',  train=True, download=True, transform=transform90)
trainset90.targets[trainset90.targets < 11] = 1
trainset180 = torchvision.datasets.MNIST(root='./data',  train=True, download=True, transform=transform180)
trainset180.targets[trainset180.targets < 11] = 2
trainset270 = torchvision.datasets.MNIST(root='./data',  train=True, download=True, transform=transform270)
trainset270.targets[trainset270.targets < 11] = 3
trainset = torchvision.datasets.MNIST(root='./data',  train=True, download=True, transform=transform)

rotate_trainset_all = torch.utils.data.ConcatDataset([trainset_0, trainset90, trainset180, trainset270])

rotate_train_loader = torch.utils.data.DataLoader(rotate_trainset_all, batch_size=Batch_size, shuffle=True, num_workers=8, pin_memory=True)
basic_train_loader = torch.utils.data.DataLoader(trainset, batch_size=Batch_size, shuffle=True, num_workers=8, pin_memory=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=Batch_size, shuffle=False, num_workers=8, pin_memory=True)

In [5]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

classes = ('0', '90', '180', '270')
# get some random training images
dataiter = iter(rotate_train_loader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(Batch_size)))

In [6]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.extract = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 4 * 4, 84),
            nn.ReLU(),
        )
        self.fc3 = nn.Linear(84, 4)

    def forward(self, x):
        x = self.extract(x)
        x = self.fc3(x)
        return x


In [7]:
net = Net()
net.to(device=try_gpu())
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=Learning_rate, weight_decay=1e-5)

In [8]:
n_epoch = 30
pre_loss = []
for epoch in range(n_epoch):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(rotate_train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device=try_gpu())
        labels = labels.to(device=try_gpu())
        
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'epoch [{epoch} / {n_epoch}] loss: {running_loss / i:.3f}')
    pre_loss.append(running_loss / i)
    running_loss = 0.0
    if epoch % n_epoch == 0:
        save_model(model=net, path=f"./rotnet_epoch_{epoch+1}.ckpt")

np.save("rot_train_loss.npy", pre_loss)
print('Finished Training')

In [10]:
class Classifiernet(nn.Module):
    def __init__(self):
        super().__init__()
        self.extract = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.cur = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.extract(x)
        x = self.cur(x)
        return x

classifiernet = Classifiernet()
classifiernet.to(device=try_gpu())

Classifiernet(
  (extract): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (cur): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=256, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [11]:
pre_trained = Net()
pre_trained.to(device=try_gpu())
classifiernet.to(device=try_gpu())
load_model(pre_trained, f"./rotnet_epoch_{n_epoch}.ckpt")
classifier_dict = classifiernet.state_dict()
pre = {k: v for k, v in pre_trained.state_dict().items() if k in classifier_dict}
classifier_dict.update(pre)
classifiernet.load_state_dict(classifier_dict)

<All keys matched successfully>

In [12]:
def train_classifier(net, train_iter, criterion, optimizer):
    if isinstance(net, nn.Module):
        net.train()
    loss_record = 0
    total = 0
    for _, data in enumerate (train_iter, 0):
        img, labels = data
        img = Variable(img)
        label = Variable(labels)
        img = img.to(device=try_gpu())
        label = label.to(device=try_gpu())
        output = net(img)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_record += loss
        total += 1

    return loss_record / total

def eval_accuracy(net, data_iter):
    if isinstance(net, nn.Module):
        net.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, target in data_iter:
            eval_imgs = Variable(data).to(device=try_gpu())
            target = Variable(target).to(device=try_gpu())
            output = net(eval_imgs)
            pred = output.max(dim=1)[1]
            correct += (pred == target).sum().item()
            total += target.size(0)
        return correct / total

In [13]:
for name, param in classifiernet.named_parameters():
    if param.requires_grad and "extract" in name:
        param.requires_grad = False

classifiernet.to(device=try_gpu())
num_epochs = 10
Learning_rate = 0.001
classifier_criterion = nn.CrossEntropyLoss().to(device=try_gpu())
classifier_optimizer = torch.optim.Adam(classifiernet.parameters(), lr=Learning_rate, weight_decay=1e-5)

train_loss = []
val_acc =[]
for epoch in range(num_epochs):
    epoch_train_loss = train_classifier(classifiernet, basic_train_loader, classifier_criterion, classifier_optimizer)
    epoch_val_acc = eval_accuracy(classifiernet, test_loader, classifier_criterion)
    print('epoch [{}/{}], train_loss:{:.4f}'.format(epoch+1, num_epochs, epoch_train_loss.item()))    
    print('epoch [{}/{}], val_acc:{:.4f}'.format(epoch+1, num_epochs, epoch_val_acc))
    train_loss.append(epoch_train_loss.item())
    val_acc.append(epoch_val_acc)

np.save("rotateclassvalaccfix.npy", val_acc)

epoch [1/10], train_loss:0.4677
epoch [1/10], val_acc:0.9405
epoch [2/10], train_loss:0.1965
epoch [2/10], val_acc:0.9514
epoch [3/10], train_loss:0.1626
epoch [3/10], val_acc:0.9557
epoch [4/10], train_loss:0.1434
epoch [4/10], val_acc:0.9610
epoch [5/10], train_loss:0.1281
epoch [5/10], val_acc:0.9616
epoch [6/10], train_loss:0.1167
epoch [6/10], val_acc:0.9656
epoch [7/10], train_loss:0.1083
epoch [7/10], val_acc:0.9675
epoch [8/10], train_loss:0.0990
epoch [8/10], val_acc:0.9671
epoch [9/10], train_loss:0.0951
epoch [9/10], val_acc:0.9688
epoch [10/10], train_loss:0.0868
epoch [10/10], val_acc:0.9694
