In [None]:
import torch
import json
import torch.nn as nn
import numpy as np

from random import randrange

class CharacterLanguageModel(nn.Module):
    def __init__(self, num_unique_chars, hidden_state_dim=512):
        super(CharacterLanguageModel, self).__init__()

        self.gru = nn.GRU(input_size=num_unique_chars, hidden_size=hidden_state_dim, num_layers=2)
        self.linear_layer_1 = nn.Linear(hidden_state_dim, num_unique_chars)

    
    def forward(self, character_indices_tensor, h0=None):
        out, h = self.gru(character_indices_tensor, h0)
        
        out = self.linear_layer_1(out)
        return out, h.detach()
    
class CharacterLanguageModelWrapper:
    def __init__(self, context_window=300, hidden_state_dim=512):
        self.context_window = context_window
        self.hidden_state_dim = hidden_state_dim
        
    def load_data(self, file_path):
        with open(file_path, 'r') as file:
            self.training_data = file.read()
        
    def load_unique_characters(self, min_char_count=500): 
        char_counts = {}
        for c in self.training_data:
            if c not in char_counts:
                char_counts[c] = 0
            char_counts[c] += 1
        
        self.unique_characters = {c: i for i, c in enumerate([c for c, count in char_counts.items() if count > min_char_count])}
        self.num_unique_characters = len(self.unique_characters) + 1
    
    def create_model(self):
        has_cuda = torch.cuda.is_available()
        print('Has CUDA', has_cuda)
        self.device = torch.device('cuda' if has_cuda else 'cpu')
        self.model = CharacterLanguageModel(self.num_unique_characters, self.hidden_state_dim).to(self.device)
    
    def train_model(self, epochs=1, lr=0.0001):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        for epoch in range(epochs):
            offset = 0
            total_loss = 0
            
            h=None
            while offset + self.context_window + 1 < len(self.training_data):
                optimizer.zero_grad()
                
                text_in = self.training_data[offset:offset + self.context_window + 1]
                text_in_tensor = self._text_to_tensor(text_in)

                prediction, h = self.model.forward(torch.FloatTensor(text_in_tensor[:-1]).to(self.device), h)
                loss = nn.functional.cross_entropy(prediction, torch.FloatTensor(text_in_tensor[1:]).to(self.device))
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            
                offset += self.context_window
            print(f'Epoch {epoch + 1}, Train Loss {total_loss}')
            
            if epoch % 10 == 0:
                self.store_model(f'warbreaker_epoch_{epoch}')
                print(self.predict('Working on Warbreaker '))
    
    def predict(self, text, length=500):
        for i in range(length):
            text_in_tensor = self._text_to_tensor(text)
            prediction = self.model.forward(torch.FloatTensor(text_in_tensor).to(self.device))
            
            probs = nn.functional.softmax(prediction[0][-1], dim=0).detach().cpu().numpy()
            res_index = np.random.choice(range(len(probs) - 1), p=probs[:-1]/sum(probs[:-1]))
            res_char = list(self.unique_characters.keys())[res_index]
            
            text += res_char
        
        return text
    
    def store_model(self, name):
        torch.save(self.model.state_dict(), f'models/{name}.pt')
        
        with open(f'models/{name}.json', 'w') as f:
            json.dump({'unique_characters': self.unique_characters, 'num_unique_characters': self.num_unique_characters}, f)
    
    def load_model(self, name):
        with open(f'models/{name}.json') as f:
            d = json.load(f)
            self.unique_characters = d['unique_characters']
            self.num_unique_characters = d['num_unique_characters']
        
        self.create_model()

        self.model.load_state_dict(torch.load(f'models/{name}.pt'))
        self.model.eval()
        
    
    def _text_to_tensor(self, text_in):
        text_in_tensor = []
        for char in text_in:
            char_tensor = [0] * self.num_unique_characters
            char_tensor[self._get_char_index(char)] = 1
            text_in_tensor.append(char_tensor)
        return text_in_tensor
        
    
    def _get_char_index(self, char):
        return self.unique_characters[char] if char in self.unique_characters else self.num_unique_characters - 1


In [None]:
# Model training
def train():
    model = CharacterLanguageModelWrapper(hidden_state_dim=512)
    model.load_data('warbreaker.txt')
    model.load_unique_characters()
    model.create_model()
    model.train_model(epochs=100000)
train()


In [None]:
model = CharacterLanguageModelWrapper(hidden_state_dim=512)
model.load_model('queen_epoch_1770')
print(model.predict('Patrik '))
