## Constants

In [5]:
style_size = 2
content_size = 200
embedding_size = 304
max_words = 15

## Data

In [6]:
from torch.utils.data import DataLoader
from load_data import load_data, StyleDataset
from sklearn.model_selection import train_test_split

In [7]:
X, y, numerator = load_data([
    'data/trump',
    'data/musk'
], fasttext_location='C:/Users/Vadim/Documents/data/wiki.simple')
X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, random_state=420)

In [8]:
train_loader = DataLoader(StyleDataset(X_train, y_train, numerator.embeddings, sentence_size=max_words),
                          batch_size=10, shuffle=True)

## VadNet

In [9]:
import torch
from torch.nn.utils import rnn
from torch import nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable

In [10]:
debug_prints = False

def debug_print(text):
    if debug_prints:
        print(text)

In [11]:
class CycleLSTM(nn.Module):
    """LSTM that re-feeds its output into the input"""
    
    def __init__(self, hidden_size, output_size):
        super().__init__()
        
        self.lstm_cell = nn.LSTMCell(hidden_size, hidden_size)
        self.output_size = output_size
    
    def forward(self, hc):        
        output = []
        
        for i in range(self.output_size):
            h, c = hc
            hc = self.lstm_cell(h, hc)
            output.append(h)
                
        return torch.stack(output, 0).permute((1, 0, 2))

In [12]:
class LSTMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, content_size, style_size):
        super().__init__()
        
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.hidden_to_content = nn.Linear(hidden_size, content_size)
        self.cell_to_content = nn.Linear(hidden_size, content_size)
        self.hidden_to_style = nn.Linear(hidden_size, style_size)
        self.cell_to_style = nn.Linear(hidden_size, style_size)
        
        self.hidden_size = hidden_size
    
    def forward(self, text):        
        initial_hidden = torch.zeros(len(text), self.hidden_size)[None].double()
        initial_cell = torch.zeros(len(text), self.hidden_size)[None].double()
        
        if use_cuda:
            initial_hidden = initial_hidden.cuda()
            initial_cell = initial_cell.cuda()
        
        initial_hidden = Variable(initial_hidden)
        initial_cell = Variable(initial_cell)
        text = Variable(text)
        
        output, state = self.lstm(text, (initial_hidden, initial_cell))
        hidden_state, cell_state = state
        hidden_state = hidden_state[0]
        cell_state = cell_state[0]

        content = self.hidden_to_content(hidden_state) + self.cell_to_content(cell_state)
        style = self.hidden_to_style(hidden_state) + self.hidden_to_style(cell_state)
        
        debug_print(str(content.shape) + ' ' + str(style.shape))
            
        return content, style

In [13]:
class LSTMDecoder(nn.Module):
    def __init__(self, content_size, style_size, hidden_size, output_size):
        super().__init__()
        
        self.lstm = CycleLSTM(hidden_size, output_size)
        self.content_to_hidden = nn.Linear(content_size, hidden_size)
        self.content_to_cell = nn.Linear(content_size, hidden_size)
        self.style_to_hidden = nn.Linear(style_size, hidden_size)
        self.style_to_cell = nn.Linear(style_size, hidden_size)
    
    def forward(self, content, style):
        content = Variable(content)
        style = Variable(content)
        
        hidden_state = self.content_to_hidden(content) + self.style_to_hidden(style)
        cell_state = self.content_to_cell(content) + self.style_to_cell(style)
        
        output = self.lstm((hidden_state, cell_state))
        
        debug_print(output.shape)
        return output

In [14]:
use_cuda = True

In [15]:
encoder = LSTMEncoder(embedding_size, content_size + style_size, content_size, style_size).double()
decoder = LSTMDecoder(content_size, style_size, embedding_size, max_words).double()
opt = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))
vector_distance = F.mse_loss
sequence_distance = F.mse_loss # Switch to a sequence loss

if use_cuda:
    encoder = encoder.cuda()
    decoder = decoder.cuda()

RuntimeError: cuda runtime error (30) : unknown error at ..\src\THC\THCGeneral.cpp:70

In [None]:
import time
from tqdm import tqdm
from IPython.display import clear_output

def train_epoch(loss_history = None):
     # In each epoch, we do a full pass over the training data:
    start_time = time.monotonic() # < - the correct way to time things in Python
    encoder.train(True) # enable dropout / batch_norm training behavior
    decoder.train(True)
    epoch_loss = 0
    epoch_size = 0
    for true_text, true_style in tqdm(train_loader):
        if use_cuda:
            true_text = true_text.cuda()
            true_style = true_style.cuda()
        
        content, style = encoder(true_text)
        style_shifted = torch.cat((style[-1:], style[:-1]), 0)
        content_shifted = torch.cat((content[-1:], content[:-1]), 0)
        
        transfered_text = decoder(content, style_shifted)
        transfered_content, transfered_style = encoder(transfered_text)
        #transfered_style_unshifted = torch.cat((transfered_style[1:], transfered_style[:1]), 0)
        #restored_text = decoder(transfered_content, transfered_style_unshifted)

        #cycle_consistency_loss = sequence_distance(true_text, restored_text)
        #transfer_similarity_loss = sequence_distance(true_text, transfered_text)
        style_encoding_loss = vector_distance(style, true_style)
        content_diversity_loss = - vector_distance(content, content_shifted)
        autoencoder_loss = vector_distance(style_shifted, transfered_style) + vector_distance(content, transfered_content)
        
        loss = style_encoding_loss + autoencoder_loss

        epoch_loss += loss.data.numpy()
        loss.mean().backward() # loss vectors are supported
        opt.step()
        opt.zero_grad()
        epoch_size+=1
        
        if use_cuda:
            torch.cuda.empty_cache()

    epoch_loss /= epoch_size
    if loss_history:
        loss_history.append(epoch_loss)

    # Say no to visual polution!
    clear_output()

    # Then we print the results for this epoch:
    print("Epoch took {:.3f}s".format(time.monotonic() - start_time))
    print("  training loss (in-iteration):" + str(epoch_loss) + "(mean " + str(np.mean(epoch_loss)) + ")")

In [None]:
loss_history = []

In [57]:
train_epoch(loss_history)

  0%|                                                                               | 0/2768 [00:00<?, ?it/s]


CuDNNError: 8: b'CUDNN_STATUS_EXECUTION_FAILED'

In [36]:
def embed_sentence(sentence):
    # Sentence has to have spaces between all tokens
    sentence = sentence.lower()
    sentence = np.vstack([numerator.embed(word) for word in str(sentence).strip().split(' ')])
    return np.vstack([sentence, np.zeros((max_words - len(sentence), embedding_size))])

In [38]:
def unembed_sentence(sentence_embeddings):
    tokens = []
    for emb in sentence_embeddings:
        token = numerator.unembed(emb)
        
        if token == 'PAD' or token == 'EOS':
            break
        
        tokens.append(token)
        
    return ' '.join(tokens)

In [22]:
def transfer_style(sentence, target_style):
    if type(sentence) == str:
        sentence = embed_sentence(sentence)
    
    sentence = torch.DoubleTensor([sentence])
    target_style = torch.DoubleTensor([target_style])
    
    content, current_style = encoder(sentence)
    result_embeddings = decoder(content, target_style).cpu().data.numpy()
    
    return unembed_sentence(result_embeddings[0])

In [20]:
trump_style = np.array([1, 0])
musk_style = np.array([0, 1])

In [65]:
transfer_style("best nukes", musk_style)

'again--my @cwbyfn12 SOS SOS @dsf2020 @dpc1975 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6 @ff7429b1dbb84b6'

## And, finally

![Pick up](https://img.artlebedev.ru/everything/silver-lake/poo-sign.gif)

In [30]:
torch.cuda.empty_cache()

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at torch/csrc/cuda/Module.cpp:245

Release the GPU memory for everyone to use