In [1]:
import copy
import os
import json
from torch.utils.data import DataLoader
%load_ext autoreload
%autoreload 2

In [2]:
from data_utils import merge_lists, gen_clean

with open(os.path.join('../data/data_lyrics.json'), 'r') as fp:
    lyrics_data = json.load(fp)
with open(os.path.join('../data/data_news.json'), 'r') as fp:
    news_data = json.load(fp)
with open(os.path.join('../data/data_rap.json'), 'r') as fp:
    raw_rap_data = json.load(fp)

gen_clean_control = {'lemmatize': True, 'stop_words': False, 'remove_number': True,}
all_data = gen_clean(merge_lists(lyrics_data) + merge_lists(news_data) + merge_lists(raw_rap_data), gen_clean_control)

In [3]:
from data_utils import add_some_music, gen_pre_data_preprocession, add_some_news, dis_pre_data_preprocession, get_dev_data

gen_percentage = 0.7 # percentage for generator pretraining from rap data
music_percentage = 0.1 # percentage of music added into rap lyrics

# generator data preprocessing
rap_music = add_some_music(raw_rap_data, lyrics_data, music_percentage)
final_train, gen_pre, dis_rap_raw = gen_pre_data_preprocession(rap_music, gen_percentage, gen_clean_control)

# discriminator data preprocessing
rap_news = add_some_news(dis_rap_raw, news_data)
dis_pre = dis_pre_data_preprocession(rap_news)

# get validation data
pre_dev_percentage = 0.1
gen_pre, gen_pre_dev = get_dev_data(gen_pre, pre_dev_percentage)
dis_pre, dis_pre_dev = get_dev_data(dis_pre, pre_dev_percentage)

In [4]:
from sentence_transformers import SentenceTransformer
sen_embed = SentenceTransformer('bert-base-nli-mean-tokens')

In [6]:
from dataset import GENDataset, DISDataset, basic_collate_fn
import argparse

# parser = argparse.ArgumentParser()
# parser.add_argument('--pre-train-epochs', type=int, default=30)
# parser.add_argument('--batch-size', type=int, default=256)
# parser.add_argument('--sequence-length', type=int, default=5)
# gen_args = parser.parse_args()

gen_args = {"pre_train_epochs": 10, "batch_size": 64, "sequence_length": 10}

# new_parser = argparse.ArgumentParser()
# new_parser.add_argument('--pre-train-epochs', type=int, default=30)
# dis_args = new_parser.parse_args()

dis_args = {"pre_train_epochs": 1}

dis_batch_size = 32

# gen_pre = gen_pre[0:512]
# gen_pre_dev = gen_pre
# dis_pre_dev = dis_pre[0:256] + dis_pre[-256:]
# final_train = gen_pre
# dis_pre = dis_pre[0:256] + dis_pre[-256:]

gen_pre_data = GENDataset(gen_args, gen_pre, all_data)
gen_pre_dev_data = GENDataset(gen_args, gen_pre_dev, all_data)
final_train_data = GENDataset(gen_args, final_train, all_data)

dis_pre_data = DISDataset(dis_pre, sen_embed)
dis_pre_dev_data = DISDataset(dis_pre_dev, sen_embed)

gen_dataloader = DataLoader(gen_pre_data, batch_size=gen_args["batch_size"])
gen_dev_loader = DataLoader(gen_pre_dev_data, batch_size=1)
final_loader = DataLoader(final_train_data, batch_size=1)

dis_dataloader = DataLoader(dis_pre_data, batch_size=dis_batch_size, collate_fn=basic_collate_fn, shuffle=True)
dis_dev_loader = DataLoader(dis_pre_dev_data, batch_size=gen_args["batch_size"])

In [7]:
from model.generator import Generator
from model.discriminator import Discriminator

lstm_input_size, num_layers, lstm_hidden_dim, dropout = 128, 2, 256, 0.1
dis_hidden_dim = 1024

generator = Generator(gen_pre_data, lstm_input_size, num_layers, lstm_hidden_dim, dropout)
discriminator = Discriminator(dis_hidden_dim)

In [78]:
from train import pre_train_generator, pre_train_discriminator

device = 'cpu'
pre_patience = 10

gen_loss_type, gen_optim_type = 'cross', 'adam'
g_lr, g_weight_decay = 0.001, 0.00001

dis_loss_type, dis_optim_type = 'bce', 'adam'
d_lr, d_weight_decay = 0.001, 0.00001

generator = pre_train_generator(gen_args, generator, gen_dataloader, gen_dev_loader, gen_loss_type, gen_optim_type, g_lr, g_weight_decay, pre_patience, device)
discriminator = pre_train_discriminator(dis_args, discriminator, dis_dataloader, dis_dev_loader, dis_loss_type, dis_optim_type, d_lr, d_weight_decay, pre_patience, device)

{'epoch': 0, 'batch': 0, 'loss': 11.137024879455566}
{'epoch': 0, 'batch': 1, 'loss': 11.124754905700684}
{'epoch': 0, 'batch': 0, 'loss': 0.7305691242218018}
{'epoch': 0, 'batch': 1, 'loss': 0.6041577458381653}
{'epoch': 0, 'batch': 2, 'loss': 1.2411816120147705}
{'epoch': 0, 'batch': 3, 'loss': 0.7519977688789368}
{'epoch': 0, 'batch': 4, 'loss': 1.626331090927124}
{'epoch': 0, 'batch': 5, 'loss': 0.8829678893089294}
{'epoch': 0, 'batch': 6, 'loss': 0.7123408317565918}
{'epoch': 0, 'batch': 7, 'loss': 0.8326322436332703}
{'epoch': 0, 'batch': 8, 'loss': 1.1079089641571045}
{'epoch': 0, 'batch': 9, 'loss': 0.8852540254592896}
{'epoch': 0, 'batch': 10, 'loss': 0.7931972742080688}
{'epoch': 0, 'batch': 11, 'loss': 0.6652122735977173}
{'epoch': 0, 'batch': 12, 'loss': 0.847244381904602}
{'epoch': 0, 'batch': 13, 'loss': 0.8188858032226562}
{'epoch': 0, 'batch': 14, 'loss': 0.7762765288352966}
{'epoch': 0, 'batch': 15, 'loss': 0.753746509552002}


  loss = loss_fn(res, torch.tensor(y).float())
  y_pred.append(torch.tensor(res))


In [79]:
def get_hyper_parameters():
    _g_para_list = [{"optim_type": 'adam', 'lr': 0.01, "weight_decay": 1e-4}]
    _d_para_list = [{"optim_type": 'adam', 'lr': 0.01, "weight_decay": 1e-4}]
    _num_epoch = 40
    _patience = 10
    _max_words = 10
    _device = 'cpu'
    return _g_para_list, _d_para_list, _num_epoch, _patience, _max_words, _device

In [81]:
import itertools
from train import train_model
from data_utils import plot_loss
from generate_rap import generate_rap
import numpy as np

g_para_list, d_para_list, num_epoch, patience, max_words, device = get_hyper_parameters()

# model training
best_gen, best_dis, best_stats = copy.deepcopy(generator), copy.deepcopy(discriminator), None
best_lr, best_wd, best_bs, best_hd, best_lt, best_om = 0, 0, 0, 0, '', ''
best_dis_loss, best_gen_loss = float('-inf'), float('inf')
for g_para, d_para in itertools.product(g_para_list, d_para_list):
    g, d, stats = train_model(generator, discriminator, final_loader, final_train_data, num_epoch, g_para, d_para, gen_dev_loader, patience, max_words, device)

    # update best parameters if needed
    if np.mean(stats['dis_loss']) > best_dis_loss and np.mean(stats['gen_loss']) < best_gen_loss:
        best_dis_loss = np.mean(stats['dis_loss'])
        best_gen_loss = np.mean(stats['gen_loss'])
        best_gen, best_dis, best_stats = copy.deepcopy(g), copy.deepcopy(d) , copy.deepcopy(stats)
        best_g_para, best_d_para = g_para, d_para

    print("\n\nBest hidden dimension: {}, Best learning rate: {}, best weight_decay: {}, best batch_size: {}, best loss type： {}, best optimizer: {}".format(
    best_hd, best_lr, best_wd, best_bs, best_lt, best_om))
print("Generator loss: {:.4f}".format(best_gen_loss))
print("Discriminator loss: {:.4f}".format(best_dis_loss))
plot_loss(best_stats)

------------------------ Start Training ------------------------
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
2

KeyboardInterrupt: 

In [70]:
from generate_rap import generate_rap

sen_input = "i build a castle"
num_sentences = 10
max_words = 10

lyrics = generate_rap(best_gen, sen_input, num_sentences, max_words, final_train_data)
for sen in lyrics:
    print(sen + '.')

0
1
2
3
4
5
6
7
8
9
i build a castle nickel drainage digit simile whatsoever extra budding bisque tee moan oracle.
emotionless sprightly philosophical shelf gynecology gleeful stalk nob vive.
attain turk quizzical adroit hysterical wonderfully define acolyte grasping nope blas.
plus singled bullock clan moderation craniopagus bandit cylindrical enormously.
brimstone golfer stardom seesaw stringent quart unheated catty backfire maidan herse.
kitten stagnation honest huh till fluid meadow morello coarseness indelible.
interior tractor pedophilia anticipate migraine affirmation cutlet resin.
substandard diminishment redouble sidewalk derrick vibrantly uniting disrespect.
hump cureless chub councilman pew orthodontics scraped populace potent rightist.
quorum hymen afterglow manicure hypochondriac crisscross nightfall attention schematic utopian.
