In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd

from sklearn import model_selection
import time
import copy
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as  mpatches
from matplotlib.patches import Arc, Rectangle, ConnectionPatch
from matplotlib.offsetbox import  OffsetImage
import seaborn as sns

import sklearn.preprocessing as preprocessing
import tqdm

RANDOM_SEED = 43

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

In [None]:
import warnings
warnings.simplefilter("ignore", UserWarning)

In [None]:
filename = ''
proc_data = pd.read_pickle(filename)

In [None]:
N_PLAYERS = 20

In [None]:
# Number of shots hits each player
proc_data['poi_shots'] = proc_data.player_mask.apply(lambda x: np.sum(x==1))
top_players = proc_data.groupby('poi').poi_shots.sum().sort_values(ascending=False).head(N_PLAYERS).index

## Filter out rallies of top N players that have hit the most shots
proc_data = proc_data[proc_data.poi.isin(top_players)]

In [None]:
def get_state_act_pairs(row):
    poi = row['poi']
    
    sa_pairs = []
    
    for i in range(len(row['fv'])-1):
        shot = row['fv'][i]
        next_shot = row['fv'][i+1]
        if shot[0] == poi:
            pass
        if shot[0] != poi and next_shot[0] == poi:
            sa_pair = (shot, next_shot)
            sa_pairs.append(sa_pair)
    return sa_pairs

state_action_pairs = np.concatenate((proc_data.apply(get_state_act_pairs, axis=1).values), axis=0)

In [None]:
scaler = preprocessing.MinMaxScaler(feature_range=(-1, 1))
scaler.fit(state_action_pairs.reshape(-1, state_action_pairs.shape[-1])[..., 9:-1])

state_action_pairs[:, 0, 9:-1] = scaler.transform(state_action_pairs[:, 0, 9:-1])
state_action_pairs[:, 1, 9:-1] = scaler.transform(state_action_pairs[:, 1, 9:-1])

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, input_shape, output_shape, hidden_size=256, embedding_size=8):
        super().__init__()
        hidden_size = hidden_size
        embedding_size = embedding_size
        
        self.l1 = torch.nn.Linear(input_shape+2*embedding_size, hidden_size)
        self.l2 = torch.nn.Linear(hidden_size, hidden_size)
        
        #self.l3 = torch.nn.Linear(hidden_size, output_shape)
        self.mean_head = torch.nn.Linear(hidden_size, output_shape)
        self.log_std_head = torch.nn.Linear(hidden_size, output_shape)
        
        self.player_embeddings = torch.nn.Embedding(250, embedding_size)
        
    def forward(self, state, z, poi_idxs, opp_idxs):
        poi_embeddings = self.player_embeddings(poi_idxs.long())
        opp_embeddings = self.player_embeddings(opp_idxs.long())

        x = torch.cat((state, z, poi_embeddings, opp_embeddings), dim=-1).float()
        
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        mean = self.mean_head(x)
        log_std = self.log_std_head(x).clamp(min=-20, max=2)
        
        return mean, log_std
    
    def sample(self, state, z, poi_idxs, opp_idxs):
        mean, log_std = self.forward(state, z, poi_idxs, opp_idxs)
        std = log_std.exp()
        
        normal = torch.distributions.normal.Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log((1 - y_t.pow(2)) + 1e-5)
        log_prob = log_prob.sum(-1, keepdim=True)
        
        return action, log_prob, torch.tanh(mean)
    
    def get_dist(self, state, z, poi_idxs, opp_idxs):
        mean, log_std = self.forward(state, z, poi_idxs, opp_idxs)
        std = log_std.exp()
        
        normal = torch.distributions.normal.Normal(mean, std)
        
        return normal
    
    
class Discriminator(torch.nn.Module):
    def __init__(self, input_shape, hidden_size=256, embedding_size=8):
        super().__init__()
        hidden_size = hidden_size
        embedding_size = embedding_size
        
        self.l1 = torch.nn.Linear(input_shape+2*embedding_size, hidden_size)
        self.l2 = torch.nn.Linear(hidden_size, hidden_size)
        self.l3 = torch.nn.Linear(hidden_size, 1)
        
        self.player_embeddings = torch.nn.Embedding(250, embedding_size)
        
    def forward(self, state, action, poi_idxs, opp_idxs):
        poi_embeddings = self.player_embeddings(poi_idxs.long())
        opp_embeddings = self.player_embeddings(opp_idxs.long())
        
        x = torch.cat((state, action, poi_embeddings, opp_embeddings), dim=-1).float()
        
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

hidden_size = 256
embedding_size = 8
latent_size = 8

generator = Generator(state_action_pairs.shape[-1]+latent_size-1, len(cols_of_interest), hidden_size=hidden_size, embedding_size=embedding_size).to(device)
discriminator = Discriminator(state_action_pairs.shape[-1]-1+len(cols_of_interest), hidden_size=hidden_size, embedding_size=embedding_size).to(device)

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print('Generator params: ', get_n_params(generator))
print('Discriminator params: ', get_n_params(discriminator))

In [None]:
trainLoader = torch.utils.data.DataLoader(state_action_pairs, batch_size=256, shuffle=True, pin_memory=True)

In [None]:
optimizerG = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

generator_losses = []
discriminator_losses = []

tau = 0.4
alpha = 0.9

In [None]:
%%time

pbar = tqdm.tqdm(range(1000))

I = 100

for epoch in pbar:
    for iteration, sample in enumerate(trainLoader):
        sample = torch.tensor(sample).to(device).float()
        state = sample[:, 0]
        action = sample[:, 1]
        
        opp_idxs = state[..., 0]
        state = state[..., 1:]
        poi_idxs = action[..., 0]
        action = action[..., cols_of_interest]
        
        # Train discriminator
        with torch.no_grad():
            z = torch.randn(state.size(0), latent_size).to(device)
            fake_action, _, _ = generator.sample(state, z, poi_idxs, opp_idxs)
        
        real_weight = discriminator(state, action, poi_idxs, opp_idxs)
        fake_weight = discriminator(state, fake_action, poi_idxs, opp_idxs)
        
        target_dist = torch.distributions.normal.Normal(action, torch.tensor([0.3]).to(device))
        
        #discriminator_loss = ((real_weight-1)**2).mean() + (fake_weight**2).mean()
        target_score = 1+alpha*(tau*target_dist.log_prob(fake_action).sum(dim=-1, keepdim=True)).clip(max=0, min=-1)
        discriminator_loss = ((real_weight-1)**2).mean() + ((fake_weight - target_score)**2).mean()
        
        optimizerD.zero_grad()
        discriminator_loss.backward()
        optimizerD.step()
        
        discriminator_losses.append(discriminator_loss.item())
        
        # Train generator        
        z = torch.randn(state.size(0), latent_size).to(device)
        gen_action, log_prob, _ = generator.sample(state, z, poi_idxs, opp_idxs)
        #gen_acts = sample[:, 1].clone()
        #gen_acts[..., cols_of_interest] = gen_action
        #gen_acts = gen_acts[..., 1:]
        gen_weight = discriminator(state, gen_action, poi_idxs, opp_idxs)
        
        #generator_loss = ((gen_weight-1)**2).mean() + tau * log_prob.mean()
        generator_loss = (-gen_weight).mean() + tau * log_prob.mean()
        
        optimizerG.zero_grad()
        generator_loss.backward()
        optimizerG.step()
        
        generator_losses.append(generator_loss.item())
        
        mse_loss = ((sample[:, 1][..., cols_of_interest] - gen_action)**2).mean() # Logging
        
        if iteration % I == 0:
            pbar.set_postfix({'GenL': np.mean(generator_losses[-I:]), 
                              'DisL': np.mean(discriminator_losses[-I:]),
                              'MSE': mse_loss.item(),
                              'RW': real_weight.mean().item(),
                              'GW': fake_weight.mean().item(),})
