In [1]:
import torch
import pandas as pd
import json
from torchtext.vocab import vocab
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torch.nn import CosineSimilarity as cosine
from reimsiam import ReimsGen

### Load Vocab

In [2]:
with open("hyperparams.json", "r") as f:
    hyperparams = json.load(f)
    
letters_vocab = torch.load("letters_vocab")

### Creating Words Vocab

In [3]:
hyperparams

{'numnegs': 3,
 'maxlen': 10,
 'model_name': 'Siamnet_params',
 'model_name_final': 'reims_gen_final',
 'words_vocab': 'words_vocab',
 'letters_vocab': 'letters_vocab',
 'model_name_omnistanza': 'Siamnet_params_omnistanza',
 'letters_vocab_omnistanza': 'letters_vocab_omnistanza'}

In [4]:
reims_gen = ReimsGen(hyperparams["model_name"], letters_vocab, hyperparams["maxlen"])

reims_gen.preprocess_data("reims.csv", 100)

hyperparams["model_name_final"] = "reims_gen_final"
hyperparams["words_vocab"] = "words_vocab"
hyperparams["letters_vocab"] = "letters_vocab"

In [5]:
torch.save(reims_gen.state_dict(), hyperparams["model_name_final"])
torch.save(reims_gen.words_vocab, hyperparams["words_vocab"])

In [17]:
reims_gen.generate_reim("Schaum")

['raum', 'saum', 'traum', 'flaum', 'baum']

In [18]:
reims_gen.generate_reim("planete")

['beete', 'wehe', 'gebete', 't√§te', 'we']

In [24]:
with open("hyperparams.json", "w") as f:
    json.dump(hyperparams, f)

## Omnistanza

In [10]:
with open("hyperparams.json", "r") as f:
    hyperparams = json.load(f)
    
letters_vocab = torch.load('letters_vocab_omnistanza')

In [11]:
omni_gen = ReimsGen(hyperparams['model_name_omnistanza'], letters_vocab, hyperparams["maxlen"])

omni_gen.preprocess_data("reims_pairs.csv", 100)

hyperparams["omnistanza_final"] = "reims_omni_gen"
hyperparams["omnistanza_words_vocab"] = "omni_vocab"

In [12]:
omni_gen.generate_reim("Schaum")

['wellenschaum', 'champagnerschaum', 'seifenschaum', 'flutenschaum', 'pflaum']

In [13]:
omni_gen.generate_reim("planete")

['anbete', 'klagete', 'agnette', 'tapete', 'annette']

In [14]:
torch.save(omni_gen.state_dict(), hyperparams["omnistanza_final"])
torch.save(omni_gen.words_vocab, hyperparams["omnistanza_words_vocab"])

In [15]:
hyperparams

{'numnegs': 3,
 'maxlen': 10,
 'model_name': 'Siamnet_params',
 'model_name_final': 'reims_gen_final',
 'words_vocab': 'words_vocab',
 'letters_vocab': 'letters_vocab',
 'model_name_omnistanza': 'Siamnet_params_omnistanza',
 'letters_vocab_omnistanza': 'letters_vocab_omnistanza',
 'omnistanza_final': 'reims_omni_gen',
 'omnistanza_words_vocab': 'omni_vocab'}

In [16]:
with open("hyperparams.json", "w") as f:
    json.dump(hyperparams, f)