In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision import datasets
import sys
from tqdm import tqdm

## Model construction

In [2]:
class Net(torch.nn.Module):
    def __init__(self,init_weights=False):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(48, 128, kernel_size=5, padding=2)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(128, 192, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(192, 128, kernel_size=3, padding=1)
        self.relu4 = nn.ReLU(inplace=True)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2,)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(128 * 6 * 6, 2048)
        self.relu_fc1 = nn.ReLU(inplace=True)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(2048, 2048)
        self.relu_fc2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(2048, 2)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.maxpool1(x)
        x = self.relu2(self.conv2(x))
        x = self.maxpool2(x)
        x = self.relu3(self.conv3(x))
        x = self.relu4(self.conv4(x))
        x = self.maxpool4(x)
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = self.relu_fc1(self.fc1(x))
        x = self.dropout2(x)
        x = self.relu_fc2(self.fc2(x))
        x = self.fc3(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

# Train,  Validation  and test Model

In [4]:
batch_size=64
train_transform=transforms.Compose([transforms.RandomResizedCrop(224),
                              transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
val_transform=transforms.Compose([transforms.Resize((224,224)),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
train_dataset=datasets.ImageFolder(root='Desktop/Training',transform=train_transform)
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset=datasets.ImageFolder(root='Desktop/Validation',transform=val_transform)
val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False)
val_num=len(val_dataset)
test_dataset=datasets.ImageFolder(root='Desktop/Test',transform=val_transform)
test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
train_steps=len(train_loader)

model = Net(init_weights=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)

# Data processing

In [5]:
epochs=10
save_path='./Net.pth'
def train(epoch):
    running_loss = 0.0
    best_acc=0.0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, target)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

    correct = 0
    total = 0
    val_bar = tqdm(val_loader, file=sys.stdout)
    with torch.no_grad():
        for data in val_bar:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            val_bar.desc = "val epoch[{}/{}]".format(epoch + 1, epochs)

    accuracy = 100 * correct / total
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' % (epoch + 1, running_loss / len(train_loader), accuracy))

    if accuracy > best_acc:
        best_acc = accuracy
        torch.save(model.state_dict(), save_path)

def test():
    correct = 0
    total = 0
    model.load_state_dict(torch.load(save_path))
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('Accuracy on test set: %.2f%%' % accuracy)

# Performance

In [6]:
if __name__=='__main__':
    for epoch in range(10):
        train(epoch)
        test()

train epoch[1/10] loss:0.311: 100%|██████████████████████████████████████████████████| 735/735 [04:35<00:00,  2.67it/s]
val epoch[1/10]: 100%|███████████████████████████████████████████████████████████████| 165/165 [00:56<00:00,  2.94it/s]
[epoch 1] train_loss: 0.457  val_accuracy: 91.794
Accuracy on test set: 93.45%
train epoch[2/10] loss:0.179: 100%|██████████████████████████████████████████████████| 735/735 [01:38<00:00,  7.49it/s]
val epoch[2/10]: 100%|███████████████████████████████████████████████████████████████| 165/165 [00:19<00:00,  8.54it/s]
[epoch 2] train_loss: 0.287  val_accuracy: 94.478
Accuracy on test set: 94.76%
train epoch[3/10] loss:0.362: 100%|██████████████████████████████████████████████████| 735/735 [01:38<00:00,  7.49it/s]
val epoch[3/10]: 100%|███████████████████████████████████████████████████████████████| 165/165 [00:19<00:00,  8.50it/s]
[epoch 3] train_loss: 0.252  val_accuracy: 94.335
Accuracy on test set: 93.89%
train epoch[4/10] loss:0.131: 100%|████████