# Import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchsummary import summary
from torchvision.datasets import ImageFolder

from thop import profile

import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data

In [None]:
training_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128), antialias=True),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=0.1),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomGrayscale(p=0.1),
])

testing_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128), antialias=True),
])

batch_size = 64
train_dataset = ImageFolder(root='./dataset/train', transform=training_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_dataset = ImageFolder(root='./dataset/test', transform=testing_transforms)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Model

In [None]:
def DepthwiseSeparableConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels),
        nn.Conv2d(in_channels, out_channels, 1),
    )

def conv(in_channels, out_channels, kernel_size=3, bias=True):
    return DepthwiseSeparableConv(in_channels, out_channels, kernel_size)

def act():
    return nn.LeakyReLU()

def bn(channels):
    return nn.BatchNorm2d(channels)

def pool():
    return nn.MaxPool2d(2, 2)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.2):
        super(ResBlock, self).__init__()
        self.conv1 = conv(in_channels, out_channels, 3)
        self.bn1 = bn(out_channels)
        self.act1 = act()
        self.dropout = nn.Dropout(dropout_rate)
        self.conv2 = conv(out_channels, out_channels, 3)
        self.bn2 = bn(out_channels)
        self.act2 = act()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act1(out)
        out = self.dropout(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.act2(out)
        return out + x

class ResBlock_v2(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.2):
        super(ResBlock_v2, self).__init__()
        self.conv1 = conv(in_channels, out_channels, 3)
        self.bn1 = bn(out_channels)
        self.act1 = act()

        self.conv2 = conv(out_channels, out_channels, 3)
        self.bn2 = bn(out_channels)
        self.act2 = act()

        self.conv3 = conv(out_channels, out_channels, 3)
        self.bn3 = bn(out_channels)
        self.act3 = act()

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out1 = self.act1(out)
        out = self.conv2(out1)
        out = self.bn2(out)
        out2 = self.act2(out)
        out = self.conv3(out2)
        out = self.bn3(out)
        out3 = self.act3(out)
        return x + out1 + out2 + out3


class Head(nn.Module):
    def __init__(self, in_channels, shrink_times):
        super(Head, self).__init__()
        head = []
        n_feat = 16
        head.append(conv(in_channels, n_feat))
        head.append(bn(n_feat))
        head.append(act())
        head.append(pool())
        head.append(conv(n_feat, n_feat*2))
        head.append(bn(n_feat*2))
        head.append(act())
        head.append(pool())
        head.append(conv(n_feat*2, n_feat*2))
        head.append(bn(n_feat*2))
        head.append(act())
        head.append(pool())

        self.head = nn.Sequential(*head)

    def forward(self, x):
        return self.head(x)

class Body(nn.Module):
    def __init__(self, in_channels, out_channels, n_blocks):
        super(Body, self).__init__()
        body = []
        body.append(ResBlock(in_channels, out_channels))
        for _ in range(n_blocks-1):
            body.append(ResBlock_v2(out_channels, out_channels))

        self.body = nn.Sequential(*body)

    def forward(self, x):
        return self.body(x)


In [None]:
class Net1(torch.nn.Module):
    def __init__(self):
        super(Net1, self).__init__()

        self.head = Head(3, 5)
        n_body_feat = 32
        self.body = Body(n_body_feat, n_body_feat, 2)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(3),
            nn.Flatten(),
            nn.Linear(3*3*n_body_feat, 1),
            nn.Sigmoid()
        )


    def forward(self, x):
        x = self.head(x)
        x = self.body(x)
        x = self.fc(x)
        return x

# Test

In [None]:
def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            predicted = torch.round(outputs).squeeze()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Train

In [None]:
def train(model, trainloader, criterion, optimizer,epochs=10, test_loader=None):
    best_test_accuracy = 0
    training_accuracy_record = []
    testing_accuracy_record = []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        cur_samples = 0
        cur_correct_pred = 0
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.squeeze()
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            predicted = torch.round(outputs).squeeze()
            cur_samples += labels.size(0)
            cur_correct_pred += (predicted == labels).sum().item()
        training_accuracy = cur_correct_pred / cur_samples * 100
        training_accuracy_record.append(training_accuracy)
        test_accuracy = test(model, test_loader) * 100
        testing_accuracy_record.append(test_accuracy)
        if test_accuracy > best_test_accuracy:
            torch.save(model.state_dict(), 'best_model.pth')
        best_test_accuracy = max(best_test_accuracy, test_accuracy)
        print(f'Epoch {epoch +1}, Training Accuracy: {training_accuracy :.2f}%, Test Accuracy: {test_accuracy:.2f}%')
    print(f'BEST TEST ACCURACY: {best_test_accuracy:.2f}%')
    return training_accuracy_record, testing_accuracy_record


# Model's Info

In [None]:
def show_summary(model):
    image = torch.rand(1, 3, 128, 128).cuda()
    summary(model, (3, 128, 128))
    macs, parm = profile(model, inputs=(image, ))
    print(f'FLOPS: {macs * 2 / 1e6:.3f}M, Parameters: {parm / 1e3:.3f}K')

# Execution

In [None]:
model = Net1()
if torch.cuda.is_available():
    model.cuda()
criteria = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.6, patience=1, verbose=True)
show_summary(model)

In [None]:
training_accuracy_record, testing_accuracy_record = train(model, train_loader, criteria, optimizer,epochs=600, test_loader=test_loader)

# Plot

In [None]:
plt.figure()
plt.plot(training_accuracy_record, label='Training Accuracy')
plt.plot(testing_accuracy_record, label='Testing Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy vs Testing Accuracy')
plt.legend()
plt.savefig('accuracy.png')


# Demo

In [None]:
demo_dataset = ImageFolder(root='./dataset_demo', transform=testing_transforms)
demo_dataloader = DataLoader(demo_dataset, batch_size=64, shuffle=True, num_workers=4)

demo_model = Net1()
demo_model.to(device)
input_image = torch.rand(1, 3, 128, 128).cuda()
macs, parm = profile(demo_model, inputs=(input_image, ))

demo_model.load_state_dict(torch.load('best_model.pth'))
demo_model.eval()
test_accuracy = test(demo_model, demo_dataloader)
print()
print(f'(1) testing accuracy of this demo dataset: {test_accuracy * 100:.2f}%')
print(f'(2) FLOPs of your model: {macs * 2 / 1e6:.3f}M')
print(f'(3) number of trainable parameters of your model: {parm / 1e3:.3f}K')