In [22]:
import pandas as pd
import numpy as np
from create_flop_data import FlopDataset, flop_collate_fn
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from utils import flop_to_vector, flop_look_up, vec_to_flop
from utils import (
    SUITEDNESS_DICT, 
    PAIRNESS_DICT,
    CONNECTEDNESS_DICT,
    HIGH_LOW_TEXTURE_DICT,
    HIGH_CARD_DICT,
    STRAIGHTNESS_DICT
    )

In [2]:
df = pd.read_parquet('flopdata.parquet')

In [3]:
df.head()

Unnamed: 0,flop,flop_encoded,suitedness,pairness,connectedness,high_low_texture,high_card,straightness,card1,card2,card3
0,"[2c, 2h, 2s]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0.0,1.0,0.0,0.0,0.0,0.0,2c,2h,2s
1,"[2d, 2h, 2s]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0.0,1.0,0.0,0.0,0.0,0.0,2d,2h,2s
2,"[2h, 2s, 3s]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",1.0,1.0,1.0,0.0,1.0,0.0,2h,2s,3s
3,"[2h, 2s, 3h]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",1.0,1.0,1.0,0.0,1.0,0.0,2h,2s,3h
4,"[2h, 2s, 3c]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",0.0,1.0,1.0,0.0,1.0,0.0,2h,2s,3c


In [4]:
batch_size = 64
data = FlopDataset(data=df)
trainloader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=flop_collate_fn)
ex_batch = next(iter(trainloader))

In [11]:
ex_batch[0].shape

torch.Size([64, 51])

In [12]:
class BaseVAE(nn.Module):
    def __init__(self, latent_dim = 6):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(in_features=51, out_features=40),
            nn.LeakyReLU(),
            nn.Linear(in_features=40, out_features=30),
            nn.LeakyReLU(),
            nn.Linear(in_features=30, out_features=20),
            nn.LeakyReLU()
        )
        
        self.fc_mu = nn.Linear(20, latent_dim)
        self.fc_log_var = nn.Linear(20, latent_dim) 
        
        self.decoder = nn.Sequential(
            nn.Linear(in_features=self.latent_dim, out_features=20),
            nn.LeakyReLU(),
            nn.Linear(in_features=20, out_features=30),
            nn.LeakyReLU(),
            nn.Linear(in_features=30, out_features=40),
            nn.LeakyReLU(),
            nn.Linear(in_features=40, out_features=51),
            nn.Sigmoid()
            
        )
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar) 
        eps = torch.randn_like(std)
        return mu + eps * std 
    
    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_log_var(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
        

In [62]:
def loss_function(x_recon, x, mu, log_var):
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_divergence



In [63]:
model = BaseVAE()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 100

In [64]:

for epoch in range(epochs):
    model.train() 
    tot_loss = 0
    for flop_vec, flop, suitedness,pairness, connectedness, high_low_texture, high_card, straightness in trainloader:
        
        optimizer.zero_grad()
        x = flop_vec.to(torch.float32)
        x_recon, mu, log_var = model(x) 
    
        loss = loss_function(x_recon, x, mu, log_var)
        tot_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_loss = tot_loss / len(trainloader)
    if (epoch % 10) == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

Epoch [1/100], Loss: 1204.5489
Epoch [11/100], Loss: 1016.5307
Epoch [21/100], Loss: 987.0770
Epoch [31/100], Loss: 966.7626
Epoch [41/100], Loss: 951.4088
Epoch [51/100], Loss: 946.6270
Epoch [61/100], Loss: 946.3825
Epoch [71/100], Loss: 944.7838
Epoch [81/100], Loss: 943.6098
Epoch [91/100], Loss: 942.5611


In [105]:
def generate_flop_comp(model, num_flops = 1, device="cpu"):
    model.eval()  

    z_sample = torch.randn(num_flops, model.latent_dim).to(device)

    with torch.no_grad():
        flop_recon = model.decoder(z_sample)
    return flop_recon


In [106]:
def genflop_to_binary(flop_recon):
    flop_binary = torch.zeros_like(flop_recon)

    flop_binary[:, :13] = torch.nn.functional.one_hot(flop_recon[:, :13].argmax(dim=1), num_classes=13)
    flop_binary[:, 13:17] = torch.nn.functional.one_hot(flop_recon[:, 13:17].argmax(dim=1), num_classes=4)
    
    flop_binary[:, 17:30] = torch.nn.functional.one_hot(flop_recon[:, 17:30].argmax(dim=1), num_classes=13)
    flop_binary[:, 30:34] = torch.nn.functional.one_hot(flop_recon[:, 30:34].argmax(dim=1), num_classes=4)
    
    flop_binary[:, 34:47] = torch.nn.functional.one_hot(flop_recon[:, 34:47].argmax(dim=1), num_classes=13)
    flop_binary[:, 47:] = torch.nn.functional.one_hot(flop_recon[:, 47:].argmax(dim=1), num_classes=4)

    return flop_binary

In [107]:
def decode_binary_flop(flop_batch):
    RANKS = [str(i) for i in range(2, 10)] + ['T', 'J', 'Q', 'K', 'A']
    SUITS = ['s', 'h', 'c', 'd']
    
    card1_ranks = np.argmax(flop_batch[:, :13], axis=1)
    card1_suits = np.argmax(flop_batch[:, 13:17], axis=1)
    
    card2_ranks = np.argmax(flop_batch[:, 17:30], axis=1)
    card2_suits = np.argmax(flop_batch[:, 30:34], axis=1)
    
    card3_ranks = np.argmax(flop_batch[:, 34:47], axis=1)
    card3_suits = np.argmax(flop_batch[:, 47:], axis=1)
    flops = [
        [f"{RANKS[card1_ranks[i]]}{SUITS[card1_suits[i]]}",
         f"{RANKS[card2_ranks[i]]}{SUITS[card2_suits[i]]}",
         f"{RANKS[card3_ranks[i]]}{SUITS[card3_suits[i]]}"]
        for i in range(flop_batch.shape[0])
    ]
    return flops

In [108]:
def generate_flop_human(model, num_flops = 1, device="cpu"):
    flop_recon = generate_flop_comp(model, num_flops, device)
    flop_bin = genflop_to_binary(flop_recon)
    decoded_flops = decode_binary_flop(flop_bin)
    return decoded_flops

In [113]:
generate_flop_human(model, num_flops=10)

[['3d', '3s', '7h'],
 ['9d', 'Kd', 'Qs'],
 ['Js', 'Tc', 'Ts'],
 ['9d', 'Kd', 'Qs'],
 ['2d', '8d', 'Th'],
 ['2d', '9d', 'Ah'],
 ['8d', '9d', 'Ah'],
 ['4d', '5d', '7h'],
 ['9d', 'Ad', 'Ts'],
 ['4s', 'Qc', 'Ts']]