In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
import matplotlib.pyplot as plt

## Setting

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('use', device)

In [None]:
# Hyper-parameter
epochs = 10
batch_size = 256

In [None]:
transform = Compose([ToTensor(), Resize((96, 96)), Normalize(0, 1)])

In [None]:
train_data = FashionMNIST('data', train=True, transform=transform, download=True)
test_data = FashionMNIST('data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=batch_size, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, pin_memory=True)

## Build ResNet

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, num_channel, protect=False, strides=1):
        super(ResidualBlock, self).__init__()
        
        self.need_protect = protect or strides > 1
        
        self.cnn_1 = nn.LazyConv2d(num_channel, kernel_size=3, padding=1, stride=strides)
        self.bn_1 = nn.BatchNorm2d(num_channel)
        self.cnn_2 = nn.LazyConv2d(num_channel, kernel_size=3, padding=1)
        self.bn_2 = nn.BatchNorm2d(num_channel)
        self.relu = nn.ReLU()
        
        if self.need_protect:
            self.protect = nn.LazyConv2d(num_channel, kernel_size=1, stride=strides)
        
    def forward(self, X):
        Y = self.cnn_1(X)
        Y = self.bn_1(Y)
        Y = self.relu(Y)
        
        Y = self.cnn_2(Y)
        Y = self.bn_2(Y)
        Y = self.relu(Y)
        
        if self.need_protect:
            protect = self.protect(X)
        else:
            protect = X
        Y = Y + protect
        
        return self.relu(Y)

### Build My Model

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        residual_block = [
            ResidualBlock(64, protect=True),
            ResidualBlock(128, strides=2, protect=True),
            ResidualBlock(256, strides=2, protect=True),
            ResidualBlock(512, strides=2, protect=True),
        ]
        
        self.net = nn.Sequential(
            *residual_block
        )

        self.start = nn.Sequential(
                    nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
                    nn.LazyBatchNorm2d(), nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                    )
        
        self.last = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.LazyLinear(10)
        )
        
    def forward(self, X):
        X = self.start(X)
        X = self.net(X)
        X = self.last(X)
        return X

In [None]:
model = Model(); model.to(device)
loss = CrossEntropyLoss()
optimizer = Adam(model.parameters())
# schduler

In [None]:
from torchsummary import summary
summary(model, (1, 96, 96))

## Start Training

In [None]:
history = {
    'acc':[],
    'loss':[],
    'val_acc':[],
    'val_loss':[],
}

In [None]:
def record(acc, loss, val_acc, val_loss, needPrint=False):
    global history
    history['acc'].append(acc)
    history['val_acc'].append(val_acc)
    history['loss'].append(loss)
    history['val_loss'].append(val_loss)
    
    if needPrint:
        print(f'Training acc {acc:.4f}, loss {loss:.4f}')
        print(f'Test acc {val_acc:.4f}, loss {val_loss:.4f}')

In [None]:
last_val_acc = -1
count = 0
def EarlyStopping(patience=5):
    global history, model
    if last_val_acc >= history['val_acc'][-1]: 
        count += 1    
        print('Not improve')
    else:
        torch.save(model.state_dict(), 'model.pt')
        print('\033[33m' + 'Save' + '\033[0m')
        
    last_val_acc = history['val_acc'][-1]
        
    if count == patience: return True
    return False

In [None]:
for i in range(epochs):
    print(f'Epoch {i + 1} Start')
    
    model.train()
    train_acc = 0
    train_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        img, label = data
        img, label = img.to(device), label.to(device)
        
        output = model(img)
        output_loss = loss(output, label)
        # Record
        train_acc += (output.argmax(dim=1) == label).sum().item()
        train_loss += output_loss.item()
        # BP
        output_loss.backward()
        optimizer.step()
    train_acc /= len(train_data)
    train_loss /= len(train_loader)
    
    model.eval()
    test_acc = 0
    test_loss = 0
    with torch.no_grad():
        for data in test_loader:
            img, label = data
            img, label = img.to(device), label.to(device)
            
            output = model(img)
            output_loss = loss(output, label)
            # Record
            test_acc += (output.argmax(dim=1) == label).sum().item()
            test_loss += output_loss.item()
    
    test_acc /= len(test_data)
    test_loss /= len(test_loader)
    
    record(train_acc, train_loss, train_acc, train_loss, True)
    if EarlyStopping(5): break
    print('==================================')

## Plot Outcome

In [None]:
def plot(name):
    global device
    plt.plot(history[name], 'b', label='train')
    plt.plot(history['val_' + name], 'r', label='val')
    plt.title('Performance')
    plt.xlabel(name)
    plt.ylabel('Epoch')
    plt.show()

In [None]:
plot('acc')

In [None]:
plot('loss')