In [10]:
from torch import nn , optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

In [11]:
surnames_df = pd.read_csv("data/surnames/surnames_with_splits.csv")

In [12]:
surnames_df.head()

Unnamed: 0,nationality,nationality_index,split,surname
0,Arabic,15,train,Totah
1,Arabic,15,train,Abboud
2,Arabic,15,train,Fakhoury
3,Arabic,15,train,Srour
4,Arabic,15,train,Sayegh


In [13]:
class Surname_vocabalory(object):
    
    def __init__(self,letter_to_ix,nationality_ix,max_surname_length):
    
        self.letter_to_ix = letter_to_ix
        self.natinality_ix = nationality_ix
        self.max_surname = max_surname_length
        
        print(self.letter_to_ix)
        print(self.natinality_ix)
        print(self.max_surname)
        
    def vectorize(self,word):
        
        surname_maxtrix = np.zeros((len(self.letter_to_ix.keys()),self.max_surname),dtype=np.float32)
        
        for index , letter in enumerate(word):
            surname_maxtrix[self.letter_to_ix[letter]][index] = 1
            
        return surname_maxtrix
    
    @classmethod
    def vocab_from_dataframe(cls,dataframe):
        
        letter_to_ix = {}
        nationality_ix = {}
        max_surname_length = 0
        for ix,row in dataframe.iterrows():
            
            if len(row.surname) > max_surname_length:
                max_surname_length = len(row.surname)
            
            for letter in row.surname:
                if letter not in letter_to_ix.keys():
                    index = len(letter_to_ix)
                    letter_to_ix[letter] = index
                    
            if row.nationality not in nationality_ix.keys():
                n_index = len(nationality_ix)
                nationality_ix[row.nationality] = n_index
        
        return cls(letter_to_ix,nationality_ix,max_surname_length)
        
        

In [14]:
class Surname_dataset(Dataset):
    def __init__(self,vocab,surnames_df):
        
        self.vocab = vocab
        self.all_dataset = surnames_df
        
        self.dataset = None
        
        class_counts = surnames_df.nationality.value_counts().to_dict()
        def sort_key(item):
            return self.vocab.natinality_ix[item[0]]
        sorted_counts = sorted(class_counts.items(), key=sort_key)
        frequencies = [count for _, count in sorted_counts]
        self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)

        self.set_split("train")
    
    def set_split(self,split):        
        self.dataset = self.all_dataset[self.all_dataset["split"] == split]
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self,index):
        
        row = self.dataset.iloc[index]
        
        name = self.vocab.vectorize(row.surname)
        nationality_ix = self.vocab.natinality_ix[row.nationality]
        
        return {"x_data":name,
                "y_data":nationality_ix}
        

In [15]:
surname_dataset = Surname_dataset(Surname_vocabalory.vocab_from_dataframe(surnames_df),surnames_df)

{'T': 0, 'o': 1, 't': 2, 'a': 3, 'h': 4, 'A': 5, 'b': 6, 'u': 7, 'd': 8, 'F': 9, 'k': 10, 'r': 11, 'y': 12, 'S': 13, 'e': 14, 'g': 15, 'C': 16, 'm': 17, 'H': 18, 'i': 19, 'K': 20, 'n': 21, 'W': 22, 's': 23, 'f': 24, 'G': 25, 'M': 26, 'l': 27, 'B': 28, 'z': 29, 'N': 30, 'I': 31, 'w': 32, 'D': 33, 'Q': 34, 'j': 35, 'E': 36, 'R': 37, 'Z': 38, 'c': 39, 'Y': 40, 'J': 41, 'L': 42, 'O': 43, '-': 44, 'P': 45, 'X': 46, 'p': 47, ':': 48, 'v': 49, 'U': 50, '1': 51, 'V': 52, 'x': 53, '/': 54, 'q': 55, 'é': 56, 'É': 57, "'": 58, 'ç': 59, 'ê': 60, 'ß': 61, 'ö': 62, 'ä': 63, 'ü': 64, 'ú': 65, 'à': 66, 'ò': 67, 'è': 68, 'ó': 69, 'ù': 70, 'ì': 71, 'Ś': 72, 'ą': 73, 'ń': 74, 'á': 75, 'ż': 76, 'Ż': 77, 'ł': 78, 'õ': 79, 'ã': 80, 'í': 81, 'ñ': 82, 'Á': 83}
{'Arabic': 0, 'Chinese': 1, 'Czech': 2, 'Dutch': 3, 'English': 4, 'French': 5, 'German': 6, 'Greek': 7, 'Irish': 8, 'Italian': 9, 'Japanese': 10, 'Korean': 11, 'Polish': 12, 'Portuguese': 13, 'Russian': 14, 'Scottish': 15, 'Spanish': 16, 'Vietnamese': 1

In [16]:
class Surname_classifier(nn.Module):
    def __init__(self,in_channels,num_classes,num_channels):
        
        super(Surname_classifier,self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv1d(in_channels=in_channels, 
                      out_channels=num_channels, kernel_size=3),
            nn.ELU(),
            nn.Conv1d(in_channels=num_channels, out_channels=num_channels, 
                      kernel_size=3, stride=2),
            nn.ELU(),
            nn.Conv1d(in_channels=num_channels, out_channels=num_channels, 
                      kernel_size=3, stride=2),
            nn.ELU(),
            nn.Conv1d(in_channels=num_channels, out_channels=num_channels, 
                      kernel_size=3),
            nn.ELU()
        )
        
        self.fc = nn.Linear(num_channels, num_classes)

    
    def forward(self,x_surname,apply_softmax=False):
        
        features = self.convnet(x_surname).squeeze(dim=2)
       
        prediction_vector = self.fc(features)

        if apply_softmax:
            prediction_vector = F.softmax(prediction_vector, dim=1)
        return prediction_vector

def generate_batches(dataset, batch_size, shuffle=True,
                     drop_last=True, device="cpu"):
    """
    A generator function which wraps the PyTorch DataLoader. It will 
      ensure each tensor is on the write device location.
    """
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name].to(device)
        yield out_data_dict    
        
def compute_accuracy(y_pred, y_target):
    y_pred_indices = y_pred.max(dim=1)[1]
    n_correct = torch.eq(y_pred_indices, y_target).sum().item()
    return n_correct / len(y_pred_indices) * 100        

In [17]:
classifier = Surname_classifier(in_channels=len(surname_dataset.vocab.letter_to_ix.keys()), 
                               num_classes=len(surname_dataset.vocab.natinality_ix.keys()),
                               num_channels=256).to('cuda')

In [18]:
# dataset.class_weights = surname_dataset.class_weights.to('cuda')

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                           mode='min', factor=0.5,
                                           patience=1)

for _ in range(50):
    
    running_loss = 0.0
    running_acc = 0.0
    test_running_acc = 0.0
    classifier.train()
    surname_dataset.set_split("train")
    batch_generator = generate_batches(surname_dataset, batch_size=256, device='cuda')
    
    for batch_index, batch_dict in enumerate(batch_generator):
        
        optimizer.zero_grad()

        # step 2. compute the output
        y_pred = classifier(batch_dict['x_data'])

        # step 3. compute the loss
        loss = loss_func(y_pred, batch_dict['y_data'])
        loss_t = loss.item()
        running_loss += (loss_t - running_loss) / (batch_index + 1)

        # step 4. use loss to produce gradients
        loss.backward()

        # step 5. use optimizer to take gradient step
        optimizer.step()
        # -----------------------------------------
        # compute the accuracy
        acc_t = compute_accuracy(y_pred, batch_dict['y_data'])
        running_acc += (acc_t - running_acc) / (batch_index + 1)
    
    classifier.eval()
    surname_dataset.set_split("val")
    batch_generator = generate_batches(surname_dataset, batch_size=256, device='cuda')
    
    for batch_index, batch_dict in enumerate(batch_generator):
        
        y_pred = classifier(batch_dict["x_data"])
        
        test_acc_t = compute_accuracy(y_pred, batch_dict['y_data'])
        test_running_acc += (test_acc_t - test_running_acc) / (batch_index + 1)
    
    print(running_acc,test_running_acc)    

29.6484375 47.200520833333336
53.32031249999999 55.924479166666664
59.49218750000001 59.440104166666664
63.47656250000001 61.458333333333336
65.76822916666667 65.4296875
68.125 66.2109375
69.45312500000001 66.66666666666667
70.40364583333331 66.08072916666667
72.29166666666666 69.07552083333333
73.25520833333334 68.1640625
74.24479166666666 69.46614583333333
75.36458333333334 71.41927083333333
76.47135416666666 70.24739583333333
76.88802083333331 71.61458333333333
78.48958333333333 71.2890625
79.14062499999997 72.52604166666667
80.5078125 72.0703125
81.02864583333334 71.15885416666667
82.3046875 72.265625
82.46093750000001 71.22395833333333
83.203125 71.80989583333333
84.41406250000001 72.78645833333333
85.58593750000003 72.59114583333333
86.11979166666663 72.52604166666667
87.18750000000003 70.96354166666667
87.33072916666667 71.484375
88.69791666666667 72.59114583333333
89.40104166666667 71.484375
90.49479166666666 71.94010416666667
91.23697916666667 71.6796875
91.38020833333336 72.1

KeyboardInterrupt: 