-
Notifications
You must be signed in to change notification settings - Fork 0
/
HITL_in_GANs.py
429 lines (344 loc) · 17.6 KB
/
HITL_in_GANs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
import torch
from torch import nn
from torchvision import datasets, transforms
import math
import time
import logging
import matplotlib.pyplot as plt
import itertools
import numpy as np
from tqdm import tqdm
import torchvision.utils as vutils
import os
import textwrap
import torch.optim as optim
#optuna
import optuna
from optuna.trial import TrialState
from optuna.artifacts import FileSystemArtifactStore
from optuna.artifacts import upload_artifact
#optuna dashboard packages
from optuna_dashboard import save_note, register_objective_form_widgets, ChoiceWidget
from optuna_dashboard.artifact import get_artifact_path
torch.manual_seed(111)
device = "cuda" if torch.cuda.is_available() else "cpu"
#function to get the training loader
def get_mnist_loaders(train_batch_size, test_batch_size):
"""
The function `get_mnist_loaders` returns data loaders for the MNIST dataset with specified batch
sizes for training and testing.
:param train_batch_size: The `train_batch_size` parameter specifies the batch size for the training
data loader, which determines the number of samples in each batch during training. This parameter
controls how many samples are processed in each iteration of the training loop
:param test_batch_size: The `test_batch_size` parameter in the `get_mnist_loaders` function refers
to the batch size used for loading the test dataset in the MNIST dataset. This parameter determines
how many samples are loaded and processed in each iteration during testing or evaluation of the
model. It helps in controlling the
:return: The function `get_mnist_loaders` returns two data loaders - `train_loader` and
`test_loader` for the MNIST dataset.
"""
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])),
batch_size=train_batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])),
batch_size=test_batch_size, shuffle=True)
return train_loader, test_loader
#architecture for the discriminator
# The Discriminator class defines a neural network model for binary classification tasks.
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(),
nn.Dropout(0.2),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = x.view(x.size(0), 784)
output = self.model(x)
return output
# The Generator class defines a neural network architecture for generating images in a Generative
# Adversarial Network (GAN) using fully connected layers and activation functions.
#architecture for the generator
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(128, 256),
nn.LeakyReLU(),
nn.Linear(256, 512),
nn.LeakyReLU(),
nn.Linear(512, 1024),
nn.LeakyReLU(),
nn.Linear(1024, 784),
nn.Tanh(),
)
def forward(self, x):
output = self.model(x)
output = output.view(x.size(0), 1, 28, 28)
return output
#function to train the discriminator
"""
The function `train_discriminator` is used to train a discriminator model using real and fake images
with corresponding labels, calculating the loss and updating the model parameters.
:param discriminator: The `discriminator` parameter refers to the neural network model that acts as
the discriminator in a GAN (Generative Adversarial Network). Its role is to distinguish between real
and fake images
:param images: Images are the real images that are fed into the discriminator for training
:param real_labels: The `real_labels` parameter likely represents the labels for real images in a
binary classification task. These labels are used to train the discriminator to distinguish between
real and fake images. In a typical GAN setup, real labels are usually set to 1 to indicate that the
images are real
:param fake_images: Fake images generated by the generator model
:param fake_labels: The `fake_labels` parameter in the `train_discriminator` function is typically a
tensor containing the labels for the fake images generated by the generator. These labels are used
to train the discriminator to distinguish between real and fake images. The discriminator will try
to predict these labels for the fake images generated by generator.
:param criterion: The `criterion` parameter in the `train_discriminator` function is typically a
loss function that calculates the loss between the discriminator's output and the target labels.
Common choices for the criterion in binary classification tasks include Binary Cross Entropy Loss or
Mean Squared Error Loss.
:param d_optimizer: The `d_optimizer` parameter in the `train_discriminator` function is typically
an optimizer object that is used to update the parameters of the discriminator neural network during
training. This optimizer is responsible for updating the weights of the discriminator based on the
computed gradients of the loss function.
:return: The function `train_discriminator` returns the discriminator loss (`d_loss`), the scores
for real images (`real_score`), and the scores for fake images (`fake_score`).
"""
def train_discriminator(discriminator, images, real_labels, fake_images, fake_labels, criterion, d_optimizer):
discriminator.zero_grad()
outputs = discriminator(images)
real_loss = criterion(outputs, real_labels.unsqueeze(1))
real_score = outputs
outputs = discriminator(fake_images)
fake_loss = criterion(outputs, fake_labels.unsqueeze(1))
fake_score = outputs
d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss, real_score, fake_score
#function to train the generator
def train_generator(generator, discriminator_outputs, real_labels, criterion, g_optimizer):
"""
The function `train_generator` updates the generator model based on the discriminator outputs and
real labels using a specified criterion and optimizer.
:param generator: The `generator` parameter is typically a neural network model that generates fake
data, such as images or text, in a generative adversarial network (GAN) setup. The generator takes
random noise as input and generates data that is intended to resemble real data.
:param discriminator_outputs: The `discriminator_outputs` parameter in the `train_generator`
function represents the outputs generated by the discriminator model when it processes the generated
samples from the generator. These outputs are then used to calculate the loss for the generator
during the training process.
:param real_labels: The `real_labels` parameter typically refers to the labels assigned to real data
samples. In the context of training a GAN (Generative Adversarial Network), `real_labels` would
usually be a tensor containing the target labels for real data samples.
:param criterion: The `criterion` parameter in the `train_generator` function is typically a loss
function that calculates the loss between the discriminator outputs and the real labels. This loss
function is used to compute the loss for the generator during training. Common loss functions used
in GANs include Binary Cross Entropy.
:param g_optimizer: The `g_optimizer` parameter in the `train_generator` function is an optimizer
object that is used to update the parameters of the generator neural network during training. It is
typically an instance of an optimizer class such as `torch.optim.Adam` or `torch.optim.SGD` in
PyTorch.
:return: the generator loss after performing a backward pass and updating the generator's parameters
using the optimizer.
"""
generator.zero_grad()
g_loss = criterion(discriminator_outputs, real_labels.unsqueeze(1))
g_loss.backward()
g_optimizer.step()
return g_loss
# Plot grid of 9 images from generator after each epoch
def generate_new_images(generator, sample_images, latent_dim, img_dir):
"""
The function generates new images using a given generator model and saves them to a specified
directory.
:param generator: The `generator` parameter is typically a neural network model that takes random
noise as input and generates fake images as output. It is commonly used in generative adversarial
networks (GANs) to create new images that resemble the training data.
:param sample_images: The `sample_images` parameter in the `generate_new_images` function represents
the number of images you want to generate. In this case, it is used to specify that you want to
generate 15 new images.
:param latent_dim: The `latent_dim` parameter typically represents the dimensionality of the latent
space, which is the space in which the generator model generates new images. It is essentially the
size of the input noise vector that is used as input to the generator to produce images. In the
context of your function `generate_new.
:param img_dir: The `img_dir` parameter in the `generate_new_images` function is a string that
represents the directory path where you want to save the generated images. This parameter specifies
the location where the generated images will be saved as output
"""
fixed_noise = torch.randn(sample_images, latent_dim).to(device) # Sample 15 images
fake_images = generator(fixed_noise).to(device)
plt.figure(figsize=(5, 5))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(
np.transpose(
vutils.make_grid(fake_images, nrow=5, padding=1, normalize=True).cpu().numpy(),
(1, 2, 0)
)
)
plt.savefig(img_dir)
plt.show()
plt.close()
#function to train GANs
def train_GANs(study: optuna.Study,
artifact_store: FileSystemArtifactStore):
"""
The function `train_GANs` trains Generative Adversarial Networks (GANs) using Optuna for
hyperparameter optimization and saves generated images and training results as artifacts.
:param study: The `study` parameter in the `train_GANs` function is an instance of `optuna.Study`.
It is used to manage and store optimization results during the hyperparameter search process. The
`study` object provides methods for suggesting hyperparameters, tracking trials, and managing the
optimization process.
:type study: optuna.Study.
:param artifact_store: The `artifact_store` parameter in the `train_GANs` function is of type
`FileSystemArtifactStore`. This parameter is used to store artifacts such as generated images or
model checkpoints during the training process. It provides a way to save and retrieve these
artifacts for later analysis or evaluation.
:type artifact_store: FileSystemArtifactStore.
:return: The function `train_GANs` returns the generator loss (`g_loss.item()`) and discriminator
loss (`d_loss.item()`).
"""
trial = study.ask() #start a trial
print(f"running trial number: {trial.number}")
latent_dim = 128
#define the generator and the discriminator
discriminator = Discriminator().to(device=device)
generator = Generator().to(device=device)
cfg = {
"train_batch_size": trial.suggest_categorical("train_batch_size", [64, 128]),
"device": "cuda" if torch.cuda.is_available() else "cpu",
"num_epochs": 100,
"lr": trial.suggest_float("lr", 1e-5, 1e-3, log=True),
"optimizer": trial.suggest_categorical("optimizer", ["Adam", "AdamW"])
}
#define the loader
batch_size = cfg["train_batch_size"]
train_loader, _ = get_mnist_loaders(batch_size, batch_size)
#define the optimizers
lr = cfg['lr']
optimizer_name = cfg['optimizer']
d_optimizer = getattr(optim, optimizer_name)(discriminator.parameters(), lr=lr) # Instantiate optimizer from name
g_optimizer = getattr(optim, optimizer_name)(generator.parameters(), lr=lr) # Instantiate optimizer from name
#define the criterion
criterion = nn.BCELoss()
print(f"Batch Size: {batch_size}\nLearning Rate: {lr}\nOptimizer: {optimizer_name}")
for epoch in range(cfg['num_epochs']):
print(f"running epoch number: {epoch+1}")
for n, (images, _) in tqdm(enumerate(train_loader)):
images = images.to(device)
real_labels = torch.ones(images.size(0)).to(device)
noise = torch.randn(images.size(0), latent_dim).to(device)
fake_images = generator(noise)
fake_labels = torch.zeros(images.size(0)).to(device)
# Train the discriminator
d_loss, real_score, fake_score = train_discriminator(discriminator, images,
real_labels, fake_images, fake_labels,
criterion, d_optimizer)
noise = torch.randn(images.size(0), latent_dim).to(device)
fake_images = generator(noise)
outputs = discriminator(fake_images)
# Train the generator
g_loss = train_generator(generator, outputs, real_labels, criterion, g_optimizer)
if (n+1) % len(train_loader) == 0:
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
'D(x): %.2f, D(G(z)): %.2f'
% (epoch + 1, cfg['num_epochs'], n + 1, len(train_loader), d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
img_path = f"tmp/generated_image-{trial.number}.png"
generate_new_images(generator, 30, latent_dim, img_path)
artifacts_id = upload_artifact(trial, img_path, artifact_store)
artifact_path = get_artifact_path(trial, artifacts_id)
# 4. Save Note
note = textwrap.dedent(
f"""\
## Trial {trial.number}
Grid of GAN generated images!!
![generated-images]({artifact_path})
d_loss: {d_loss.item():.2f}\n g_loss: {g_loss.item():.2f}
"""
)
save_note(trial, note)
return g_loss.item(), d_loss.item()
#start optimisation
def start_optimization(artifact_store: FileSystemArtifactStore):
"""
The function `start_optimization` sets up a study for human-in-the-loop optimization using Optuna
for digit generation, registers choice widgets for user input, and starts the optimization process
by training GANs in batches.
:param artifact_store: The `artifact_store` parameter is an instance of the
`FileSystemArtifactStore` class, which is used to store artifacts related to the optimization
process. It could be a file system location where you store generated images, model checkpoints, or
any other relevant data during the optimization process. This allows you to
:type artifact_store: FileSystemArtifactStore
"""
# 1. Create Study
storage = "sqlite:///db.sqlite3"
study = optuna.create_study(study_name="HITL_with_optuna_for_digit_generation",
directions=['minimize', 'maximize'],
storage=storage,
load_if_exists=True)
# 2. Set an objective name
study.set_metric_names(["Are you satisfied with the model's generator's performance?", "Are you satisfied with the discriminator's performance?"])
# 3. Register ChoiceWidget
register_objective_form_widgets(
study,
widgets=[
ChoiceWidget(
choices=["Yes 👍", "Somewhat 👌", "No 👎"],
values=[-1, 0, 1],
description="Please input your score for generated images!",
),
ChoiceWidget(
choices=["Yes 👍", "Somewhat 👌", "No 👎"],
values=[1, 0, -1],
description="Please input your score for model performance!",
),
],
)
# 4. Start Human-in-the-loop Optimization
n_batch = 6
while True:
running_trials = study.get_trials(deepcopy=False, states=(TrialState.RUNNING,))
if len(running_trials) >= n_batch:
time.sleep(1) # Avoid busy-loop
continue
train_GANs(study, artifact_store)
#main function to define artifact store and start optimisation
def main():
"""
The `main` function in the Python code snippet creates an artifact store, sets up paths for
temporary and artifact files, and then initiates an optimization loop.
"""
# Get the absolute path to the current notebook file
tmp_path = os.path.join(os.path.dirname(__file__), "tmp")
# 1. Create Artifact Store
artifact_path = os.path.join(os.path.dirname(__file__), "artifact")
artifact_store = FileSystemArtifactStore(artifact_path)
print(f"paths : {tmp_path}, {artifact_path}")
if not os.path.exists(artifact_path):
os.mkdir(artifact_path)
if not os.path.exists(tmp_path):
os.mkdir(tmp_path)
# 2. Run optimize loop
start_optimization(artifact_store)
#run the script
if __name__ == "__main__":
main()