In [3]:
import torch

from bh2vec.vectors_net import VectorsNetNorm
from bh2vec.tools import hand_to_vec, predict_tricks, vec_to_hand

In [4]:
net = VectorsNetNorm()
net.load_state_dict(torch.load("model.pth"))
net.eval()

VectorsNetNorm(
  (emb1): Linear(in_features=52, out_features=32, bias=True)
  (emb2): Linear(in_features=32, out_features=32, bias=True)
  (emb3): Linear(in_features=32, out_features=8, bias=True)
  (act): ELU(alpha=1.0)
  (batch_norm): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  (hid1): Linear(in_features=16, out_features=128, bias=True)
  (hid2): Linear(in_features=128, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=5, bias=True)
)

In [5]:
hand_n = 'AQ72.KT632.AJ3.9'
hand_s = 'K943.A.Q9876.KQ2'

In [6]:
tricks = predict_tricks(net, hand_n, hand_s).tolist()
print(f"Expected number of tricks for hands {hand_n} and {hand_s}:")
print("\n".join(list(f" {suit}: {num_tricks}" for suit, num_tricks in (zip('cdhsn', tricks)))))

Expected number of tricks for hands AQ72.KT632.AJ3.9 and K943.A.Q9876.KQ2:
 c: 7.047391891479492
 d: 11.495141983032227
 h: 9.340585708618164
 s: 11.595678329467773
 n: 10.315614700317383


In [13]:
print(f"Hand {hand_n} embedding:\n{hand_to_vec(net, hand_n)}")
print(f"Hand {hand_s} embedding:\n{hand_to_vec(net, hand_s)}")

Hand AQ72.KT632.AJ3.9 embedding:
[-1.5299157   1.4496934   0.38851678  0.99312484  1.1259247  -0.47265995
  0.85016537 -1.1033434 ]
Hand K943.A.Q9876.KQ2 embedding:
[ 1.3208004  -0.56439435 -1.1750376   1.2731268  -1.0533564  -0.34576744
 -0.54109645 -0.38238   ]


In [15]:
embedding = [1, 1, 1, 1, 0, 0, 0, 0]
nearest_hand = vec_to_hand(net, embedding)
print(f"Hand nearest to {embedding}: {nearest_hand}")

Hand nearest to [1, 1, 1, 1, 0, 0, 0, 0]: KQJ4.A8.K82.Q842


In [16]:
hand_n_embedding = hand_to_vec(net, hand_n)
opposite_embedding = hand_n_embedding * -1
opposite_hand = vec_to_hand(net, opposite_embedding)
print(f"Hand opposite to {hand_n}: {opposite_hand}")

Hand opposite to AQ72.KT632.AJ3.9: 432.J.T542.KJT97


In [20]:
stronger_features_hand = vec_to_hand(net, hand_n_embedding * 1.5)
weaker_features_hand = vec_to_hand(net, hand_n_embedding * 0.5)
print(f"Hand {hand_n} lies between {stronger_features_hand} and {weaker_features_hand}")

Hand AQ72.KT632.AJ3.9 lies between AK92.AQ6542.KJ4. and AT93.A987.QT9.JT
