In [37]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import deepsudoku
import deepsudoku.network
from deepsudoku.utils import data_utils
from datetime import datetime
from torch.nn import functional

In [4]:
data_utils.split_data()
train_sudokus_raw, val_sudokus_raw, _ = data_utils.load_data()


In [5]:
def natural_distribution():
    possible_numbers_of_moves_to_make = list(range(0,64))
    with open("data/natural_distribution.np", 'rb') as f:
        probabilities = np.load(f)
    probabilities = probabilities/sum(probabilities)
    return possible_numbers_of_moves_to_make, probabilities

train_sudokus = data_utils.make_moves(train_sudokus_raw, n_moves_distribution=natural_distribution, invalid_sudoku_probability = 0.2)
val_sudokus = data_utils.make_moves(val_sudokus_raw, n_moves_distribution=natural_distribution, invalid_sudoku_probability = 0.2)

In [29]:
x_train, y_train = data_utils.fast_generate_batch(train_sudokus)
x_train = torch.tensor(x_train.reshape(-1,1,9,9).astype('float32'))
y_train = (torch.tensor(y_train[0] - 1).type(torch.LongTensor), torch.tensor(y_train[1]).type(torch.LongTensor))

In [32]:
y_pred_train = network(x_train)

In [49]:
functional.mse_loss(y_pred_train[1], y_train[1])

tensor(0.2314, grad_fn=<MseLossBackward0>)

In [35]:
y_pred_train[0].dtype

torch.float32

In [36]:
y_train[0].dtype

torch.int64

In [48]:
functional.cross_entropy(y_pred_train[0], y_train[0])

tensor(2.9158, grad_fn=<NllLoss2DBackward0>)

In [7]:
x_val, y_val = data_utils.generate_batch(val_sudokus, augment = False)
x_val = torch.tensor(x_val.reshape(-1,1,9,9).astype('float32'))
y_val = (torch.tensor(y_val[0] - 1).type(torch.LongTensor), torch.tensor(y_val[1]).type(torch.LongTensor))

In [50]:
network = deepsudoku.network.Network()
optimizer = torch.optim.Adam(network.parameters(), amsgrad = True)
loss_fn = deepsudoku.network.my_loss

val_losses = []
losses = []
current_epoch = 0

In [53]:
min_val_loss = 100
n_epochs = 100000
# # load the model checkpoint
# checkpoint = torch.load('./documentation/models/2/model.pth')
# # load model weights state_dict
# model.load_state_dict(checkpoint['model_state_dict'])
# print('Previously trained model weights state_dict loaded...')
# # load trained optimizer state_dict
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# print('Previously trained optimizer state_dict loaded...')
# start_epoch = checkpoint['epoch']
# # load the criterion
# loss = checkpoint['loss']
# print('Trained model loss function loaded...')
# min_val_loss = loss

for epoch in range(current_epoch, n_epochs):
    network.train()
    x_train, y_train = data_utils.fast_generate_batch(train_sudokus)
    x_train = torch.tensor(x_train.reshape(-1,1,9,9).astype('float32'))
    y_train = (torch.tensor(y_train[0] - 1).type(torch.LongTensor), torch.tensor(y_train[1].astype('float32')))
    # x_train = x_train.cuda()
    # y_train = y_train.cuda()

    y_pred = network(x_train)
    # y_pred = y_pred.cuda()

    loss = loss_fn(y_pred, y_train)
    optimizer.zero_grad()

    loss.backward()

    torch.nn.utils.clip_grad_norm_(network.parameters(), 1)
    optimizer.step()

    network.eval()
    with torch.no_grad():
        y_pred_val = network(x_val)
        val_loss = loss_fn(y_pred_val, y_val)

    val_losses.append(val_loss.item())
    losses.append(loss.item())

    if (epoch % 1000) == 0:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }, './documentation/models/06-16-2022/epoch%d_loss%.3f.pth' % (epoch,val_loss))
    if val_loss < min_val_loss:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': network.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }, './documentation/models/06-16-2022/best.pth')
        min_val_loss = val_loss

    print(f'Epoch {epoch}, loss = {loss.item():.4f}, validation loss = {val_loss.item():.4f}, time = {datetime.now()}.', end = "\r")

Epoch 3, loss = 2.4137, validation loss = 2.3808, time = 2022-06-16 16:56:42.410748.

KeyboardInterrupt: 