In [None]:
from pathlib import Path
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import Datasets.ModelData as md
from session import *
from LR_Schedule.lr_find import lr_find
from callbacks import *
from validation import *
from validation import _AccuracyMeter
from LR_Schedule.cos_anneal import CosAnneal

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
idtohand = {0: "Nothing in hand; not a recognized poker hand",
            1: "One pair; one pair of equal ranks within five cards",
            2: "Two pairs; two pairs of equal ranks within five cards",
            3: "Three of a kind; three equal ranks within five cards",
            4: "Straight; five cards, sequentially ranked with no gaps",
            5: "Flush; five cards with the same suit",
            6: "Full house; pair + different rank three of a kind",
            7: "Four of a kind; four equal ranks within five cards",
            8: "Straight flush; straight + flush",
            9: "Royal flush; {Ace, King, Queen, Jack, Ten} + flush" }

num_classes = 10

idtosuit = {0: "Hearts", 1: "Spades", 2: "Diamonds", 3: "Clubs"}
idtorank = {0: "Ace", 
            1: "2", 
            2: "3", 
            3: "4", 
            4: "5", 
            5: "6", 
            6: "7", 
            7: "8", 
            8: "9", 
            9: "10", 
            10: "Jack", 
            11: "Queen", 
            12: "King"}

In [None]:
class SuitShiftTransform():
    def transform_x(self, hand):
        return hand[:,torch.randperm(4)]

In [None]:
class PokerHandDataset(Dataset):
    def __init__(self, file, tfm=None, percentage=1, balanced=False):
        self.hands, self.labels = self.parse_csv(file, percentage)
        self.tfm = tfm
        
        if balanced:
            counts = np.zeros(num_classes)

            for label in self.labels:
                counts[label] += 1

            class_weights = len(self.labels) / counts

            instance_weights = np.zeros(len(self.labels))

            for idx, label in enumerate(self.labels):
                instance_weights[idx] = class_weights[label]

            self.sampler = torch.utils.data.sampler.WeightedRandomSampler(instance_weights, len(instance_weights))
        else:
            self.sampler = None
       
    @staticmethod
    def parse_csv(file, percentage=1):
        df = pd.read_csv(file, header=None)
        xs = df[df.columns[0:10]]
        ys = df[df.columns[-1]]
        
        labels = np.array([label for label in ys])
        hands = np.array([np.array(hand).reshape((5,2)) - 1 for hand in xs.values])
        
        if percentage < 1:
            idxs = np.random.choice(len(xs), int(len(xs) * percentage), replace=False)
            hands = hands[idxs]
            labels = labels[idxs]
            
        return hands, ys
        
    @staticmethod
    def make_tensor(hand):
        hand_tensor = torch.zeros((4,13)) - 1
        for card in hand:       
            hand_tensor[card[0], card[1]] = 1
        return hand_tensor.unsqueeze(0)
    
    @staticmethod
    def make_one_hot(label):
        tensor = torch.zeros(10)
        tensor[label] = 1
        return tensor
    
    def __len__(self): return len(self.hands)

    def __getitem__(self, i):
        hand, label = self.hands[i], self.labels[i]
        x, y = self.make_tensor(hand), self.make_one_hot(label)
        
        if self.tfm is not None:
            x = self.tfm.transform_x(x)
            
        return x, y

In [None]:
data_path = Path("./Datasets/PokerHands")
train_dataset = PokerHandDataset(data_path/'training.csv', tfm=SuitShiftTransform())
test_dataset = PokerHandDataset(data_path/'testing.csv')

In [None]:
train_dataset[0]

## NN

In [None]:
data = md.ModelData({'train': train_dataset}, 32)
test = md.ModelData({'test': test_dataset}, 256)

In [None]:
next(iter(data['train']))

In [None]:
class Network(nn.Module):
    def __init__(self):
        super().__init__();
        self.conv1 = nn.Conv2d(1, 32, (4,1))       
        self.conv2 = nn.Conv2d(1, 32, (1,13))
        self.relu1 = nn.ReLU(inplace=True)
        self.drop1 = nn.Dropout(p=.6)
        self.fc1 = nn.Linear(32*13 + 32*4, 100)
        self.relu2 = nn.ReLU(inplace=True)
        self.drop2 = nn.Dropout(p=.6)
        self.fc2 = nn.Linear(100, num_classes)
        
    def forward(self, x):
        x_1 = self.conv1(x)
        x_2 = self.conv2(x)
        x = torch.cat([x_1.view(x_1.size(0), -1), x_2.view(x_2.size(0), -1)], dim=1)
        x = self.relu1(x)
        x = self.drop1(x)
        x = self.fc1(x)
        x = self.relu2(x)
        x = self.drop2(x)
        
        x = self.fc2(x)
        return F.softmax(x, dim=1)

model = Network()

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def forward(self, pred, label):
        weight = self.get_weight(pred[:,1:], label[:,1:])
        return F.binary_cross_entropy_with_logits(pred[:,1:], label[:,1:], weight)
    
    def get_weight(self, x, t):
        alpha, gamma = 0.25, 2
        p = x
        pt = p * t + (1-p) * (1-t)
        w = alpha * t + (1-alpha) * (1-t)
        return w * (1-pt).pow(gamma)

In [None]:
class PokerHandAccuracy(_AccuracyMeter):
    def __init__(self):
        self.reset()
        self.confusion = [[0 for y in range(num_classes)] for x in range(num_classes)]

    def reset(self):
        self.num_correct = 0
        self.count = 0
        self.confusion = [[0 for y in range(num_classes)] for x in range(num_classes)]
        
    def accuracy(self): 
        return self.num_correct / self.count

    def update(self, actn, label, log=False):   
        vals, preds = torch.max(actn, 1)
        if log: print("Preds: ", preds)
        _, gt = torch.max(label, 1)     
        if log: print("GT   : ", gt)
        self.num_correct += torch.sum(preds == gt).item()
        if log: print("Num Correct: ", self.num_correct)
        self.count += label.shape[0]
        
        for lab, pred in zip(gt, preds):
            self.confusion[lab][pred] += 1
            
    def plot_confusion_matrix(self,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
        cm = np.array(self.confusion)
        fig, ax = plt.subplots(figsize=(12,12))
        img = ax.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.colorbar(img, ax=ax)
        ax.set_title(title)
        tick_marks = np.arange(num_classes)
        plt.xticks(tick_marks, rotation=45)
        plt.yticks(tick_marks)

        fmt = 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            ax.text(j, i, format(cm[i, j], fmt),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

        ax.set_ylabel('True label')
        ax.set_xlabel('Predicted label')

In [None]:
criterion = nn.BCELoss()
optim_fn = optim.Adam
sess = Session(model, criterion, optim_fn, 1e-2)

In [None]:
x, y = next(iter(data['train']))
x[0], y[0]

In [None]:
actn = sess.forward(x)
accuracy = PokerHandAccuracy()

In [None]:
accuracy.update(actn, y, log=True)

In [None]:
accuracy.plot_confusion_matrix()

In [None]:
lr_find(sess, data['train'], start_lr=1e-4, end_lr=2)

In [None]:
sess.set_lr(1e-2)

In [None]:
accuracy = PokerHandAccuracy()
validator = Validator(data['test'], accuracy)
# lr_schedule = CosAnneal(len(data['train']), lr_min=1e-5, T_mult=2)
schedule = TrainingSchedule(data['train'])

In [None]:
sess.train(schedule, 5)

In [None]:
validator.on_epoch_end(sess, LossMeter())

In [None]:
accuracy.plot_confusion_matrix()

In [None]:
lr_schedule.plot()

In [None]:
sess.save("PokerHand-99.6")

In [None]:
t_accuracy = PokerHandAccuracy()
t_validator = Validator(data['train'], t_accuracy)

In [None]:
t_validator.on_epoch_end(sess, LossMeter())

In [None]:
t_accuracy.plot_confusion_matrix()

In [None]:
len(data['train'])

In [None]:
len(data['test'])