In [32]:
import torch
import datasets, networks, sampling, completion, evaluation
import matplotlib.pyplot as plt
from torch import nn

In [33]:
#HYPERPARAMETERS
num_epochs = 25

# TODO Select a optimizer. [ 'adam', 'adamw', 'rmsprop' ]
optimizer_option = 'adam'

# TODO Select a lr scheduler. [ 'step', 'cosine', 'exponential']
lr_scheduler_option = 'step'

# TODO Select a batch size.
batch_size = 64

# TODO Select a learning rate.
lr = 0.00001


num_residual = 8
num_kernels = 64

In [34]:
ds = datasets.Dataset('mnist', batch_size=batch_size)
training_data = ds.get_train_data_loader()
test_data = ds.get_test_data_loader()

# ds.visualize_dataset(training_data)

In [35]:
# Training
import time
from torch.optim.lr_scheduler import StepLR, ExponentialLR, CosineAnnealingLR
from torch.optim import Adam, AdamW, RMSprop, SGD 
from torch.autograd import Variable

PixelCNN = networks.PixelCNN(num_kernels=num_kernels)
# Select device
# if torch.cuda.is_available():
#   device = torch.device('cuda:0')
# else:
#   device = torch.device('cpu')
device = torch.device('cuda:0')
print(f'Using device: {device}')
print('=========================================')

PixelCNN.to(device)

learning_rates = []
train_loss_curve = []
test_loss_curve = []
train_loss_epochs = []
test_loss_epochs = []
optimizer = None
criterion = nn.CrossEntropyLoss()

# Select optimizer
if optimizer_option == "adam":
  optimizer = torch.optim.Adam(PixelCNN.parameters(), lr)
  
elif optimizer_option == "adamw":
  optimizer = torch.optim.AdamW(PixelCNN.parameters(), lr)
  
elif optimizer_option == "rmsprop":
  optimizer = torch.optim.RMSprop(PixelCNN.parameters(), lr)
  
else:
  optimizer = torch.optim.SGD(PixelCNN.parameters(), lr, momentum=0.9)  


#Select scheduler
scheduler = None
if lr_scheduler_option == 'step':
    scheduler = StepLR(optimizer, step_size=2, gamma=0.9)
    
elif lr_scheduler_option == 'exponential':
    scheduler = ExponentialLR(optimizer, gamma=0.9)
    
elif lr_scheduler_option == 'cosine':
    scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0.0001)

overall_start_time = time.time()

# training loop
for epoch in range(num_epochs):

  epoch_start_time = time.time()
  # Switch to training mode.
  PixelCNN.train()

  losses = []
  batch_idx = 0

  for images, labels in training_data:

    target = Variable(images[:,0,:,:]*255).long()
    # target = images.view(-1)
    images = images.to(device)
    target = target.to(device)

    optimizer.zero_grad()
    output = PixelCNN(images)

    # output.reshape(-1, 256)

    # print(output.shape)
    # print(target.shape)
    loss = criterion(output, target)
    loss.backward()
    # torch.nn.utils.clip_grad_norm_(PixelCNN.parameters(), 1)
    optimizer.step()
    for param_group in optimizer.param_groups:
      learning_rates.append(param_group['lr'])



    losses.append(loss.detach().clone())

    if batch_idx % 100 == 0:
      average_loss = torch.stack(losses).mean().item()
      train_loss_curve.append(average_loss)
      train_loss_epochs.append(epoch + 1)
      losses = []
      print(f'Epoch: {epoch + 1:3d}/{num_epochs:3d}, Batch {batch_idx + 1:5d}, Loss: {average_loss:.4f}')
    batch_idx += 1

  # scheduler.step()
  epoch_end_time = time.time()
  print('-----------------------------------------')
  print(f'Epoch: {epoch + 1:3d} took {epoch_end_time - epoch_start_time:.2f}s')
  # test_loss = evaluation.evaluate(model=PixelCNN, test_data_loader=test_data, device=device, batch_size=batch_size)
  # test_loss_curve.append(test_loss)
  # test_loss_epochs.append(epoch + 1)
  # print(f'Epoch: {epoch + 1:3d}, Test Loss: {test_loss:.4f}')
  # print('-----------------------------------------')
  

  sampling.samplemnist(PixelCNN, num_samples=5)

  

overall_end_time = time.time()
print('=========================================')
print(f'Training took {overall_end_time - overall_start_time:.2f}s')

# Loss Curve Plot
plt.figure(figsize=(10, 5))
plt.plot(train_loss_epochs, train_loss_curve, label='Train Loss')
plt.scatter(test_loss_epochs, test_loss_curve, color='red', label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()

# Learning Rate Plot
num_batches = len(training_data)
learning_rates_res = [sum(learning_rates[i * num_batches:(i + 1) * num_batches]) / num_batches for i in range(num_epochs)]
plt.figure(figsize=(10, 5))
plt.plot(range(num_epochs), learning_rates_res)
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate over Time')
plt.show()

Using device: cuda:0
Epoch:   1/ 25, Batch     1, Loss: 5.5609
Epoch:   1/ 25, Batch   101, Loss: 5.5005
Epoch:   1/ 25, Batch   201, Loss: 5.3078
Epoch:   1/ 25, Batch   301, Loss: 4.5070
Epoch:   1/ 25, Batch   401, Loss: 2.3677
Epoch:   1/ 25, Batch   501, Loss: 2.0467
Epoch:   1/ 25, Batch   601, Loss: 1.8921
Epoch:   1/ 25, Batch   701, Loss: 1.7636
Epoch:   1/ 25, Batch   801, Loss: 1.6417
Epoch:   1/ 25, Batch   901, Loss: 1.5363
-----------------------------------------
Epoch:   1 took 73.45s
Epoch:   2/ 25, Batch     1, Loss: 1.4383
Epoch:   2/ 25, Batch   101, Loss: 1.4073
Epoch:   2/ 25, Batch   201, Loss: 1.3352
Epoch:   2/ 25, Batch   301, Loss: 1.2782
Epoch:   2/ 25, Batch   401, Loss: 1.2409
Epoch:   2/ 25, Batch   501, Loss: 1.2134
Epoch:   2/ 25, Batch   601, Loss: 1.1955
Epoch:   2/ 25, Batch   701, Loss: 1.1790
Epoch:   2/ 25, Batch   801, Loss: 1.1737
Epoch:   2/ 25, Batch   901, Loss: 1.1605
-----------------------------------------
Epoch:   2 took 75.02s
Epoch:   

KeyboardInterrupt: 