In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader

import gensim.downloader

print('Accessories downloaded')

from tools.generator import NLPDataset
print('Generator imported')
from tools.custom_models import RNN,LSTM
print('Models imported')
from tools.processor_functions import train_with_translate,evaluate_with_translate,generate
print('Processor imported')


Accessories downloaded
Generator imported
Models imported
Processor imported


In [7]:
!python3 training_script.py

Starting training_script.py
tool generator import done
tool model imports done
tool train import done
tool val import done
Imports done
Datasets computed
Dataloaders ready
W2v model loaded
cuda:0
The model has 21,294,576 trainable parameters
  0%|                                                | 0/384639 [00:00<?, ?it/s]899.7579345703125


In [2]:
train_ds = NLPDataset('data/train_1.pt')
val_ds = NLPDataset('data/val.pt')
print('Datasets computed')

train_loader = DataLoader(train_ds)
val_loader = DataLoader(val_ds)

Datasets computed


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

n_epochs = 10
feature_num = 300

model = RNN(input_size = feature_num,hidden_size = feature_num).to(device) 
n_hidden = 1

'''model = nn.LSTM(input_size = 300,hidden_size = 300,batch_first = True).to(device)
n_hidden = 2'''

'''model = nn.Transformer(batch_first = True,nhead=10,d_model=feature_num).to(device)
n_hidden = 0'''


lr = 10e-5
optimizer = optim.AdamW(model.parameters(), lr=lr) 
criterion = nn.MSELoss(reduction='sum')
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

cuda:0
The model has 180,600 trainable parameters


In [4]:
gs_w2v_pretrained = gensim.downloader.load('word2vec-google-news-300')

In [5]:
# RNN OR LSTM
hidden, loss = train_with_translate(model,gs_w2v_pretrained,optimizer,criterion,train_loader,-1 ,device,hidden_exists = True,n_hiddens= n_hidden,verbose = True)
evaluate_with_translate(model,gs_w2v_pretrained,optimizer,criterion,val_loader,-1,device,hidden_exists = True,n_hiddens=1,hidden = hidden) 
for epoch in range(n_epochs):
    hidden,loss =train_with_translate(model,gs_w2v_pretrained,optimizer,criterion,train_loader,epoch,device,hidden_exists = True,n_hiddens= n_hidden,hidden = hidden, verbose = True)
    evaluate_with_translate(model,optimizer,criterion,val_loader,epoch,device,hidden_exists = True,n_hiddens=1,hidden = hidden) 
    torch.save(model,f'LSTM_model_epoch={epoch}_loss={loss}.pt')
    if epoch % 10 == 0:
        #evaluate(model,optimizer,criterion,val_loader,epoch,device,hidden_exists = hiddenexists,n_hiddens=2,hidden = hidden) 
        pass
    
# Transformer
'''loss = train(model,optimizer,criterion,train_loader,-1 ,device,hidden_exists = False)
for epoch in range(n_epochs):
    loss = train(model,optimizer,criterion,train_loader,epoch,device,hidden_exists = False)
    evaluate(model,optimizer,criterion,val_loader,epoch,device,hidden_exists = False) 
    if epoch % 10 == 0:
        #evaluate(model,optimizer,criterion,val_loader,epoch,device,hidden_exists = hiddenexists,n_hiddens=2,hidden = hidden) 
        pass'''

  0%|                                                                                                     | 0/384639 [00:00<?, ?it/s]

1630.44921875


  0%|                                                                                         | 14/384639 [00:04<25:00:09,  4.27it/s]

 
  ##########PREDICTED##########:  
 involving_tugs_barges Banks_Tony_Yayo Jared_Sullinger, 
 ##########TARGET##########:  
 Chronicles no no


  0%|▏                                                                                         | 980/384639 [00:07<20:15, 315.64it/s]

58635.2890625


  0%|▏                                                                                      | 1048/384639 [00:55<30:04:20,  3.54it/s]

 
  ##########PREDICTED##########:  
 horns Vavruska hazel_eyes_flared Vavruska Vavruska hazel_eyes_flared Vavruska jarringly_incorrect Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska hazel_eyes_flared Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska HHWW jarringly_incorrect Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska hazel_eyes_flared Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska hazel_eyes_flared Vavruska Vavruska Vavruska hazel_eyes_flared Vizzutti Vavruska jarringly_incorrect hazel_eyes_flared Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska hazel_eyes_flared Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska hazel_eyes_flared Benzel Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Vavruska Schieferstein Vavruska Vavruska Vavruska Vavruska hazel_eyes_flared Vavruska, 
 ###

  1%|▍                                                                                        | 1997/384639 [00:58<14:36, 436.36it/s]

10677.0732421875


  1%|▍                                                                                       | 2042/384639 [01:07<6:40:35, 15.92it/s]

 
  ##########PREDICTED##########:  
 Renewable_Obligation Jack_Niedenthal Jack_Niedenthal Jack_Niedenthal Jack_Niedenthal Jack_Niedenthal Renewable_Obligation Jack_Niedenthal Renewable_Obligation Jack_Niedenthal Atoll Atoll Jack_Niedenthal Operation_Valkyrie Jack_Niedenthal Jack_Niedenthal Jack_Niedenthal, 
 ##########TARGET##########:  
 is incompatible with strong acids halogenated acids oxidizers when exposed newly formed hydrogen it may The The


  1%|▋                                                                                        | 2977/384639 [01:09<16:03, 396.22it/s]

8490.625


  1%|▋                                                                                       | 3064/384639 [01:16<3:42:00, 28.65it/s]

 
  ##########PREDICTED##########:  
 Cheverus Cheverus Cheverus poisoned_baits Cheverus Cheverus Cheverus ENTERPRISE Cheverus Cheverus RSA_FraudAction_Research Cheverus Cheverus, 
 ##########TARGET##########:  
 on Palestinian attitudes during the Gulf war Journal Palestine Studies 3 sur sur


  1%|▉                                                                                        | 3966/384639 [01:19<17:38, 359.52it/s]

197268.25


  1%|▉                                                                                      | 4042/384639 [04:01<94:17:14,  1.12it/s]

 
  ##########PREDICTED##########:  
 Peter_Prodromou Commissioner_Jamie_Zaninovich Commissioner_Jamie_Zaninovich Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Peter_Prodromou Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Commissioner_Jamie_Zaninovich Kate_Bostock styling_cues Kate_Bostock Kate_Bostock Bertone Dasburg Kate_Bostock Peter_Prodromou Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Peter_Prodromou Kate_Bostock Peter_Prodromou Kate_Bostock Kate_Bostock Commissioner_Jamie_Zaninovich Kate_Bostock Kate_Bostock Peter_Prodromou Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Dasburg Kate_Bostock Kate_Bostock Kate_Bostock Kate_Bostock Nordstrom.com Kate_Bostock Kate_Bostock Peter_Prodromou Kate_Bostock Kate_Bostock Das

  1%|█▏                                                                                       | 4991/384639 [04:04<17:44, 356.65it/s]

1960.3035888671875


  1%|█▏                                                                                      | 5028/384639 [04:06<1:41:55, 62.07it/s]

 
  ##########PREDICTED##########:  
 Gamebird Gamebird Gamebird, 
 ##########TARGET##########:  
 Ellis Warren Warren


  2%|█▍                                                                                       | 5996/384639 [04:08<15:50, 398.48it/s]

76156.3046875


  2%|█▎                                                                                     | 6032/384639 [05:09<42:52:01,  2.45it/s]

 
  ##########PREDICTED##########:  
 SINGAPORE_Reuters Mile_Island RIO_TINTO_RIO.L Golden_Gate_Bridge Onchan_Commissioners ID_nN######## Mile_Island Mile_Island union_Bectu ID_nN######## Mile_Island RIO_TINTO_RIO.L FRANKFURT_Reuters Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island LAGOS_Reuters Mile_Island Mile_Island Mile_Island Mile_Island Dreamspace Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Golden_Gate_Bridge Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island SINGAPORE_Reuters Mile_Island union_Bectu www.bnd.com Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island RIO_TINTO_RIO.L Telemark_Lodge Mile_Island substitute_Erik_Nevland Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island Mile_Island union_Bectu I

  2%|█▌                                                                                       | 6979/384639 [05:12<19:30, 322.78it/s]

65764.8046875


  2%|█▌                                                                                     | 7040/384639 [06:03<32:50:30,  3.19it/s]

 
  ##########PREDICTED##########:  
 Insider_Trading Insider_Trading Insider_Trading thinnish persistant Insider_Trading Insider_Trading apprehends Dispelling_rumors pound_wahoo option_expiries apprehends Insider_Trading Insider_Trading Insider_Trading Insider_Trading option_expiries Insider_Trading persistant Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading persistant persistant Notch_kimberlite Insider_Trading Insider_Trading Insider_Trading Aravali_hills Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading option_expiries schnitzeling Insider_Trading Insider_Trading option_expiries Aravali_hills trico Insider_Trading Ashwani_Gujral_Market Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading Ashwani_Gujral_Market Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading Insider_Trading uridashi_issuance Insider_Trading Insider_Tra

  2%|█▊                                                                                       | 7998/384639 [06:06<18:05, 347.09it/s]

639.5805053710938


  2%|█▊                                                                                       | 8034/384639 [06:07<47:35, 131.89it/s]

 
  ##########PREDICTED##########:  
 Michelle_Larcher_De, 
 ##########TARGET##########:  
 As


  2%|██                                                                                       | 8982/384639 [06:09<18:09, 344.73it/s]

95111.125


  2%|██                                                                                     | 9019/384639 [07:26<58:40:29,  1.78it/s]

 
  ##########PREDICTED##########:  
 NYCN Gumaa Syed_Bukhari PSRB PSRB PSRB Aysha_Bary goalkeeper_Jaime_Penedo sidestepped Mayor_Gene_Pielin Aysha_Bary PSRB Aysha_Bary Syed_Bukhari Mayor_Gene_Pielin Aysha_Bary Syed_Bukhari Syed_Bukhari sidestepped NYCN Afrikaner_nationalists sidestepped PSRB Aysha_Bary Aysha_Bary NYCN Afrikaner_nationalists goalkeeper_Jaime_Penedo Aysha_Bary PSRB goalkeeper_Jaime_Penedo sidestepped Sjodin_abduction Sjodin_abduction Sjodin_abduction PSRB goalkeeper_Jaime_Penedo PSRB PSRB sidestepped Afrikaner_nationalists Afrikaner_nationalists Aysha_Bary PSRB sidestepped PSRB goalkeeper_Jaime_Penedo Afrikaner_nationalists PSRB PSRB sidestepped sidestepped sidestepped NYCN PSRB Afrikaner_nationalists sidestepped PSRB sidestepped goalkeeper_Jaime_Penedo PSRB sidestepped goalkeeper_Jaime_Penedo sidestepped Afrikaner_nationalists sidestepped PSRB PSRB Syed_Bukhari PSRB PSRB Aysha_Bary Aysha_Bary Syed_Bukhari economist_Samir_Hleileh Niesiolowski NYCN NYCN goalkeeper_Jaime_

  3%|██▎                                                                                      | 9996/384639 [07:29<18:57, 329.36it/s]

644.109619140625


  3%|██▎                                                                                     | 10030/384639 [07:29<45:37, 136.82it/s]

 
  ##########PREDICTED##########:  
 amusement_arcade, 
 ##########TARGET##########:  
 Just


  3%|██▌                                                                                     | 10985/384639 [07:32<17:37, 353.29it/s]

83722.21875


  3%|██▍                                                                                   | 11057/384639 [08:35<31:50:47,  3.26it/s]

 
  ##########PREDICTED##########:  
 riotously_colorful Bittersweets Seo_Hee Bittersweets riotously_colorful riotously_colorful riotously_colorful Bittersweets Bittersweets riotously_colorful Bittersweets Seo_Hee riotously_colorful Bittersweets Bittersweets riotously_colorful riotously_colorful Bittersweets THE_COLOR Bittersweets THE_COLOR THE_COLOR DISTORTED Bittersweets Bittersweets riotously_colorful Bittersweets Bittersweets Seo_Hee Sparkling_Planet Bittersweets Bittersweets Bittersweets Bittersweets Bittersweets Bittersweets Bittersweets riotously_colorful Bittersweets Bittersweets Yoo Bittersweets Bittersweets Bittersweets Bittersweets Bittersweets Seo_Hee DISTORTED Bittersweets Bittersweets Bittersweets THE_COLOR Bittersweets Bittersweets riotously_colorful Bittersweets Bittersweets riotously_colorful Bittersweets Farrokhzad Bittersweets Bittersweets riotously_colorful Seo_Hee Bittersweets Bittersweets Bittersweets riotously_colorful Bittersweets Bittersweets Bittersweets Bitte

  3%|██▋                                                                                     | 11965/384639 [08:37<17:08, 362.36it/s]

636.5502319335938


  3%|██▊                                                                                     | 12041/384639 [08:38<34:11, 181.66it/s]

 
  ##########PREDICTED##########:  
 massive_sulfide, 
 ##########TARGET##########:  
 series


  3%|██▉                                                                                     | 12990/384639 [08:41<19:29, 317.70it/s]

43964.06640625


  3%|██▉                                                                                   | 13033/384639 [09:15<26:08:51,  3.95it/s]

 
  ##########PREDICTED##########:  
 Telstra_ADSL Telstra_ADSL Harness BIXI_bikes JW_Steakhouse Samoline Telstra_ADSL Telstra_ADSL Telstra_ADSL Truckstop.net Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Samoline Telstra_ADSL Telstra_ADSL Truckstop.net Telstra_ADSL Telstra_ADSL Samoline Telstra_ADSL Telstra_ADSL JW_Steakhouse Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL actively_pursuing Truckstop.net Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Telstra_ADSL Samoline, 
 ##########TARGET##########:  
 with Weldon Irvine Simone turne

  4%|███▏                                                                                    | 13988/384639 [09:18<17:28, 353.58it/s]

8978.87890625


  4%|███▏                                                                                   | 14064/384639 [09:25<4:10:12, 24.68it/s]

 
  ##########PREDICTED##########:  
 dataport dataport dataport intermittency dataport Bounty_Hunter_Duane dataport Marni_Phillips Marni_Phillips dataport CNBC_Dennis_Kneale Marni_Phillips colleagues_Dewell dataport, 
 ##########TARGET##########:  
 MTV Video Music Awards is an annual awards ceremony established in by Hawaii Hawaii


  4%|███▍                                                                                    | 14964/384639 [09:28<16:15, 378.97it/s]

649.97705078125


  4%|███▍                                                                                    | 15042/384639 [09:28<30:24, 202.54it/s]

 
  ##########PREDICTED##########:  
 tailback_Jesse_Lumsden, 
 ##########TARGET##########:  
 Early


  4%|███▋                                                                                    | 15997/384639 [09:31<15:42, 391.07it/s]

42118.7890625


  4%|███▌                                                                                   | 16000/384639 [09:46<3:45:09, 27.29it/s]


KeyboardInterrupt: 

In [None]:
torch.save(model,f'LSTM_model_epoch={epoch}_loss={loss}.pt')

In [None]:
gs_w2v_pretrained = gensim.downloader.load('word2vec-google-news-300')
model = torch.load('model.pt')

In [None]:
prompt = 'whisper words of wisdom'
device = torch.device("cpu")
hidden = torch.zeros(1,1,300).to(device)
generated_words = generate(model,gs_w2v_pretrained,prompt,10,device,hidden_exists = True,n_hiddens = 1,hidden = hidden)