In [1]:
!pip install --upgrade --quiet git+https://github.com/dtonderski/DeepSudoku

[0m

In [2]:
!pip show deepsudoku

Name: deepsudoku
Version: 0.8.4
Summary: Solving Sudokus using a Neural Network assisted Monte-Carlo approach.
Home-page: https://github.com/dtonderski/DeepSudoku
Author: davton
Author-email: dtonderski@gmail.com
License: GNU GPLv3
Location: /usr/local/lib/python3.9/dist-packages
Requires: einops, numpy, py-sudoku, torch
Required-by: 


### Fine-tuning Sudoker

### Data

In [3]:
import deepsudoku as ds
from deepsudoku.utils import data_utils
import pickle as pkl
import random
import os

In [4]:
train_sudokus_raw, val_sudokus_raw, _ = data_utils.load_data()

In [5]:
val_sudokus = data_utils.make_moves(val_sudokus_raw, n_moves_distribution=data_utils.difficulty_uniform_combo_distribution)

In [6]:
old_previous_data_path = '../models/initial/previous_data.pkl'
with open(old_previous_data_path, 'rb') as f:
    previous_data = pkl.load(f)
    
previous_data_path = '../models/finetuning/previous_data.pkl'
if os.path.exists(previous_data_path):
    print(f"Loading from {previous_data_path}")
    with open(previous_data_path, 'rb') as f:
        previous_data = pkl.load(f)
print(f"{len(previous_data)=}")

Loading from models/finetuning/previous_data.pkl
len(previous_data)=24


### Network

In [7]:
from deepsudoku.dsnn.training import generate_training_data
from deepsudoku.dsnn import sudoker, loss
import torch
import os

In [9]:
models = {"ViTTiSudoker": lambda: sudoker.Sudoker(12,192,3,768,0).cuda()}
batch_sizes = {"ViTTiSudoker": 512}
lrs = {"ViTTiSudoker": 1e-5}

In [10]:
model_name = "ViTTiSudoker"
network, batch_size = models[model_name](), batch_sizes[model_name]
optimizer = torch.optim.Adam(network.parameters())
loss_fn = loss.loss

In [11]:
old_model_path = f"../models/training/{model_name}"

checkpoint = torch.load(old_model_path)
network.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [12]:
best_model_path = f'../models/finetuning/best.pth'
if os.path.exists(best_model_path):
    print("Loading model")
    checkpoint = torch.load(best_model_path)
    network.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    losses = checkpoint['losses']
    cat_accs = checkpoint['cat_accs']
    bin_accs = checkpoint['bin_accs']
    min_percentages = checkpoint['min_percentages']
else:
    losses = []
    cat_accs = []
    bin_accs = []
    min_percentages = []
    
starting_epoch = len(losses)
best_percentage = 0 if len(min_percentages) == 0 else min_percentages[-1]
print(f"{starting_epoch=}")

Loading model
starting_epoch=100


### Training

In [13]:
from deepsudoku.dsnn.evaluation import evaluate, get_averages, print_evaluation, categorical_accuracy, binary_accuracy
from deepsudoku.utils import network_utils
from deepsudoku.montecarlo.simulation import get_n_simulations_function
from datetime import datetime

In [14]:
epochs = 1000
# Do 128 to get more training data
n_simulations_function = get_n_simulations_function(1, 128, use_builtin_difficulty=False)
valid_n_simulations_function = get_n_simulations_function(1, 16, use_builtin_difficulty=False)

min_data_size = 16384
sudokus_to_evaluate = 128
generate_and_evaluate_every_n_epochs = 10

In [None]:
train = True

if train:
    for epoch in range(starting_epoch, epochs):
        if epoch % generate_and_evaluate_every_n_epochs == 0:
            network.eval()
            if epoch > starting_epoch:
                moves_before_failure_dict, percentage_completed_dict = evaluate(val_sudokus, network, valid_n_simulations_function, 
                                                                                n_played_sudokus = sudokus_to_evaluate)

                avg_moves_dict, avg_percentage_dict = get_averages(moves_before_failure_dict, percentage_completed_dict)

                print_evaluation(avg_moves_dict, avg_percentage_dict)

                current_min_average_percentage_before_failure = min(avg_percentage_dict.values())
                min_percentages.append(current_min_average_percentage_before_failure)

                train_sudokus = data_utils.make_moves(train_sudokus_raw, n_moves_distribution=data_utils.difficulty_uniform_combo_distribution)
                generated_train_sudokus = generate_training_data(train_sudokus, network, n_simulations_function, verbose = 1, min_data_size = min_data_size)

                previous_data.append(generated_train_sudokus)

                with open(previous_data_path, 'wb') as f:
                    pkl.dump(previous_data, f)
                    
                if current_min_average_percentage_before_failure > best_percentage:
                    print(f"Min percentage increased from {best_percentage} to "
                          f"{current_min_average_percentage_before_failure}! Saving network")
                    best_percentage = current_min_average_percentage_before_failure
                    torch.save({
                        'model_state_dict': network.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'cat_accs': cat_accs,
                        'bin_accs': bin_accs,
                        'losses': losses,
                        'min_percentages': min_percentages
                        }, f'{best_model_path}')
            
            current_train_sudokus = [sudoku for simulation_sudokus in previous_data for sudoku in simulation_sudokus]
            current_fraction = sum([x[2] for x in current_train_sudokus])/len([x[2] for x in current_train_sudokus])
            print(f"{len(current_train_sudokus)=}, {current_fraction=}.")
            
            network.train()
            

            
        random.shuffle(current_train_sudokus)
        batch_losses, batch_cat_accs, batch_cat_accs_weights, batch_bin_accs = [], [], [], []
        
        for i in range(0, len(current_train_sudokus), batch_size):
            batch_sudokus = current_train_sudokus[i:i+batch_size]

            x_np, y_np = data_utils.generate_numpy_batch(batch_sudokus, augment = True)
            x, y = network_utils.numpy_batch_to_pytorch(x_np, y_np, 'cuda')
            y_pred = network(x)
            
            binary_cross_entropy_weights = loss.get_binary_cross_entropy_weights(y[1])

            batch_p_loss, batch_v_loss = loss_fn(x, y_pred, y, binary_cross_entropy_weights=binary_cross_entropy_weights)    
            batch_loss = batch_v_loss + batch_p_loss

            batch_cat_acc = categorical_accuracy(x, y, y_pred)
            
            # Weight is number of valid sudokus
            batch_cat_accs_weights.append(y[1].sum())
            
            batch_bin_acc = binary_accuracy(y, y_pred)

            batch_losses.append(batch_loss.item())
            batch_cat_accs.append(batch_cat_acc)
            batch_bin_accs.append(batch_bin_acc)

            optimizer.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(network.parameters(), 1)
            optimizer.step()

            print(f"Epoch {epoch}, batch {min(i+batch_size, len(current_train_sudokus))}/{len(current_train_sudokus)}," 
                  f"{batch_p_loss.item()=:.4f}, {batch_v_loss.item()=:.4f}, {batch_cat_acc=:.4f}, {batch_bin_acc=:.4f}", end = "\r")

        losses.append(sum(batch_losses)/len(batch_losses))
        batch_cat_accs_weights = [x/sum(batch_cat_accs_weights) for x in batch_cat_accs_weights]
        # Weighted average over cat_accs
        cat_accs.append(sum([x*y for x,y in zip(batch_cat_accs, batch_cat_accs_weights)]))
        bin_accs.append(sum(batch_bin_accs)/len(batch_bin_accs))

        print(f'Epoch {epoch}, loss = {losses[-1]:.4f}, cat_acc = {cat_accs[-1]:.4f}, bin_acc = {bin_accs[-1]:.4f}, time = {datetime.now()}.                                       ')

len(current_train_sudokus)=397284, current_fraction=0.6933125925030961.
Epoch 100, loss = 0.0575, cat_acc = 0.9893, bin_acc = 0.9899, time = 2022-12-26 06:54:08.140689.                                       
Epoch 101, loss = 0.0398, cat_acc = 0.9931, bin_acc = 0.9923, time = 2022-12-26 07:03:03.591982.                                       
Epoch 102, loss = 0.0364, cat_acc = 0.9939, bin_acc = 0.9926, time = 2022-12-26 07:11:59.908073.                                       
Epoch 103, loss = 0.0337, cat_acc = 0.9944, bin_acc = 0.9931, time = 2022-12-26 07:20:56.183865.                                       
Epoch 104, loss = 0.0320, cat_acc = 0.9947, bin_acc = 0.9935, time = 2022-12-26 07:29:50.236617.                                       
Epoch 105, loss = 0.0317, cat_acc = 0.9948, bin_acc = 0.9934, time = 2022-12-26 07:38:44.157147.                                       
Epoch 106, loss = 0.0309, cat_acc = 0.9949, bin_acc = 0.9935, time = 2022-12-26 07:47:39.700471.                