# De-captcha

In [1]:
import os

# TODO: change ID
ID = "attn-v5"

# TODO: change device
DEVICE = "cuda:0"  # "cuda:i" or "cpu"
OMP_NUM_THREADS = 10
SEED = 42

TRAIN_TEST_PIVOT = 100000  # 0 - 100000 for train, 100000 - 120000 for test
DATASET_PATH = "./words_captcha/spec_train_val.txt"
IMAGE_DIR = "./words_captcha/"

IMAGE_SIZE = 256
BATCH_SIZE = 32
MAX_SEQ_LEN = 10
EMBEDDING_DIM = 256
HIDDEN_DIM = 512

# TODO: change epochs
START_EPOCH = 0
EPOCHS = 0
LEARNING_RATE = 1e-4

CHECKPOINT_DIR = os.path.join("./ckpts/", ID)
CHECKPOINT_NAME = "ckpt"
CHECKPOINT_NAME_ENC = "encoder"
CHECKPOINT_NAME_DEC = "decoder"
OUTPUT_DIR = os.path.join("./outputs/", ID)

In [2]:
import os
import random
import torch
import warnings

# Check if CUDA is available
if torch.cuda.is_available():
    gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {gpus}")
    device = torch.device(DEVICE)
else:
    print("No GPU available, using the CPU instead.")
    device = torch.device("cpu")
print(f"Device: {device}")

os.environ["OMP_NUM_THREADS"] = str(OMP_NUM_THREADS)

random.seed(SEED)
torch.manual_seed(SEED)

warnings.filterwarnings("ignore")

Number of GPUs: 4
Device: cuda:0


## Dataset

In [3]:
import os
from torch.utils.data import Dataset, DataLoader
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2


def load_labels(path: str):
    assert os.path.exists(path)
    labels: list[tuple[str, str]] = []
    with open(path, "r") as f:
        for idx, line in enumerate(f):
            line = line.strip()
            parts = line.split(" ")
            assert len(parts) == 2, f"Invalid line at {idx}: {line}"
            image_name, label = parts
            labels.append((image_name, label))
    return labels


class DatasetGenerator(Dataset):
    def __init__(
        self,
        image_dir: str,
        image_size: int,
        labels: list[tuple[str, str]],
        augmentation: bool = False,
        target_char_count: int = 20000,
        max_dataset_size: int = 500000,
        return_name: bool = False,
        verbose: bool = False,
    ):
        """
        :param image_dir str: Directory containing images
        :param image_size int: Size of the image
        :param labels list[tuple[str, str]]: List of tuples containing image name and label
        :param augmentation bool: Whether to apply train-time augmentations
        """
        self.image_dir = image_dir
        self.image_size = image_size
        self.return_name = return_name
        self.labels = labels.copy()

        if augmentation:
            # Augment the dataset to have at least target_char_count of each character
            # Set up data structures
            char_list: list[str] = []
            char_freq: dict[str, int] = {}
            char_images: dict[str, list[str]] = {}
            image_labels: dict[str, str] = {}
            # Count appearances of each character
            for image_name, label in labels:
                image_labels[image_name] = label
                for char in label:
                    char_freq[char] = char_freq.get(char, 0) + 1
            # Set up character list and image count
            for char in char_freq.keys():
                char_list.append(char)
                char_images[char] = set()
            char_list = sorted(char_list)
            for image_name, label in labels:
                for char in label:
                    char_images[char].add(image_name)
            for char in char_freq.keys():
                char_images[char] = list(char_images[char])
            if verbose:
                print(f"Character frequencies: {char_freq}")
                print(f"Dataset size: {len(labels)}")

            # Start augmenting the dataset until all characters have been seen at least target_char_count times or max_dataset_size is reached
            while len(self.labels) < max_dataset_size:
                if all(
                    freq >= target_char_count for freq in char_freq.values()
                ):
                    # All characters have been seen at least target_char_count times
                    break
                # Select image of lowest frequency character
                top_char_freq = sorted(
                    char_freq.items(), key=lambda x: x[1], reverse=True
                )
                top_char = [char for char, _ in top_char_freq]
                candidate_images = char_images[top_char[-1]]
                image_name = random.choice(candidate_images)
                label = image_labels[image_name]
                # Skip the image if it contains the top frequency character
                if any(char in top_char[:3] for char in label):
                    continue

                # Add the image to the dataset
                self.labels.append((image_name, label))
                # Update character frequency
                for char in label:
                    char_freq[char] += 1

            image_freq = {}
            for image_name, label in self.labels:
                image_freq[image_name] = image_freq.get(image_name, 0) + 1
            if verbose:
                print(f"Modified character frequencies: {char_freq}")
                print(f"Modified dataset size: {len(self.labels)}")
                print(
                    f"Top Image Frequencies: {sorted(image_freq.items(), key=lambda x: x[1], reverse=True)[:10]}"
                )

            # Define image transformations
            appearance_transforms = [
                A.Sharpen(alpha=(0.1, 0.4), lightness=(0.1, 0.4), p=0.2),
                A.SomeOf(
                    [
                        A.RGBShift(
                            r_shift_limit=(-13, 13),
                            g_shift_limit=(-13, 13),
                            b_shift_limit=(-13, 13),
                            p=1,
                        ),
                        A.RandomBrightnessContrast(p=1),
                        A.RandomGamma(p=1),
                    ],
                    n=1,
                    p=1,
                ),
            ]
            self.transfom = A.Compose(
                [
                    A.Resize(image_size, image_size),
                    *appearance_transforms,
                    A.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),  # Rescale [0,1] to [-1,1]
                    ToTensorV2(),
                ],
            )
        else:
            self.transfom = A.Compose(
                [
                    A.Resize(image_size, image_size),
                    A.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),  # Rescale [0,1] to [-1,1]
                    ToTensorV2(),
                ],
            )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_name, label = self.labels[idx]
        image = cv2.imread(self.image_dir + image_name + ".png")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Apply transformations
        transformed = self.transfom(image=image)
        image = transformed["image"]
        if self.return_name:
            return image, label, image_name
        return image, label


labels = load_labels(DATASET_PATH)
train_labels = labels[:TRAIN_TEST_PIVOT]
test_labels = labels[TRAIN_TEST_PIVOT:]

train_dataset = DatasetGenerator(
    image_dir=IMAGE_DIR,
    image_size=IMAGE_SIZE,
    labels=train_labels,
    augmentation=False,
    verbose=True,
)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
)
test_dataset = DatasetGenerator(
    image_dir=IMAGE_DIR,
    image_size=IMAGE_SIZE,
    labels=test_labels,
    augmentation=False,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)

# Tokenizer

In [4]:
class Tokenizer:
    PAD_TOKEN = "<pad>"
    SOS_TOKEN = "<sos>"
    EOS_TOKEN = "<eos>"

    PAD_INDEX = 0
    SOS_INDEX = 1
    EOS_INDEX = 2

    def __init__(self):
        self.word2index = {}
        self.index2word = {}
        self.n_words = 0

        self.word2index[self.PAD_TOKEN] = self.PAD_INDEX
        self.word2index[self.SOS_TOKEN] = self.SOS_INDEX
        self.word2index[self.EOS_TOKEN] = self.EOS_INDEX
        self.index2word[self.PAD_INDEX] = self.PAD_TOKEN
        self.index2word[self.SOS_INDEX] = self.SOS_TOKEN
        self.index2word[self.EOS_INDEX] = self.EOS_TOKEN
        self.n_words = 3

    def add_word(self, word: str):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.n_words += 1

    def add_sentence(self, sentence: str):
        word_list = [c for c in sentence]
        for word in word_list:
            self.add_word(word)

    def add_sentences(self, sentences: list[str]):
        for sentence in sentences:
            self.add_sentence(sentence)

    def sentence_to_tensor(
        self,
        sentence: str,
        add_sos=True,
        add_eos=True,
        add_pad=False,
        max_length=None,
    ):
        word_list = [c for c in sentence]
        if add_sos:
            word_list = [self.SOS_TOKEN] + word_list
        if add_eos:
            word_list = word_list + [self.EOS_TOKEN]
        if add_pad and max_length is not None:
            word_list = word_list[:max_length]
            word_list += [self.PAD_TOKEN] * (max_length - len(word_list))
        indices = [self.word2index[word] for word in word_list]
        return torch.tensor(indices, dtype=torch.long)

    def sentences_to_tensor(
        self,
        sentences: list[str],
        add_sos=True,
        add_eos=True,
        add_pad=False,
        max_length=None,
    ):
        tensors = []
        for sentence in sentences:
            tensor = self.sentence_to_tensor(
                sentence, add_sos, add_eos, add_pad, max_length
            )
            tensors.append(tensor)
        return torch.stack(tensors)

    def tensor_to_sentence(
        self,
        tensor: torch.Tensor,
        remove_sos=True,
        remove_eos=True,
        remove_pad=True,
    ):
        assert isinstance(
            tensor, torch.Tensor
        ), "Input must be a tensor, got {}".format(type(tensor))
        if tensor.dim() == 1:
            tensor = tensor.unsqueeze(0)
        assert tensor.dim() == 2, "Input must be a 1D or 2D tensor"
        sentences = []
        for i in range(tensor.size(0)):
            sentence = ""
            for j in range(tensor.size(1)):
                if remove_sos and tensor[i, j] == self.SOS_INDEX:
                    continue
                if remove_eos and tensor[i, j] == self.EOS_INDEX:
                    continue
                if remove_pad and tensor[i, j] == self.PAD_INDEX:
                    continue
                sentence += self.index2word[tensor[i, j].item()]
            sentences.append(sentence)
        return sentences

    def __len__(self):
        return self.n_words


tokenizer = Tokenizer()
tokenizer.add_sentences([label for _, label in train_labels])
print(f"Number of tokens: {len(tokenizer)}")

Number of tokens: 29


# Model

In [5]:
import torch
import torch.nn as nn
from torchvision import models


class DenseNet121Encoder(nn.Module):
    def __init__(self, embedding_dim: int):
        super(DenseNet121Encoder, self).__init__()
        self.embedding_dim = embedding_dim

        densenet = models.densenet121(
            weights=models.DenseNet121_Weights.DEFAULT
        )
        self.cnn = densenet.features  # (batch_size, 1024, 8, 8)
        self.fc = nn.Linear(
            1024, embedding_dim
        )  # (batch_size, 64, embedding_dim)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        """
        :param x torch.Tensor: (batch_size, 3, 256, 256)
        """
        x = self.cnn(x)  # (batch_size, 1024, 8, 8)
        x = x.view(x.size(0), 1024, 64)  # (batch_size, 1024, 64)
        x = x.permute(0, 2, 1)  # (batch_size, 64, 1024)
        x = self.fc(x)  # (batch_size, 64, embedding_dim)
        x = self.relu(x)  # (batch_size, 64, embedding_dim)
        return x


@torch.no_grad()
def test_encoder(device):
    encoder = DenseNet121Encoder(embedding_dim=EMBEDDING_DIM).to(device)
    sample_input = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE).to(
        device
    )
    sample_output = encoder(sample_input)
    print(sample_output.shape)


test_encoder(device)

torch.Size([32, 64, 256])


In [6]:
import torch
import torch.nn as nn


class Attention(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super(Attention, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.W1 = nn.Linear(embedding_dim, hidden_dim)  # For features
        self.W2 = nn.Linear(hidden_dim, hidden_dim)  # For hidden state
        self.V = nn.Linear(hidden_dim, 1)  # For attention score

    def forward(self, features: torch.Tensor, hidden: torch.Tensor):
        """
        :param features torch.Tensor: (batch_size, seq_len, embedding_dim)
        :param hidden torch.Tensor: (batch_size, hidden_dim)
        """
        hidden_with_time_axis = hidden.unsqueeze(
            1
        )  # (batch_size, 1, hidden_dim)
        score = self.V(
            torch.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        )  # (batch_size, seq_len, 1)
        attention_weights = torch.softmax(
            score, dim=1
        )  # (batch_size, seq_len, 1)
        context_vector = (
            attention_weights * features
        )  # (batch_size, seq_len, embedding_dim)
        context_vector = torch.sum(
            context_vector, dim=1
        )  # (batch_size, embedding_dim)
        return context_vector, attention_weights


@torch.no_grad()
def test_attention(device):
    attention = Attention(
        embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM
    ).to(device)
    sample_features = torch.randn(16, 10, EMBEDDING_DIM).to(device)
    sample_hidden = torch.randn(16, HIDDEN_DIM).to(device)
    sample_context, sample_attention_weight = attention(
        sample_features, sample_hidden
    )
    print(sample_context.shape, sample_attention_weight.shape)


test_attention(device)

torch.Size([16, 256]) torch.Size([16, 10, 1])


In [7]:
import torch
import torch.nn as nn


class RNNDecoder(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        hidden_dim: int,
        vocab_size: int,
    ):
        super(RNNDecoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size

        self.attention = Attention(
            embedding_dim=embedding_dim, hidden_dim=hidden_dim
        )  # (batch_size, hidden_dim), (batch_size, 1)

        self.embedding = nn.Embedding(
            vocab_size, embedding_dim
        )  # (batch_size, 1, embedding_dim)
        self.gru = nn.GRU(
            input_size=embedding_dim + embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
        )  # (batch_size, 1, hidden_dim)
        self.fc1 = nn.Linear(
            hidden_dim, hidden_dim
        )  # (batch_size, 1, hidden_dim)
        self.fc2 = nn.Linear(
            hidden_dim, vocab_size
        )  # (batch_size, 1, vocab_size)

    def forward(
        self,
        x: torch.Tensor,
        features: torch.Tensor,
        hidden: torch.Tensor,
    ):
        """
        :param x torch.Tensor: (batch_size, 1)
        :param features torch.Tensor: (batch_size, 1, embedding_dim)
        :param hidden torch.Tensor: (batch_size, hidden_dim)
        """
        context, attention_weight = self.attention(
            features, hidden
        )  # (batch_size, embedding_dim), (batch_size, 1)
        x = self.embedding(x)  # (batch_size, 1, embedding_dim)
        context = context.unsqueeze(1)  # (batch_size, 1, embedding_dim)
        x = torch.cat(
            [context, x], dim=2
        )  # (batch_size, 1, embedding_dim + embedding_dim)
        x, hidden = self.gru(
            x, hidden.unsqueeze(0)
        )  # (batch_size, 1, hidden_dim), (1, batch_size, hidden_dim)
        x = self.fc1(x)  # (batch_size, 1, hidden_dim)
        x = x.view(-1, x.size(2))  # (batch_size * 1, hidden_dim)
        x = torch.relu(x)  # (batch_size * 1, hidden_dim)
        x = self.fc2(x)  # (batch_size * 1, vocab_size)
        hidden = hidden.squeeze(0)
        return x, hidden, attention_weight


@torch.no_grad()
def test_decoder(device):
    decoder = RNNDecoder(
        embedding_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        vocab_size=len(tokenizer),
    ).to(device)
    sample_input = torch.randint(0, len(tokenizer), (16, 1)).to(device)
    sample_features = torch.randn(16, 1, EMBEDDING_DIM).to(device)
    sample_hidden = torch.randn(16, HIDDEN_DIM).to(device)
    sample_output, sample_hidden, sample_attention_weight = decoder(
        sample_input, sample_features, sample_hidden
    )
    print(
        sample_output.shape,
        sample_hidden.shape,
        sample_attention_weight.shape,
    )


test_decoder(device)

torch.Size([16, 29]) torch.Size([16, 512]) torch.Size([16, 1, 1])


In [8]:
import torch
import torch.nn as nn


class AttnBased(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        hidden_dim: int,
        vocab_size: int,
    ):
        super(AttnBased, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size

        self.encoder = DenseNet121Encoder(embedding_dim=embedding_dim)
        self.decoder = RNNDecoder(
            embedding_dim=embedding_dim,
            hidden_dim=hidden_dim,
            vocab_size=vocab_size,
        )

    def forward(
        self,
        images: torch.Tensor,
        targets: torch.Tensor,
    ):
        """
        :param images torch.Tensor: (batch_size, 3, 256, 256)
        :param targets torch.Tensor: (batch_size, seq_len)
        :rtype: torch.Tensor
        :return: (batch_size, seq_len, vocab_size)
        """
        device = images.device
        batch_size = images.size(0)
        seq_len = targets.size(1)

        hidden = torch.zeros(batch_size, self.hidden_dim).to(device)
        outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(device)
        outputs[:, 0, 1] = 1  # Set SOS token

        features = self.encoder(images)  # (batch_size, 64, embedding_dim)
        for t in range(seq_len - 1):
            output, hidden, _ = self.decoder(
                targets[:, t].unsqueeze(1), features, hidden
            )
            outputs[:, t + 1] = output
        return outputs

    def inference(self, images: torch.Tensor, max_length: int = 10):
        """
        :param images torch.Tensor: (batch_size, 3, 256, 256)
        :param max_length int: Maximum length of the output sequence
        :rtype: torch.Tensor
        :return: (batch_size, seq_len, vocab_size)
        """
        device = images.device
        batch_size = images.size(0)

        hidden = torch.zeros(batch_size, self.hidden_dim).to(device)
        outputs = torch.zeros(batch_size, max_length, self.vocab_size).to(
            device
        )
        outputs[:, 0, 1] = 1

        features = self.encoder(images)  # (batch_size, 64, embedding_dim)
        for t in range(max_length - 1):
            output, hidden, _ = self.decoder(
                outputs[:, t].argmax(1).unsqueeze(1), features, hidden
            )
            outputs[:, t + 1] = output
        return outputs


@torch.no_grad()
def test_model(device):
    model = AttnBased(
        embedding_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        vocab_size=len(tokenizer),
    ).to(device)
    sample_images = torch.randn(16, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)
    sample_targets = torch.randint(0, len(tokenizer), (16, MAX_SEQ_LEN)).to(
        device
    )
    sample_outputs = model(sample_images, sample_targets)
    print(sample_outputs.shape)


test_model(device)

torch.Size([16, 10, 29])


In [9]:
model = AttnBased(
    embedding_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    vocab_size=len(tokenizer),
).to(device)

trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad
)
params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params}")
print(f"Total parameters: {params}")

Trainable parameters: 9471902
Total parameters: 9471902


# Training

In [10]:
import torch
from torch.functional import F


def loss_function(
    predict: torch.Tensor,
    target: torch.Tensor,
):
    """
    :param predict torch.Tensor: (batch_size, seq_len, vocab_size)
    :param target torch.Tensor: (batch_size, seq_len)
    """
    mask = target != 0
    predict = predict.view(
        -1, predict.size(2)
    )  # (batch_size * seq_len, vocab_size)
    target = target.view(-1)  # (batch_size * seq_len)
    mask = mask.view(-1)  # (batch_size * seq_len)
    loss = F.cross_entropy(predict, target, reduction="none")
    loss = loss * mask
    return loss.sum() / mask.sum()


@torch.no_grad()
def test_loss_function():
    sample_predict = tokenizer.sentences_to_tensor(
        ["helo", "wo"], add_pad=True, max_length=MAX_SEQ_LEN
    )  # (2, 10)
    sample_predict = F.one_hot(
        sample_predict, num_classes=len(tokenizer)
    ).float()  # (2, 10, 29)
    sample_target = tokenizer.sentences_to_tensor(
        ["helo", "wo"], add_pad=True, max_length=MAX_SEQ_LEN
    )  # (2, 10)
    sample_loss = loss_function(sample_predict, sample_target)
    print(sample_loss)


test_loss_function()

tensor(2.4249)


In [11]:
import os
import torch
import torch.nn as nn


def save_checkpoint(epoch, model, optimizer, checkpoint_dir, checkpoint_name):
    """
    Save a checkpoint to a specified directory.

    :param epoch: The epoch number to save in the checkpoint.
    :param model: The model to save in the checkpoint.
    :param optimizer: The optimizer to save in the checkpoint.
    :param checkpoint_dir: The directory to save the checkpoint in.
    :param checkpoint_name: The name of the checkpoint file, will be appended with the epoch number.
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(
        checkpoint_dir, f"{checkpoint_name}_{epoch:03d}.pt"
    )
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        checkpoint_path,
    )
    print(f"Saved checkpoint for epoch {epoch} at {checkpoint_path}")


def load_checkpoint(
    model,
    checkpoint_dir,
    checkpoint_name,
    optimizer=None,
    epoch=None,
    device=None,
) -> bool:
    """
    Load a checkpoint from a specified directory.

    :param model: The model to load the checkpoint into.
    :param checkpoint_dir: The directory to search for the checkpoint.
    :param checkpoint_name: The name of the checkpoint file, will be appended with the epoch number.
    :param optimizer: The optimizer to load the checkpoint into. If None, the optimizer is not loaded.
    :param epoch: The epoch to load the checkpoint from. If None, the latest checkpoint is loaded.
    If no checkpoint is found, the function prints a message and returns.
    """
    if epoch is None:
        if not os.path.exists(checkpoint_dir):
            print("No checkpoint found.")
            return False
        # Search for the latest checkpoint
        files = os.listdir(checkpoint_dir)
        files = [
            file
            for file in files
            if file.startswith(checkpoint_name) and file.endswith(".pt")
        ]
        if not files:
            print("No checkpoint found.")
            return False
        files.sort()
        checkpoint_path = os.path.join(checkpoint_dir, files[-1])
    else:
        checkpoint_path = os.path.join(
            checkpoint_dir, f"{checkpoint_name}_{epoch:03d}.pt"
        )
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint for epoch {epoch} not found.")
        return False
    checkpoint = torch.load(
        checkpoint_path, weights_only=False, map_location=device
    )
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    print(f"Loaded checkpoint from {checkpoint_path}")
    return True

In [12]:
import math
import time
import torch
from torch import optim


# Training step function
def train_step(
    model: nn.Module,
    optimizer: optim.Optimizer,
    loss_funcion: nn.Module,
    images: torch.Tensor,
    targets: torch.Tensor,
):
    model.train()  # Set model to training mode
    optimizer.zero_grad()  # Zero out gradients
    outputs = model(images, targets)  # Forward pass
    loss = loss_funcion(outputs, targets)  # Compute loss
    loss_metric = loss.item()
    loss.backward()  # Backward pass
    optimizer.step()  # Update weights
    return loss_metric


def evaluate(
    model: nn.Module,
):
    match_count = 0
    total_count = 0
    for idx, (images, labels) in enumerate(test_loader):
        images = images.to(device)
        model.eval()
        with torch.no_grad():
            outputs = model.inference(images, max_length=MAX_SEQ_LEN)
            sentence = tokenizer.tensor_to_sentence(outputs.argmax(2))
        for pred, label in zip(sentence, labels):
            total_count += 1
            if pred == label:
                match_count += 1
    return match_count / total_count


optimizer = None

if EPOCHS > 0:
    # Initialize
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    if START_EPOCH > 0:
        # Load checkpoint
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        load_checkpoint(
            model,
            CHECKPOINT_DIR,
            CHECKPOINT_NAME,
            optimizer,
            START_EPOCH,
            device,
        )

train_loss_list = []
for epoch in range(START_EPOCH + 1, START_EPOCH + EPOCHS + 1):
    start_time = time.time()
    train_loss = 0

    for idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        targets = tokenizer.sentences_to_tensor(
            labels,
            add_sos=True,
            add_eos=True,
            add_pad=True,
            max_length=MAX_SEQ_LEN,
        ).to(device)
        loss = train_step(model, optimizer, loss_function, images, targets)
        train_loss += loss

        if math.isnan(loss) or math.isinf(loss):
            print("Loss is {}, stopping training".format(loss))

        if idx % 50 == 0:
            print(
                "epoch {:03d}/{:03d}, batch {:03d}/{:03d}, loss {:.4f}".format(
                    epoch,
                    START_EPOCH + EPOCHS,
                    idx + 1,
                    len(train_loader),
                    loss,
                )
            )

    if math.isnan(loss) or math.isinf(loss):
        print(f"Loss is {loss}, stopping training")

    precision = evaluate(model)

    avg_loss = train_loss / len(train_loader)
    print(
        "epoch {:03d}/{:03d}, loss {:.4f}, precision {:.4f}, time {:.2f}s".format(
            epoch,
            START_EPOCH + EPOCHS,
            avg_loss,
            precision,
            time.time() - start_time,
        )
    )
    save_checkpoint(epoch, model, optimizer, CHECKPOINT_DIR, CHECKPOINT_NAME)

## Record

```
epoch 001/020, loss 0.6165, precision 0.9760, time 671.98s
epoch 002/020, loss 0.4132, precision 0.9884, time 668.52s
epoch 003/020, loss 0.4099, precision 0.9909, time 667.80s
epoch 004/020, loss 0.4090, precision 0.9935, time 663.10s
```

# Predict

In [13]:
epoch = None
load_checkpoint(
    model, CHECKPOINT_DIR, CHECKPOINT_NAME, epoch=epoch, device=device
)

Loaded checkpoint from ./ckpts/attn-v5/ckpt_004.pt


True

In [14]:
from IPython import get_ipython


def evaluate():
    test_dataset = DatasetGenerator(
        image_dir=IMAGE_DIR,
        image_size=IMAGE_SIZE,
        labels=test_labels,
        augmentation=False,
        return_name=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    output_path = os.path.join(OUTPUT_DIR, f"output.txt")

    match_count = 0
    total_count = 0
    with open(output_path, "w") as f:
        for idx, (images, labels, names) in enumerate(test_loader):
            images = images.to(device)
            model.eval()
            with torch.no_grad():
                outputs = model.inference(images, max_length=MAX_SEQ_LEN)
                sentence = tokenizer.tensor_to_sentence(outputs.argmax(2))
            for pred, label, name in zip(sentence, labels, names):
                total_count += 1
                if pred == label:
                    match_count += 1
                f.write(f"{name} {pred}\n")
    return match_count / total_count


ipy = get_ipython()
if ipy is not None:
    score = evaluate()
    print(f"Precision: {score}")

Precision: 0.99345


# Report

First, I have tried to extract the features with densenet and train only on the features.
However, the accuracy is only less than 1%.

Then, I have tried to also train the cnn model, and the accuracy is increased to 99.35%.

I have ported the code from tensorflow to pytorch as a practice.

I found that the shape written in Lab12-2_slide is wrong. The output shape of the context vector in attention layer should be (batch_size, embedding_dim) instead of (batch_size, hidden_dim).

I trained the model only for 4 epochs because it is enough to get the accuracy of 99.35%.
And I train in the extracted python script so there is no training output in this notebook.