In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

torch.cuda.is_available()

In [None]:
dataset = np.load('1m_markers_1_1.npy')

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 2048),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, 8 * 8),
            nn.Tanh()  # Normalize output to [-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 8, 8)  # Reshape to 8x8 image
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(8 * 8, 2048),  # Increase the number of neurons
            nn.LeakyReLU(0.2),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)  # Output is a scalar (score)
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)  # Flatten the image into a vector (batch_size, 64)
        validity = self.model(img_flat)
        return validity

In [None]:
def custom_generator_loss(fake_images, fake_validity, penalty_weight=1.0):
    """
    Custom loss function for the generator with denormalization and condition checking.
    Penalties are calculated for each image in the batch separately.
    
    :param fake_images: Generated images (in range [-1, 1]).
    :param fake_validity: Discriminator output for generated images.
    :param penalty_weight: Weight of the penalty for violating conditions.
    :return: Total loss function for the generator.
    """
    # Denormalization: transform images to range [0, 12]
    denormalized_images = fake_images * 6 + 6
    # 3. Round values to integers
    # denormalized_images = np.round(denormalized_images).astype(np.int8)
    # Main loss function (WGAN loss)
    wgan_loss = -torch.mean(fake_validity)
    # return wgan_loss
    # Penalties for violating conditions (calculated for each image in the batch)
    penalty = 0.0
    
    # Condition 1: Exactly one pixel with value 11 and one with value 12
    count_11 = torch.sum((denormalized_images >= 10.5) & (denormalized_images < 11.5), dim=(1, 2, 3))
    count_12 = torch.sum((denormalized_images >= 11.5) & (denormalized_images <= 12.0), dim=(1, 2, 3))
    
    # Convert to floating point type before calculating mean
    penalty += 2.0 * torch.mean(torch.abs(count_11 - 1).float())  # Mean penalty per batch
    penalty += 2.0 * torch.mean(torch.abs(count_12 - 1).float())  # Mean penalty per batch
    
    # Condition 2: No more than 8 pixels with value 1 and 2
    count_1 = torch.sum((denormalized_images >= 0.5) & (denormalized_images < 1.5), dim=(1, 2, 3))
    count_2 = torch.sum((denormalized_images >= 1.5) & (denormalized_images < 2.5), dim=(1, 2, 3))
    
    # Convert to floating point type before calculating mean
    penalty += 3.0 * torch.mean(torch.relu(count_1 - 4).float())  # Mean penalty per batch
    penalty += 3.0 * torch.mean(torch.relu(count_2 - 4).float())  # Mean penalty per batch
    
    # Condition 3: Absence of pixels with value 1 and 2 in the first and eighth row
    first_row = denormalized_images[:, :, 0, :]  # First row (index 0)
    eighth_row = denormalized_images[:, :, 7, :]  # Eighth row (index 7)
    
    # Check for presence of pixels with value 1 and 2 in the first and eighth row
    penalty_first_row = torch.sum((first_row >= 0.5) & (first_row < 2.5), dim=(1, 2))
    penalty_eighth_row = torch.sum((eighth_row >= 0.5) & (eighth_row < 2.5), dim=(1, 2))
    
    # Convert to floating point type before calculating mean
    penalty += 4.0 * torch.mean((penalty_first_row + penalty_eighth_row).float())  # Mean penalty per batch
    
    # Total loss function
    total_loss = wgan_loss + penalty
    return total_loss

In [None]:
def compute_gradient_penalty(discriminator, real_images, fake_images, device):
    """Computes gradient penalty."""
    batch_size = real_images.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)  # Random coefficients
    interpolates = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True)
    d_interpolates = discriminator(interpolates)
    fake = torch.ones(batch_size, 1).to(device)  # Gradient penalty is always calculated relative to 1
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)  # Reshape gradients into a one-dimensional vector
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
data_tensor = torch.tensor(dataset, dtype=torch.float32)
# Create Dataset and DataLoader
datasett = TensorDataset(data_tensor)
dataloader = DataLoader(datasett, batch_size=64, shuffle=True)

In [None]:
# Parameters
latent_dim = 100
img_shape = (1, 8, 8)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.9))

In [None]:
n_epochs = 10000
batch_size = 64
lambda_gp = 10

from IPython.display import clear_output

for epoch in range(n_epochs):
    for i, real_images in enumerate(dataloader):
        real_images = real_images[0].to(device)
        real_images = real_images.unsqueeze(1)

        optimizer_D.zero_grad()

        z = torch.randn(real_images.size(0), latent_dim).to(device)
        fake_images = generator(z)
        
        real_validity = discriminator(real_images)
        fake_validity = discriminator(fake_images.detach())

        gradient_penalty = compute_gradient_penalty(discriminator, real_images, fake_images, device)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        if i % 5 == 0:
            optimizer_G.zero_grad()
            fake_validity = discriminator(fake_images)
        
            # Custom loss function
            #g_loss = -torch.mean(fake_validity)
            g_loss = custom_generator_loss(fake_images, fake_validity, 2.0)
            g_loss.backward()
            optimizer_G.step()

        # Output progress
        if i % 100 == 0:
            clear_output(wait=True)
            #filtered_count = (filter_mask == 0).sum().item()
            print(
                f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
                f"[D loss: {d_loss.item()}] [G loss: {g_loss.item()}]"
            )

In [None]:
from utils import array_to_fen, display_board
import chess

def eval_rand(nsize=64):
    res = []
    z = torch.randn(nsize, latent_dim).to(device)

    gen_img = generator(z)
    
    for i in range(0, nsize-1):
        image_array = np.array(gen_img[i].detach().cpu().numpy()[0])
        image_array = image_array * 6+6
        image_array = np.round(image_array).astype(np.int8)
        fen = array_to_fen(image_array)
        board = chess.Board(fen)
        if board.is_valid():
            res.append(fen)
    for r in res:
        display_board(r)

In [None]:
eval_rand(batch_size)
torch.save(generator.state_dict(), 'gen')
torch.save(discriminator.state_dict(), 'disc')