# **NATIONALITY PREDICTION**

The goal of this notebook is to create a model that can predict nationalities from name strings.

In [7]:
import pandas as pd
import numpy as np
import country_converter as coco
import mlflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchinfo import summary
from utils.data import NameNationalityData, NameNationalityDataStream
from utils.model import RNN_Nationality_Predictor
from sklearn.metrics import roc_auc_score, average_precision_score
import lightning as L
from pytorch_lightning.loggers import MLFlowLogger
from lightning.pytorch.callbacks import LearningRateMonitor, EarlyStopping
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR, ConstantLR

device: str = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# parameters
MAXIMUM_NAME_LENGTH: int = 40 # maximum number of characters
BATCH_SIZE: int = 2048 # number of training examples per batch
N_EVAL: int = 100 # evaluate loss every n batches
MAX_EPOCHS: int = 10

# hyperparameters
ARCHITECTURE = 'LSTM' # one of 'RNN', 'GRU' or 'LSTM'
EMBEDDING_DIM = 64 # number of dimensions of embedded tensor
HIDDEN_SIZE = 128 # number of neurons in hidden layer of rnn
NUM_RNN_LAYERS = 3 # number of stacked rnn layers
DROPOUT = 0.3 # dropout probability

# read country codes
with open('./data/.country_codes', 'r') as f:
    COUNTRY_CODES: list = f.read().splitlines()
print(f'Country codes: {", ".join(COUNTRY_CODES)}')

#read vocabulary (all unique characters used in the dataset)
with open('./data/.vocabulary', 'r') as f:
    VOCABULARY: str = f.read()
print(f'Vocabulary: {VOCABULARY}')

# generate country code mappings
target_class: str = 'UNregion' # see country_converter documentation on PyPI for available classes
COUNTRY_MAPPING: dict = {cc: coco.convert(names=cc, to=target_class) for cc in COUNTRY_CODES} 
print(f'Target classes: {", ".join(list(set(COUNTRY_MAPPING.values())))}')

Using mps device
Country codes: MY, CR, AZ, TM, AL, BW, MX, MO, NA, TN, AO, BG, UY, ZA, BF, NG, BD, BR, BE, CA, LY, IR, IE, KZ, FJ, EG, ID, IS, HU, IQ, FI, EE, PS, QA, PE, PR, SI, ES, HT, JO, IT, GH, PA, DE, KH, EC, ET, SY, PT, HR, JM, IL, DK, DJ, KR, HK, SV, SA, PL, RS, GE, GR, IN, HN, DZ, FR, SD, PH, SE, JP, GB, SG, RU, GT, KW, LT, BH, CL, TR, CZ, AE, CM, BI, AR, LB, LU, MD, CO, AF, CY, CN, OM, MA, MV, BN, YE, BO, AT, NL, MU, US, TW, CH, MT, NO
Vocabulary:  !#$%&()*-./:;<=ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}¡¢£¤¥¦§¨©ª«¬®¯°±´µ¶·¸º»¼½¾¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖ×ØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂăĄąĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚěĜĝĞğĠġĢģĤĥĦħĨĩĪīĬĭĮįİıĲĳĴĵĶķĸĹĺĻļĽľĿŀŁłŃńŅņŇňŉŊŋŌōŎŏŐőŒœŔŕŖŗŘřŚśŜŝŞşŠšŢţŤťŦŧŨũŪūŬŭŮůŰűŲųŴŵŶŷŸŹźŻżŽžſƀƁƂƃƄƅƆƇƈƉƊƋƌƍƎƏƐƑƒƓƔƕƖƗƘƙƚƛƜƝƞƟƠơƢƣƤƥƦƧƨƩƪƫƬƭƮƯưƱƲƳƴƵƶƸƹƺƻƼƽƾƿǀǂǅǆǍǎǏǐǑǒǓǔǕǖǗǘǙǚǛǜǝǞǟǠǡǢǣǤǥǦǧǨǩǪǫǬǭǮǯǰǳǴǵǶǷǸǹǺǻǼǽǾǿȀȁȂȃȄȅȆȇȈȉȊȋȌȍȎȏȐȑȒȓȔȕȖȗȘșȚțȜȝȞȟȠȡȢȣȤȥȦȧȨȩȪȫȬȭȮȯȰȱȲȳȴȵȶȷȸȹȺȻȼȽȾȿɀɃɄɅɆɇɈɉɊɋɌɍɎɏɐɑɒɓɔɕɖɗɘəɚɛɜɝɞɟɠɡɢɣɤɥɦɧɨɩɪɫɭɮɯɱɲɳɴɵɶɷɸɹɺɽɾɿʀʁʂʃ

### **IMPORT DATA**

- train.csv gets streamed in chunks
- val.csv will be loaded into memory as a whole
- name strings will be encoded as integer tensors where index i maps to the i-th character in the vocabulary
- zero will be used as padding index, names longer than max_name_length will be truncated
- the tensors will have a shape of (batch_size, max_name_length)
- the dataset also generates a tensor of shape (batch_size) that holds the sequence length (number of characters) of the current name
- countries will be converted to one-hot-encoded tensors of shape (batch_size, n_countries+1) where n_countries is the number of output classes in the COUNTRY_MAPPING dictionary

In [2]:
train_data = NameNationalityDataStream(
    data_file='./data/train.csv',
    chunksize=BATCH_SIZE,
    maximum_name_length=MAXIMUM_NAME_LENGTH,
    vocabulary=VOCABULARY,
    country_codes=COUNTRY_CODES,
    country_mapping=COUNTRY_MAPPING
)
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE)

In [3]:
val_data = NameNationalityData(
    data_file='./data/val.csv',
    maximum_name_length=MAXIMUM_NAME_LENGTH,
    vocabulary=VOCABULARY,
    country_codes=COUNTRY_CODES,
    country_mapping=COUNTRY_MAPPING
)
val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

Dataset has 4882176 records.


### **MODELING**

- Create simple model using character embeddings, rnn layers and a dense layer
- embedding layer maps input tensor of shape (batch_size, max_name_length) to embedding tensor of shape (batch_size, max_name_length, embedding_dim)
- the embedding tensor and sequence_lengths tensor will be used to [pack a padded batch](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html), which enables variable length inputs
- the packed sequence will be passed to the rnn layer 
- the hidden state of the last rnn layer will be used passed through a dense layer to create an output of shape (batch_size, n_countries+1), where where n_countries is the number of output classes in the COUNTRY_MAPPING dictionary

In [4]:
# PyTorch Lightning Wrapper
class LightningModelWrapper(L.LightningModule):
    def __init__(self, mlflow_logger, model, criterion):
        super().__init__()
        self.mlflow_logger = mlflow_logger

        # log model summary and model hyperparameters
        self.model = model
        with open("model_summary.txt", "w") as f:
            f.write(str(summary(self.model)))
        self.mlflow_logger.experiment.log_artifact(local_path="model_summary.txt", run_id=self.mlflow_logger.run_id)
        hyperparams = {
            "architecture": self.model.architecture,
            "embedding_dim": self.model.embedding_dim,
            "hidden_size": self.model.hidden_size,
            "num_rnn_layers": self.model.num_rnn_layers,
            "dropout": self.model.dropout
        }
        self.mlflow_logger.log_hyperparams(hyperparams)
        
        # log criterion
        self.criterion = criterion
        self.mlflow_logger.log_hyperparams({'criterion': self.criterion.__name__})

    def training_step(self, batch):
        X, y, seq_lengths = batch
        logits = self.model(X, seq_lengths)
        loss = self.criterion(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch):
        X, y, seq_lengths = batch
        logits = self.model(X, seq_lengths)
        loss = self.criterion(logits, y)
        self.log('val_loss', loss)
        self.log('micro_average_precision_score', average_precision_score(y.detach().cpu().T, logits.detach().cpu().T, average='micro'))
        self.log('micro_average_roc_auc_score', roc_auc_score(y.detach().cpu().T, logits.detach().cpu().T, multi_class='ovo'))
        return loss

    def configure_optimizers(self):
        # instantiate learning rate schedule
        warmup_steps = 1000 # number of steps for warmup
        cosine_steps = 10000 # number of steps for cosine annealing
        max_lr = 1e-2 # maximum learning rate when warmup ends and cosine annealing starts
        min_lr = 1e-5 # minimum learning rate after cosine annealing ends

        # instantiate optimizer
        optimizer = optim.AdamW(self.parameters(), lr=max_lr)
        self.mlflow_logger.log_hyperparams({'optimizer': optimizer.__class__.__name__})

        # Warmup: scales LR from 1e-3×base to base LR over `warmup_steps`
        warmup_scheduler = LinearLR(optimizer, start_factor=1e-4, end_factor=1.0, total_iters=warmup_steps)

        # Cosine annealing: decays LR from base to η_min (1e-5) over `cosine_steps`
        cosine_scheduler = CosineAnnealingLR(optimizer, T_max=cosine_steps, eta_min=min_lr)

        # Constant phase: hold the LR at eta_min.
        constant_scheduler = ConstantLR(optimizer, factor=min_lr/max_lr, total_iters=1e10)

        # Combine all three using SequentialLR.
        scheduler = SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler, constant_scheduler],
            milestones=[warmup_steps, warmup_steps + cosine_steps]
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

In [5]:
mlflow_logger = MLFlowLogger(experiment_name='Nationality Predictor', log_model=True)

# log training parameters
params = {
    "max_epochs": MAX_EPOCHS,
    "batch_size": BATCH_SIZE,
}
mlflow_logger.log_hyperparams(params)

lightning_model = LightningModelWrapper(
    mlflow_logger=mlflow_logger,
    model=RNN_Nationality_Predictor(
        input_size=len(VOCABULARY)+1,
        output_size=len(set(COUNTRY_MAPPING.values()))+1,
        architecture=ARCHITECTURE,
        embedding_dim=EMBEDDING_DIM,
        hidden_size=HIDDEN_SIZE,
        num_rnn_layers=NUM_RNN_LAYERS,
        dropout=DROPOUT
    ).to(device),
    criterion=F.binary_cross_entropy_with_logits
)

# register callbacks
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stopping = EarlyStopping('val_loss', patience=10) # patience counts checks, not steps, so changing val_check_interval in Trainer instantiation changes this behaviour
callbacks = [lr_monitor, early_stopping]
mlflow_logger.log_hyperparams({'callbacks': ', '.join([callback.__class__.__name__ for callback in callbacks])})

# instantiate trainer
trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    limit_val_batches=N_EVAL,
    val_check_interval=N_EVAL,
    log_every_n_steps=N_EVAL,
    logger=mlflow_logger,
    callbacks=callbacks,
    gradient_clip_val=1.0
)

# fit model
trainer.fit(
    model=lightning_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type                      | Params | Mode 
------------------------------------------------------------
0 | model | RNN_Nationality_Predictor | 1.1 M  | train
------------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.515     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/haukesteffen/miniconda3/envs/LearningPyTorch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/Users/haukesteffen/miniconda3/envs/LearningPyTorch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/haukesteffen/miniconda3/envs/LearningPyTorch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [6]:
# Test names for all 19 regions
test_names = {
    # Africa
    "Northern Africa": "Abdel Fattah el-Sisi", # Egypt
    "Middle Africa": "João Lourenço", # Angola
    "Western Africa": "Bola Ahmed Tinubu", # Nigeria
    "Eastern Africa": "Taye Atske Selassie", # Ethiopia
    "Southern Africa": "Cyril Ramaphosa", # South Africa

    # Asia
    "Central Asia": "Qassym-Schomart Kemeluly Toqajew", # Kazakhstan
    "Eastern Asia": "Xi Jinping", # China
    "South-Eastern Asia": "Prabowo Subianto", # Indonesia
    "Southern Asia": "Droupadi Murmu", # India
    "Western Asia": "Recep Tayyip Erdoğan", # Turkey

    # Europe
    "Northern Europe": "Ulf Kristersson", # Sweden
    "Western Europe": "Olaf Scholz", # Germany
    "Southern Europe": "Giorgia Meloni", # Italy
    "Eastern Europe": "Andrzej Sebastian Duda", # Poland

    # Americas
    "Northern America": "Donald Trump", # United States
    "Central America": "Andrés Manuel López Obrador", # Mexico
    "Caribbean": "Andrew Holness", # Jamaica
    "South America": "Luiz Inácio Lula da Silva", # Brazil

    # Oceania
    "Oceania": "Anthony Albanese" # Australia
}

# run test on test names
lightning_model.model.eval()
lightning_model.model.to(device)
tensor, length = train_data._encode_name(list(test_names.values()))
tensor = tensor.to(device)
logits = lightning_model.model(tensor, length)
countries_list = train_data._decode_country(logits)
preds = dict(zip(test_names.values(), countries_list))

# define column widths
name_width = 40
actual_width = 20
predicted_width = 20
correct_width = 10

# print output header
header = f"{'Name':<{name_width}} {'Actual Class':<{actual_width}} {'Predicted Class':<{predicted_width}} {'Correct?':<{correct_width}}"
print(header)
print("-" * (name_width + actual_width + predicted_width + correct_width))

# loop through test names and format outputs
total = 0
correct_count = 0
for actual_class, name in test_names.items():
    predicted_class = preds.get(name, "N/A")
    is_correct = predicted_class == actual_class
    correct_str = "Yes" if is_correct else "No"
    if is_correct:
        correct_count += 1
    total += 1
    row = f"{name:<{name_width}} {actual_class:<{actual_width}} {predicted_class:<{predicted_width}} {correct_str:<{correct_width}}"
    print(row)
accuracy = (correct_count / total) * 100
print(f'\nAccuracy: {accuracy:.2f}%')

Name                                     Actual Class         Predicted Class      Correct?  
------------------------------------------------------------------------------------------
Abdel Fattah el-Sisi                     Northern Africa      Northern Africa      Yes       
João Lourenço                            Middle Africa        Southern Europe      No        
Bola Ahmed Tinubu                        Western Africa       Western Africa       Yes       
Taye Atske Selassie                      Eastern Africa       Western Europe       No        
Cyril Ramaphosa                          Southern Africa      Southern Africa      Yes       
Qassym-Schomart Kemeluly Toqajew         Central Asia         Western Asia         No        
Xi Jinping                               Eastern Asia         South-eastern Asia   No        
Prabowo Subianto                         South-Eastern Asia   Western Asia         No        
Droupadi Murmu                           Southern Asia        S