In [264]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils import data
# from torch.optim.lr_scheduler import StepLR
import numpy as np
from DQN import DQLearning, Conv2D
from tqdm import tqdm

from tensorboardX import SummaryWriter

writer = SummaryWriter()

mps_device = torch.device('mps')
torch.backends.cudnn.benchmark = True
np.random.seed(1)
torch.manual_seed(1)
torch.mps.manual_seed(1)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [269]:
# add code for datasets (we always use train and validation/ test set)
data_transforms = transforms.Compose([
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# get train dataset
# train_dataset = datasets.ImageFolder(
#     root=os.path.join(opt.path_to_data, "train"),
#     transform=data_transforms)
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=data_transforms)

train_data_loader = data.DataLoader(train_dataset, 64)


# test_dataset = datasets.ImageFolder(
#     root=os.path.join(opt.path_to_data, "test"),
#     transform=data_transforms)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=data_transforms)

test_data_loader = data.DataLoader(test_dataset, 10000)

In [270]:

net = DQLearning(inputs=784, classes=10, hidden_units=[10,10,10])
# net = Conv2D()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)
net.apply(init_weights)
net.to(mps_device)
print(net)

Linear(in_features=784, out_features=10, bias=True) ReLU() Linear(in_features=10, out_features=10, bias=True) ReLU() Linear(in_features=10, out_features=10, bias=True) ReLU() Linear(in_features=10, out_features=10, bias=True) LogSoftmax(dim=1)
DQLearning(
  (net): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): ReLU()
    (6): Linear(in_features=10, out_features=10, bias=True)
    (7): LogSoftmax(dim=1)
  )
)


  torch.nn.init.xavier_uniform(m.weight)


In [271]:
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(net.parameters(), lr=0.01)

In [272]:
import time


# now we start the main loop
epochs = 100
for epoch in range(0, epochs):

    # set models to train mode
    net.train()

    # use prefetch_generator and tqdm for iterating through data
    start_time = time.time()


    # for loop going through dataset
    pbar = tqdm(enumerate(train_data_loader),
                    total=len(train_data_loader))
    
    start_time = time.time()
    loss = 0
    for batch_idx, (img, label) in pbar:
        img, label = img.to(mps_device), label.to(mps_device)
        
        # data preparation
        flatten_img = torch.flatten(img, 1, -1)
        
        
        # # It's very good practice to keep track of preparation time and computation time using tqdm to find any issues in your dataloader
        prepare_time = start_time-time.time()
        
        output = net(flatten_img)
        loss = criterion(output, label)
    
        # # forward and backward pass
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        # # compute computation time and *compute_efficiency*
        process_time = start_time-time.time()-prepare_time
        pbar.set_description("Compute efficiency: {:.2f}, epoch: {}/{}".format(
            process_time/(process_time+prepare_time), epoch, epochs))
        start_time = time.time()
    
    writer.add_scalar("Loss/train", loss, epoch)
    
    # # maybe do a test pass every x epochs
    if epoch % 5 == 0:
        # bring models to evaluation mode
        net.eval()
        # Iterate over the test data loader
        num_correct = 0
        num_samples = 0
        for i, (test_img, test_label) in enumerate(test_data_loader):
            test_img, test_label = test_img.to(mps_device), test_label.to(mps_device)
            output_test = net(torch.flatten(test_img, 1, -1))
            # output_test = net(test_img)
            
            # Get predictions
            predictions = torch.argmax(output_test, 1)
            # Update accuracy statistics
            num_correct += (predictions == test_label).sum().item()
            num_samples += test_label.size(0)

        # Calculate and print accuracy
        accuracy = float(num_correct) / float(num_samples) * 100 if num_samples > 0 else 0.0
        print(f'Got {num_correct} / {num_samples} with accuracy {accuracy:.2f}%')

writer.flush()
writer.close()

Compute efficiency: 0.65, epoch: 0/100: 100%|██████████| 938/938 [00:08<00:00, 109.17it/s]


Got 7977 / 10000 with accuracy 79.77%


Compute efficiency: 0.67, epoch: 1/100: 100%|██████████| 938/938 [00:08<00:00, 108.46it/s]
Compute efficiency: 0.61, epoch: 2/100: 100%|██████████| 938/938 [00:08<00:00, 110.12it/s]
Compute efficiency: 0.61, epoch: 3/100: 100%|██████████| 938/938 [00:08<00:00, 108.13it/s]
Compute efficiency: 0.64, epoch: 4/100: 100%|██████████| 938/938 [00:08<00:00, 106.51it/s]
Compute efficiency: 0.58, epoch: 5/100: 100%|██████████| 938/938 [00:08<00:00, 110.95it/s]


Got 8493 / 10000 with accuracy 84.93%


Compute efficiency: 0.65, epoch: 6/100: 100%|██████████| 938/938 [00:08<00:00, 110.58it/s]
Compute efficiency: 0.60, epoch: 7/100: 100%|██████████| 938/938 [00:08<00:00, 113.14it/s]
Compute efficiency: 0.64, epoch: 8/100: 100%|██████████| 938/938 [00:08<00:00, 111.76it/s]
Compute efficiency: 0.56, epoch: 9/100: 100%|██████████| 938/938 [00:08<00:00, 115.74it/s]
Compute efficiency: 0.63, epoch: 10/100: 100%|██████████| 938/938 [00:08<00:00, 116.43it/s]


Got 8418 / 10000 with accuracy 84.18%


Compute efficiency: 0.64, epoch: 11/100: 100%|██████████| 938/938 [00:08<00:00, 114.81it/s]
Compute efficiency: 0.62, epoch: 12/100: 100%|██████████| 938/938 [00:08<00:00, 115.49it/s]
Compute efficiency: 0.60, epoch: 13/100: 100%|██████████| 938/938 [00:08<00:00, 116.14it/s]
Compute efficiency: 0.60, epoch: 14/100: 100%|██████████| 938/938 [00:08<00:00, 116.21it/s]
Compute efficiency: 0.62, epoch: 15/100: 100%|██████████| 938/938 [00:08<00:00, 116.33it/s]


Got 8378 / 10000 with accuracy 83.78%


Compute efficiency: 0.59, epoch: 16/100: 100%|██████████| 938/938 [00:08<00:00, 115.61it/s]
Compute efficiency: 0.58, epoch: 17/100: 100%|██████████| 938/938 [00:08<00:00, 115.35it/s]
Compute efficiency: 0.62, epoch: 18/100: 100%|██████████| 938/938 [00:08<00:00, 116.11it/s]
Compute efficiency: 0.61, epoch: 19/100: 100%|██████████| 938/938 [00:08<00:00, 114.67it/s]
Compute efficiency: 0.59, epoch: 20/100: 100%|██████████| 938/938 [00:08<00:00, 115.42it/s]


Got 8473 / 10000 with accuracy 84.73%


Compute efficiency: 0.63, epoch: 21/100: 100%|██████████| 938/938 [00:08<00:00, 113.04it/s]
Compute efficiency: 0.66, epoch: 22/100: 100%|██████████| 938/938 [00:08<00:00, 109.38it/s]
Compute efficiency: 0.91, epoch: 23/100: 100%|██████████| 938/938 [00:08<00:00, 104.57it/s]
Compute efficiency: 0.59, epoch: 24/100: 100%|██████████| 938/938 [00:08<00:00, 108.08it/s]
Compute efficiency: 0.68, epoch: 25/100: 100%|██████████| 938/938 [00:08<00:00, 105.10it/s]


Got 8180 / 10000 with accuracy 81.80%


Compute efficiency: 0.64, epoch: 26/100: 100%|██████████| 938/938 [00:09<00:00, 96.90it/s] 
Compute efficiency: 0.61, epoch: 27/100: 100%|██████████| 938/938 [00:08<00:00, 111.54it/s]
Compute efficiency: 0.62, epoch: 28/100: 100%|██████████| 938/938 [00:08<00:00, 113.14it/s]
Compute efficiency: 0.60, epoch: 29/100: 100%|██████████| 938/938 [00:08<00:00, 114.58it/s]
Compute efficiency: 0.59, epoch: 30/100: 100%|██████████| 938/938 [00:08<00:00, 114.85it/s]


Got 8457 / 10000 with accuracy 84.57%


Compute efficiency: 0.66, epoch: 31/100: 100%|██████████| 938/938 [00:08<00:00, 114.26it/s]
Compute efficiency: 0.55, epoch: 32/100: 100%|██████████| 938/938 [00:08<00:00, 108.96it/s]
Compute efficiency: 0.63, epoch: 33/100: 100%|██████████| 938/938 [00:08<00:00, 112.12it/s]
Compute efficiency: 0.66, epoch: 34/100: 100%|██████████| 938/938 [00:08<00:00, 107.72it/s]
Compute efficiency: 0.65, epoch: 35/100: 100%|██████████| 938/938 [00:08<00:00, 110.98it/s]


Got 8442 / 10000 with accuracy 84.42%


Compute efficiency: 0.65, epoch: 36/100: 100%|██████████| 938/938 [00:08<00:00, 111.04it/s]
Compute efficiency: 0.67, epoch: 37/100: 100%|██████████| 938/938 [00:08<00:00, 110.56it/s]
Compute efficiency: 0.63, epoch: 38/100: 100%|██████████| 938/938 [00:08<00:00, 112.23it/s]
Compute efficiency: 0.53, epoch: 39/100:  49%|████▊     | 457/938 [00:04<00:04, 107.47it/s]


KeyboardInterrupt: 