### This notebook serves the purpose of running everithing in one go - train, save the model where you want, resume training, see the predictions etc.


In [11]:
device = 'cuda:0'

In [2]:
from baseline_preprocess_input import *
from models import *
from base_train import *


In [3]:
with open('data/sme-freecorpus.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# clean very special char
text = text.replace("¶", "").replace('•', '').replace('□', '').replace('§', '').replace('\uf03d', '').replace('π', '').replace('●', '').replace('µ', '').replace('º', '').replace('文', '').replace('中', '').replace('⅞', '').replace('½', '').replace('⅓', '').replace('¾', '').replace('¹', '').replace('³', '').replace('\t', '')
# remove numbers
text = re.sub(r'[0-9]+', '', text)
# remove russian texts (it is in data)
text = re.sub(r"[А-Яа-я]", '', text) 
# remove puctuation
text = re.sub(r"[^\w\s]", "", text) 

# encode the text 
# 1. int2char, integers to characters
# 2. char2int, characters to unique integers
chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

# encode the text
encoded = np.array([char2int[ch] for ch in text])

In [4]:
n_hidden=512
n_layers=3
# default values 
drop_prob = 0.5
lr=1
bidirectional = False

# load one of the models from models.py
model = LSTM(chars, device, bidirectional, n_hidden, n_layers)
print(model)

LSTM(
  (lstm): LSTM(224, 512, num_layers=3, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=224, bias=True)
)


In [12]:
batch_size = 128
seq_length = 300
n_epochs = 5 # small because for testing

# train the model
train_and_save(model, encoded, device, model_name='lstm3_epoch.pt', epochs=n_epochs, batch_size=batch_size, seq_length=seq_length, lr=0.0001, resume_from_saved=True, bidirectional=bidirectional)

Resuming lstm3_epoch.pt from epoch 5 ...
Epoch: 5... Loss: 1.3181...


In [13]:
opt = torch.optim.Adam(model.parameters(), lr=0.0001) 

In [14]:
from predict import * 

model, _ , _, _, _, _  = load_checkpoint("lstm3_epoch.pt", model, opt)

show_sample(model, 2000, device, prime='ja', top_k=5)

'ja gozihanbálvalusas ja dušše go dat doaibmabijuiguin dárbbašuvvojit guovddáš servodateallima ja dutkandihke stuora dasa ahte sámi geavaheami dihtejit stuorra ovddideami doaitma mearrida sámi kultuvrra servodagas \n Sámediggi go sámegillii \nOlbmot dihto doarjja loahpalaččat \n Doarjjaortnega bokte \nSámediggi lea váldá dasa ahte sámi duoji oasit dan muhto guoskevaš ovddasvástádusa birra leat máksojuvvon go guovlluin leat dásseárvosuodjalusdearvvašvuođalága  ru sámi kulturmuitosuodjalusa oahppaneavtouorga johtui ja golmma ja ođđaáigásaš bušeahttamearkkaid giellaortnega dásis maid sámediggejoavku lea maiddái doarjaga oahpahusa deaŧalaš oassin lea dohkkehuvvon geavahit sámi ásahusaid birra mii lea dáin lága mat guoská oahppoplánaide ja dáiddáriiguin mat galget galget ollu go luondduguovddážis sáhttet ovddaskas stuorra doarjaga muhtun sámegielat mat leat mii oaidnit mielde dieđut guolástusdieđáhusa dan ovdánahttinguovlluid mielas lea meroštallojuvvon mii dain dahkkojuvvon sámi dáiddariik