# Генерация названий групп

Данные

1. Датасет с именами групп
2. Фильтруем те, что хотя бы начинаются с ascii символов, чтобы не попадал бред (там мало другого, ничего все равно не выйдет)
3. Данные нужны только как обучение нграмм, поэтому по сути это строчки, нам не важно их делить по группам, просто перемешать, тогда переносы строк мы сделаем eos и потом учтем при генерации
4. Для удобства возьмем только латиницу, пробел и перенос строки (конец слова)

In [2]:
import string
import random

import numpy as np
import pandas as pd

import torch as tt
import torch.nn as nn
from torch.autograd import Variable

from tqdm.autonotebook import tqdm

In [3]:
df = pd.read_csv('./death-metal/bands.csv')
df[df['name']<= 'Z'][['name']].to_csv('bands.txt', index=None, header=None)

In [4]:
with open('bands.txt') as f:
    names = f.readlines()
    random.shuffle(names)
    names = ''.join(names)
train = names[:int(len(names)*0.7)]
valid = names[int(len(names)*0.7):]

In [5]:
all_characters = ' '+'\n'+string.ascii_letters

In [7]:
def char_tensor(string):
    tensor = tt.zeros(len(string)).long()
    for ci in range(len(string)):
        try:
            tensor[ci] = all_characters.index(string[ci])
        except:
            pass
    return tensor

def random_training_set(chunk_len, batch_size, file):
    limit = len(file) - chunk_len
    inp = tt.LongTensor(batch_size, chunk_len)
    target = tt.LongTensor(batch_size, chunk_len)
    
    for bi in range(batch_size):
        start_index = random.randint(0, limit)
        chunk = file[start_index : start_index + chunk_len + 1]
        inp[bi] = char_tensor(chunk[:-1])
        target[bi] = char_tensor(chunk[1:])
    
    return inp, target

def perplexity(x):
    return 2**x

In [8]:
def _train_epoch(inp, target, model, optimizer, criterion, curr_epoch):

    decoder.train()
    hidden = decoder.init_hidden(batch_size)
    decoder.zero_grad()
    
    train_loss = 0
    perplexities = []
    
    for ci in range(chunk_len):
        optimizer.zero_grad()
        
        output, hidden = decoder(inp[:,ci], hidden)
        loss = criterion(output.view(batch_size, -1), target[:,ci])
        perplexities.append(perplexity(loss.item()))
        
        current_loss = loss.data.cpu().detach().item()
        loss_smoothing = ci / (ci+1)
        train_loss = loss_smoothing * train_loss + (1 - loss_smoothing) * current_loss
    
    loss.backward()
    optimizer.step()
    
    result_perplexity = np.mean(perplexities)
    return train_loss, result_perplexity

def _test_epoch(inp, target, model, criterion):
    
    model.eval()
    
    epoch_loss = 0
    loss = 0
    perplexities = []
    
    hidden = decoder.init_hidden(batch_size)
    
    with tt.no_grad():
        for ci in range(chunk_len):
            output, hidden = decoder(inp[:,ci], hidden)
            loss = criterion(output.view(batch_size, -1), target[:,ci])
            perplexities.append(perplexity(loss.item()))
            epoch_loss += loss.data.item()
    
    result_perplexity = np.mean(perplexities)
    return epoch_loss / chunk_len, result_perplexity


def nn_train(model, train, valid, criterion, optimizer, n_epochs=100, scheduler=None, early_stopping=0):
    
    print('EPOCH\tValid Loss\t Train Loss\tV.Perplexity\tT.Perplexity')
    
    best_epoch = None
    prev_loss = 100500
    es_epochs = 0
    train_losses = []
    valid_losses = []
    
    for epoch in tqdm(range(n_epochs)):
        try:
            train_loss, train_per = _train_epoch(*random_training_set(chunk_len, 
                                                                      batch_size, 
                                                                      train),
                                                 model, optimizer, criterion, epoch)
            valid_loss, valid_per = _test_epoch(*random_training_set(chunk_len, 
                                                                     batch_size, 
                                                                     valid),
                                                model, criterion)
            train_losses.append(train_loss)
            valid_losses.append(valid_loss)
            
            if epoch % 100 == 0 or epoch == n_epochs-1:
                print('%s \t %.5f \t %.5f \t %.5f \t %.5f' % (str(epoch),
                                                                 valid_loss,
                                                                 train_loss,
                                                                 valid_per,
                                                                 train_per))
        except KeyboardInterrupt:
            break
        
        if early_stopping > 0:
            if valid_loss > prev_loss:
                es_epochs += 1
            else:
                es_epochs = 0
            if es_epochs >= early_stopping:
                break
            prev_loss = min(prev_loss, valid_loss)

In [9]:
class MyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        
        super(MyModel, self).__init__()
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, 1)
        self.decoder = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size

    def forward(self, input, hidden):
        batch_size = input.size(0)
        
        embed = self.encoder(input)
        output, hidden = self.rnn(embed.view(1, batch_size, -1), hidden)
        output = self.decoder(output.view(batch_size, -1))
        
        return output, hidden

    def init_hidden(self, batch_size):
        return tt.zeros(1, batch_size, self.hidden_size)

In [10]:
hidden_size = 128
batch_size = 32
chunk_len = 256 # побольше

decoder = MyModel(input_size=len(all_characters),
                  hidden_size=hidden_size, 
                  output_size=len(all_characters))

optimizer = tt.optim.Adam(decoder.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [11]:
nn_train(decoder, train, valid, criterion, optimizer, n_epochs=1000, early_stopping=100)
tt.save(decoder, 'model.pt')

EPOCH	Valid Loss	 Train Loss	V.Perplexity	T.Perplexity


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

0 	 3.87222 	 4.02740 	 14.65978 	 16.30981
100 	 2.76785 	 2.76734 	 6.92133 	 6.91242
200 	 2.76340 	 2.73453 	 6.88434 	 6.75855


  "type " + obj.__name__ + ". It won't be checked "


In [11]:
def generate(decoder, prime_str='\n', predict_len=30, temperature=0.8):
    hidden = decoder.init_hidden(1)
    prime_input = char_tensor(prime_str).unsqueeze(0)
    predicted = ''

    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[:,p], hidden)
        
    inp = prime_input[:,-1]
    
    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = tt.multinomial(output_dist, 1)[0]
        predicted_char = all_characters[top_i]
        
        if predicted and predicted_char == '\n':
            break
        else:
            predicted += predicted_char
            inp = char_tensor(predicted_char).unsqueeze(0)

    return predicted

In [13]:
for x in range(100):
    print(generate(decoder))

Imbryiad
Sawncre
Mulphaw Inster Ins
Intesionar Kargons
Abrgge Pred Cargomtenc
ISchenc Cad
Kiclerepte
Eredonst
Grormerbutardar
Edombadong
Svadaumne
Mated
Dead
Varwace
Elvear
Mornphare
Paromeredenca
Premist
If Vheadarw
ERadeadong Embre
Gerns Caind tore
Lesorta us Cemnc
Gredemstamum
Sompeca
GM Insticcar Catha
Inctredarcwomiad
AKilarge
Rorosister
Efmpspmsaderbaaring Diad  Bere
Sa Crpitcarara
Diarg Bets
Propsy Ertatar
Ivicely a Kessitler
Edtar
Eng H
Portaatare
Iploncer
IHestarbon
Saruphicage
Chatorwuariaw Diad
Aong Verist
LIdmbonic
Cumudt
Comn
Oviargemer
Graapter
Echado of Vit
Everncy
Empshadora
Gortal Baw
Prophiarbad Bicang Bupige
IGprpads
Imbombs
Mil Had
IKOprtatatar
Paermad Wher 
Gred DScyne
Sponfonc
Dear
Margomeron
Incadynem
Groceremarcad Vermsmn
Ipncrememonon
Graw
jmfemonst Dres Cromars
Etaroo
Phitariaondart
Demoran
IL Diads
Gartrol
Furppets
SInfett Darnon
Pariatar Cawne
Sarewn
Burepsl
Murggemt Crsptormen
Igrapon
Ededernads
Gurone
Sadph Ins Bade
IqYamont Epembaton
IIHmphare
Goritembare