In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import Subset, DataLoader
from torch import nn

In [2]:
image_path ='./'
train_dataset = datasets.CelebA(
    image_path, split='train',
    target_type='attr', download=True
)
valid_dataset = datasets.CelebA(
    image_path, split='valid',
    target_type='attr', download=True
)
test_dataset = datasets.CelebA(
    image_path, split='test',
    target_type='attr', download=True
)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
get_smile = lambda attr: attr[18]

transform_train = transforms.Compose([
    transforms.RandomCrop([178, 178]),
    transforms.RandomHorizontalFlip(),
    transforms.Resize([64, 64]),
    transforms.ToTensor(),
])

transform = transforms.Compose([
    transforms.CenterCrop([178, 178]),
    transforms.Resize([64, 64]),
    transforms.ToTensor(),
])

In [4]:
train_dataset = datasets.CelebA(
    image_path, split='train',
    target_type='attr', download=False,
    transform=transform_train, target_transform=get_smile
)
valid_dataset = datasets.CelebA(
    image_path, split='valid',
    target_type='attr', download=False,
    transform=transform, target_transform=get_smile
)
test_dataset = datasets.CelebA(
    image_path, split='test',
    target_type='attr', download=False,
    transform=transform, target_transform=get_smile
)
train_dataset = Subset(
    train_dataset, torch.arange(16_000)
)
valid_dataset = Subset(
    valid_dataset, torch.arange(1_000)
)

In [5]:
batch_size = 32
train_loader = DataLoader(
    train_dataset, batch_size,
    shuffle=True
)
valid_loader = DataLoader(
    valid_dataset, batch_size,
    shuffle=False
)
test_loader = DataLoader(
    test_dataset, batch_size,
    shuffle=False
)

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [7]:
class SmileClassificationNet(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=32,
                kernel_size=3, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(p=0.5),
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=3, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(p=0.5),
            nn.Conv2d(
                in_channels=64, out_channels=128,
                kernel_size=3, padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(
                in_channels=128, out_channels=256,
                kernel_size=3, padding=1
            ),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=8),
            nn.Flatten(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        ).to(device)

    def forward(self, x):
        return self.classifier(x)
    
model = SmileClassificationNet(device)

In [8]:
loss = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [9]:
def train(model, n_epochs, train_loader, valid_loader):
    loss_history_train = [0] * n_epochs
    acc_history_train = [0] * n_epochs
    loss_history_valid = [0] * n_epochs
    acc_history_valid = [0] * n_epochs

    for epoch in range(n_epochs):
        model.train()
        for x_batch, y_batch in train_loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            prediction = model(x_batch)[:, 0]
            loss_val = loss(prediction, y_batch.float())
            loss_val.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_history_train[epoch] += loss_val.item() * y_batch.size(0)
            is_correct = ((prediction >= 0.5).float() == y_batch).float()
            acc_history_train[epoch] += is_correct.sum()
        loss_history_train[epoch] /= len(train_loader.dataset)
        acc_history_train[epoch] /= len(train_loader.dataset)

        model.eval()
        with torch.no_grad():
            for x_batch, y_batch in valid_loader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                
                prediction = model(x_batch)[:, 0]
                loss_val = loss(prediction, y_batch.float())
                loss_history_valid[epoch] += loss_val.item() * y_batch.size(0)
                is_correct = ((prediction >= 0.5).float() == y_batch).float()
                acc_history_valid[epoch] += is_correct.sum()
        loss_history_valid[epoch] /= len(valid_loader.dataset)
        acc_history_valid[epoch] /= len(valid_loader.dataset)

        print(f'Epoch {epoch+1} accuracy:'
              f'{acc_history_train[epoch]:.4f} val_accuracy: '
              f'{acc_history_valid[epoch]:.4f}')
    return loss_history_train, loss_history_valid,\
           acc_history_train, acc_history_valid

In [10]:
n_epochs = 30

hist = train(model, n_epochs, train_loader, valid_loader)

Epoch 1 accuracy:0.6163 val_accuracy: 0.6430
Epoch 2 accuracy:0.6880 val_accuracy: 0.7290
Epoch 3 accuracy:0.7261 val_accuracy: 0.7550
Epoch 4 accuracy:0.7380 val_accuracy: 0.7730
Epoch 5 accuracy:0.7511 val_accuracy: 0.7660
Epoch 6 accuracy:0.7614 val_accuracy: 0.7710
Epoch 7 accuracy:0.7730 val_accuracy: 0.7950
Epoch 8 accuracy:0.7931 val_accuracy: 0.8110
Epoch 9 accuracy:0.8096 val_accuracy: 0.8190
Epoch 10 accuracy:0.8250 val_accuracy: 0.8330
Epoch 11 accuracy:0.8336 val_accuracy: 0.8520
Epoch 12 accuracy:0.8384 val_accuracy: 0.8630
Epoch 13 accuracy:0.8371 val_accuracy: 0.8620
Epoch 14 accuracy:0.8512 val_accuracy: 0.8560
Epoch 15 accuracy:0.8479 val_accuracy: 0.8730
Epoch 16 accuracy:0.8572 val_accuracy: 0.8740
Epoch 17 accuracy:0.8588 val_accuracy: 0.8680
Epoch 18 accuracy:0.8654 val_accuracy: 0.8720
Epoch 19 accuracy:0.8631 val_accuracy: 0.8690
Epoch 20 accuracy:0.8606 val_accuracy: 0.8820
Epoch 21 accuracy:0.8679 val_accuracy: 0.8880
Epoch 22 accuracy:0.8664 val_accuracy: 0.88

In [11]:
torch.save(model.state_dict(), './model.pt')