This notebook trains as ResNet model for classifying emojis. The result is used for filtering out emojis that are difficult to recognize.

# Emoji Dataset Class

In [None]:
from typing import Tuple

import numpy as np
import torch
from qsr_learning.entity import emoji_names, load_emoji
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

class EmojiDataset(Dataset):
    def __init__(self, size: Tuple[int, int] = None, transform=None, max_images=0):
        super().__init__()
        self.size = size
        self.transform = transform
        self.idx2name = {}
        self.name2idx = {}
        for idx, name in enumerate(
            emoji_names if not max_images else emoji_names[:max_images]
        ):
            self.idx2name[idx] = name
            self.name2idx[name] = idx

    def __getitem__(self, idx):
        name = self.idx2name[idx]
        image = load_emoji(name, size=self.size)
        # Use black background and remove the alpha channel
        background = Image.new("RGBA", image.size, (0, 0, 0))
        image = Image.alpha_composite(background, image).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, idx

    def __len__(self):
        return len(self.idx2name)

# Training 

In [None]:
import torch.nn as nn
from munch import Munch
from torchvision.models import resnet18

from tqdm.auto import trange

def report_result(epoch, phases, result, data_loader):
    log = dict(epoch=epoch)
    for phase in phases:
        log[phase + "_loss"] = result[phase].total_loss / len(
            data_loader[phase].dataset
        )
        log[phase + "_accuracy"] = result[phase].num_correct / len(
            data_loader[phase].dataset
        )
    print(log)


def step(model, criterion, optimizer, phase, batch, result, device):
    images, targets = batch[0].to(device), batch[1].to(device)
    batch_size = images.shape[0]
    if phase == "train":
        model.train()
        model.zero_grad()
        out = model(images)
        loss = criterion(out, targets) / batch_size
        loss.backward()
        optimizer.step()
    else:
        model.eval()
        with torch.no_grad():
            out = model(images)
            loss = criterion(out, targets) / batch_size
    result[phase].total_loss += loss.item()
    result[phase].num_correct += (out.argmax(dim=-1) == targets).sum().item()


def train(config, device):
    phases = ["train", "validation"]
    data = Munch({phase: EmojiDataset(**config.data[phase]) for phase in phases})
    data_loader = Munch(
        {
            phase: DataLoader(
                data[phase],
                batch_size=config.train.batch_size,
                shuffle=True,
                num_workers=4,
            )
            for phase in phases
        }
    )
    model = resnet18(pretrained=True)
    model.fc = nn.Linear(512, len(data.train))
    model.to(device)
    criterion = nn.CrossEntropyLoss(reduction="sum")
    optimizer = torch.optim.Adam(model.parameters())
    result = Munch()
    for epoch in trange(config.train.num_epochs):
        for phase in phases:
            result[phase] = Munch()
            result[phase].total_loss = 0
            result[phase].num_correct = 0
            for batch in data_loader[phase]:
                step(model, criterion, optimizer, phase, batch, result, device)
        report_result(epoch, phases, result, data_loader)

In [None]:
config = Munch(
    data=Munch(
        train=Munch(
            size=(224, 224),
            transform=transforms.Compose(
                [
                    transforms.RandomAffine(
                        degrees=(0, 360),
                        translate=(0.345, 0.345),
                        scale=(0.16, 0.32),
                        shear=(-20.0, 20.0, -20.0, 20.0),
                    ),
                    transforms.ColorJitter(
                        brightness=0.5, contrast=0, saturation=0.05, hue=0.05
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.0130, 0.0114, 0.0095], std=[0.0995, 0.0885, 0.0788]),
                ]
            ),
            max_images=None,
        ),
        validation=Munch(
            size=(224, 224),
            transform=transforms.Compose(
                [
                    transforms.RandomAffine(
                        degrees=(0, 360),
                        translate=(0.345, 0.345),
                        scale=(0.16, 0.32),
                        shear=(-20.0, 20.0, -20.0, 20.0),
                    ),
                    transforms.ColorJitter(
                        brightness=0.5, contrast=0, saturation=0.05, hue=0.05
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.0130, 0.0114, 0.0095], std=[0.0995, 0.0885, 0.0788]),
                ]
            ),
            max_images=None,
        ),
    ),
    train=Munch(batch_size=128, num_epochs=1000),
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train(config, device=device)

# Tests

## Test the transformed images

In [None]:
def test_dataloader(loader, channel_means, channel_stds):
    #     loader = DataLoader(emoji_dataset, batch_size=1, shuffle=True)
    for batch in loader:
        break
    images = batch[0]
    image = images[0]
    questions = batch[1]
    question = emoji_dataset.idx2name[questions[0].item()]
    image = Image.fromarray(
        (
            255
            * (
                image * torch.tensor(channel_stds).view(3, 1, 1)
                + torch.tensor(channel_means).view(3, 1, 1)
            )
            .permute(1, 2, 0)
            .numpy()
        ).astype("uint8")
    )
    display(image, question)


channel_means = [0.0128, 0.0111, 0.0096]
channel_stds = [0.0981, 0.0869, 0.0797]
emoji_dataset = EmojiDataset(
    size=(224, 224),
    transform=transforms.Compose(
        [
            transforms.RandomAffine(
                degrees=(0, 360),
                translate=(0.345, 0.345),
                scale=(0.16, 0.32),
                shear=(-20.0, 20.0, -20.0, 20.0),
            ),
            transforms.ColorJitter(
                brightness=0.5, contrast=0, saturation=0.05, hue=0.05
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=channel_means, std=channel_stds),
        ]
    ),
)
loader = DataLoader(emoji_dataset, batch_size=128, shuffle=True)

test_dataloader(
    loader,
    channel_means=channel_means,
    channel_stds=channel_stds,
)

## Compute the mean and the std of the emoji images for each channel

In [None]:
emoji_dataset = EmojiDataset(
    size=(224, 224),
    transform=transforms.Compose(
        [
            transforms.RandomAffine(
                degrees=(0, 360),
                translate=(0.345, 0.345),
                scale=(0.16, 0.32),
                shear=(-20.0, 20.0, -20.0, 20.0),
            ),
            transforms.ColorJitter(
                brightness=0.5, contrast=0, saturation=0.05, hue=0.05
            ),
            transforms.ToTensor(),
        ]
    ),
)

loader = DataLoader(emoji_dataset, batch_size=1, shuffle=True)
channel_values = (
    torch.stack([img[0] for img, _ in loader], dim=0).permute(1, 0, 2, 3).reshape(3, -1)
)
print("mean:", channel_values.mean(dim=1))
print("std:", channel_values.view(3, -1).std(dim=1))

- mean = [0.1948, 0.2264, 0.1711]
- std = [0.3252, 0.3362, 0.3034]