In [21]:
import os
import cv2
%matplotlib inline
import numpy as np
import time
import torch
import torch.nn.functional as F
import pandas as pd
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchvision import transforms
import matplotlib.pyplot as plt

from dsd import DSDTraining
from utils import set_all_seed, SaveBestModel, plot_wb

In [None]:
# Config
data_path = "./data"
pretrained_path = "./pretrained"
image_size = (224, 224)
batch_size = 128
initial_lr = 0.0001


# I/ Dataset

In [22]:
class FER2013(Dataset):
    def __init__(self, stage):
        self._stage = stage
        self._image_size = image_size
        self._data = pd.read_csv(os.path.join(data_path, "{}.csv".format(stage)))
        self._pixels = self._data["pixels"].tolist()
        self._emotions = pd.get_dummies(self._data["emotion"])
        self._train_transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[0.1, 0.1, 0.1]),
            ]
        )
        self._val_test_transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0, 0, 0], std=[0.1, 0.1, 0.1]),
            ]
        )

    def __len__(self):
        return len(self._pixels)

    def __getitem__(self, idx):
        pixels = self._pixels[idx]
        pixels = list(map(int, pixels.split(" ")))
        image = np.asarray(pixels).reshape(48, 48)
        image = image.astype(np.uint8)
        image = cv2.resize(image, self._image_size)
        image = np.dstack([image] * 3)

        if self._stage == "train":
            image = self._train_transform(image)
        else:
            image = self._val_test_transform(image)

        target = self._emotions.iloc[idx].idxmax()
        return image, target


train_set = FER2013(stage="train")
val_set = FER2013(stage="val")
test_set = FER2013(stage="test")

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)


# II/ Model

In [23]:
class VGG_16(nn.Module):
    """
    VGG-face
    """

    def __init__(self):
        """
        Constructor
        """
        super().__init__()
        self.conv_1_1 = nn.Conv2d(3, 64, 3, stride=1, padding=1)
        self.conv_1_2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.conv_2_1 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv_2_2 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.conv_3_1 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv_3_2 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.conv_3_3 = nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.conv_4_1 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.conv_4_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_4_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_1 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.conv_5_3 = nn.Conv2d(512, 512, 3, stride=1, padding=1)
        self.fc6 = nn.Linear(512 * 7 * 7, 4096)
        self.fc7 = nn.Linear(4096, 4096)
        self.fc8 = nn.Linear(4096, 2622)

    def forward(self, x):
        """Pytorch forward
        Args:
            x: input image (224x224)
        Returns: class logits
        """
        x = F.relu(self.conv_1_1(x))
        x = F.relu(self.conv_1_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2_1(x))
        x = F.relu(self.conv_2_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_3_1(x))
        x = F.relu(self.conv_3_2(x))
        x = F.relu(self.conv_3_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_4_1(x))
        x = F.relu(self.conv_4_2(x))
        x = F.relu(self.conv_4_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_5_1(x))
        x = F.relu(self.conv_5_2(x))
        x = F.relu(self.conv_5_3(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc6(x))
        x = F.dropout(x, 0.7, self.training)
        x = F.relu(self.fc7(x))
        return self.fc8(x)


In [None]:
model = VGG_16()
model.load_state_dict(torch.load(f"{pretrained_path}/VGG_FACE_converted.pth"))

for param in model.parameters():
    param.requires_grad = False

model.fc6.weight.requires_grad = True
model.fc6.bias.requires_grad = True

model.fc7.weight.requires_grad = True
model.fc7.bias.requires_grad = True

model.fc8 = nn.Linear(4096, 7)
model.fc8.weight.data.normal_(mean=0.0, std=0.1)
model.fc8.bias.requires_grad = True
model.fc8.bias.data.zero_()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dsd_model = DSDTraining(model, sparsity=0.6, only_fc=True)
summary(dsd_model, (3, 224, 224))

# III/ Train

In [25]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    dsd_model.parameters(), lr=initial_lr, momentum=0.9, nesterov=True
)
save_best_model = SaveBestModel(path="./best_model.pth")

In [26]:
def train_dsd(
    dsd_model,
    EPOCH_DENSE1,
    EPOCH_SPARSE1,
    EPOCH_DENSE2,
    EPOCH_SPARSE2,
    EPOCH_DENSE3,
    NB_TRAIN_EXAMPLES,
    NB_VAL_EXAMPLES,
):
    EPOCHS = EPOCH_DENSE1 + EPOCH_SPARSE1 + EPOCH_DENSE2 + EPOCH_SPARSE2 + EPOCH_DENSE3
    train_costs, val_costs = [], []
    val_loss_decrease_counter = 0
    prev_val_loss = 0
    current_lr = initial_lr

    # Training phase.
    for epoch in range(EPOCHS):
        print(
            f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Train Acc':^12} | {'Val Loss':^10} | {'Val Acc':^10} | {'Elapsed':^9}"
        )
        print("-" * 85)

        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()

        # DSD
        if (
            epoch >= EPOCH_DENSE1
            and epoch < EPOCH_DENSE1 + EPOCH_SPARSE1
            or epoch >= EPOCH_DENSE1 + EPOCH_SPARSE1 + EPOCH_DENSE2
            and epoch < EPOCH_DENSE1 + EPOCH_SPARSE1 + EPOCH_DENSE2 + EPOCH_SPARSE2
        ):
            dsd_model.train_on_sparse = True
        else:
            dsd_model.train_on_sparse = False

        if dsd_model.train_on_sparse:
            dsd_model.update_masks()

        # ------------------------------------------------
        #                 TRAINING
        # ------------------------------------------------

        train_loss, correct_train = 0, 0
        batch_loss, correct_batch, batch_counts = 0, 0, 0

        if device == "cuda":
            dsd_model.train().cuda()

        for step, (inputs, labels) in enumerate(train_loader):
            # Load data to GPU.
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients.
            optimizer.zero_grad()

            # Forward pass.
            prediction = dsd_model(inputs)

            # Compute the loss.
            loss = criterion(prediction, labels)

            # Backward pass.
            loss.backward()

            # Sparse-phase
            if dsd_model.train_on_sparse:
                for (w, b), (mask_w, mask_b) in zip(dsd_model.layers, dsd_model.masks):
                    # Values
                    w.data[mask_w] = 0
                    b.data[mask_b] = 0
                    # Grad
                    w.grad.data[mask_w] = 0
                    b.grad.data[mask_b] = 0

            # Optimize.
            optimizer.step()

            # Compute training accuracy.
            _, predicted = torch.max(prediction.data, 1)
            correct_train += (predicted == labels).sum().item()
            correct_batch += (predicted == labels).sum().item()

            # Compute batch loss.
            batch_loss += loss.data.item() * inputs.shape[0]
            train_loss += loss.data.item() * inputs.shape[0]

            # Print the loss values and time elapsed for every 100 batches
            if (step % 100 == 0 and step != 0) or (step == len(train_loader) - 1):
                time_elapsed = time.time() - t0_batch

                print(
                    f"{epoch + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {correct_batch / batch_counts:^12.6f} | {'-':^10} | {'-':^10} |  {time_elapsed:^9.2f}"
                )
                batch_loss, correct_batch, batch_counts = 0, 0, 0
                t0_batch = time.time()

            batch_counts += inputs.shape[0]

        train_loss /= NB_TRAIN_EXAMPLES
        train_costs.append(train_loss)
        train_acc = correct_train / NB_TRAIN_EXAMPLES

        print("-" * 85)

        # ------------------------------------------------
        #                 VALIDATION
        # ------------------------------------------------

        val_loss = 0
        correct_val = 0

        if device == "cuda":
            dsd_model.eval().cuda()

        with torch.no_grad():
            for inputs, labels in val_loader:
                # Load data to GPU.
                inputs, labels = inputs.to(device), labels.to(device)

                # Forward pass.
                prediction = dsd_model(inputs)

                # Compute the loss.
                loss = criterion(prediction, labels)

                # Compute training accuracy.
                _, predicted = torch.max(prediction.data, 1)
                correct_val += (predicted == labels).sum().item()

                # Compute batch loss.
                val_loss += loss.data.item() * inputs.shape[0]

            val_loss /= NB_VAL_EXAMPLES
            val_costs.append(val_loss)
            val_acc = correct_val / NB_VAL_EXAMPLES

        time_elapsed = time.time() - t0_epoch

        # checking and updating lr if needed
        if val_loss < prev_val_loss:
            val_loss_decrease_counter = val_loss_decrease_counter + 1
        else:
            val_loss_decrease_counter = 0
        if val_loss_decrease_counter == 10:
            for g in optimizer.param_groups:
                g["lr"] = current_lr / 10
            current_lr = current_lr / 10
            print(f"New learning rate: {current_lr}")
            val_loss_decrease_counter = 0
        prev_val_loss = val_loss

        info = "[Epoch {}/{}]: train_on_sparse = {} | train-loss = {:0.6f} | train-acc = {:0.6f} | val-loss = {:0.6f} | val-acc = {:0.6f} | time_elapsed = {:0.2f}"
        print(
            info.format(
                epoch + 1,
                EPOCHS,
                dsd_model.train_on_sparse,
                train_loss,
                train_acc,
                val_loss,
                val_acc,
                time_elapsed,
            )
        )

        save_best_model(val_loss, epoch, dsd_model, optimizer, criterion)

        # Save plots.
        if epoch + 1 == EPOCH_DENSE1:
            plot_wb(dsd_model, "fer2013_dense1.png", only_fc=True)
        elif epoch + 1 == EPOCH_DENSE1 + EPOCH_SPARSE1:
            plot_wb(dsd_model, "fer2013_sparse1.png", only_fc=True)
        elif epoch + 1 == EPOCH_DENSE1 + EPOCH_SPARSE1 + EPOCH_DENSE2:
            plot_wb(dsd_model, "fer2013_dense2.png", only_fc=True)
        elif epoch + 1 == EPOCH_DENSE1 + EPOCH_SPARSE1 + EPOCH_DENSE2 + EPOCH_SPARSE2:
            plot_wb(dsd_model, "fer2013_sparse2.png", only_fc=True)
        elif epoch + 1 == EPOCHS:
            plot_wb(dsd_model, "fer2013_dense3.png", only_fc=True)

    return train_costs, val_costs


In [None]:
EPOCH_DENSE1 = 200
EPOCH_SPARSE1 = 50
EPOCH_DENSE2 = 50
EPOCH_SPARSE2 = 50
EPOCH_DENSE3 = 50
NB_TRAIN_EXAMPLES = len(train_loader.dataset)
NB_VAL_EXAMPLES = len(val_loader.dataset)

set_all_seed(42)
train_costs, val_costs = train_dsd(
    dsd_model,
    EPOCH_DENSE1,
    EPOCH_SPARSE1,
    EPOCH_DENSE2,
    EPOCH_SPARSE2,
    EPOCH_DENSE3,
    NB_TRAIN_EXAMPLES,
    NB_VAL_EXAMPLES,
)

In [None]:
plt.title("Training/Validation error")
plt.ylabel("Cost")
plt.xlabel("epoch")

plt.plot(train_costs)
plt.plot(val_costs)
plt.legend(["train-loss", "val-loss"], loc="upper right")
plt.savefig("./training-validation.png")
