In [1]:
from pgn2gif import chess
import numpy as np
from openTSNE import TSNE
from openTSNE.callbacks import ErrorApproximations
from matplotlib import pyplot as plt
from scipy import interpolate
import re

FILE_NAME = '/mnt/d/Work/CG Institute/chess/lichess data/lichess_db_standard_rated_2021-08.pgn'

In [2]:
def state_to_vector(state):
    piece_dict = {
        'wr': [1,0,0,0,0,0,0,0,0,0,0,0,0],
        'wn': [0,1,0,0,0,0,0,0,0,0,0,0,0],
        'wb': [0,0,1,0,0,0,0,0,0,0,0,0,0],
        'wk': [0,0,0,1,0,0,0,0,0,0,0,0,0],
        'wq': [0,0,0,0,1,0,0,0,0,0,0,0,0],
        'wp': [0,0,0,0,0,1,0,0,0,0,0,0,0],
        'br': [0,0,0,0,0,0,1,0,0,0,0,0,0],
        'bn': [0,0,0,0,0,0,0,1,0,0,0,0,0],
        'bb': [0,0,0,0,0,0,0,0,1,0,0,0,0],
        'bk': [0,0,0,0,0,0,0,0,0,1,0,0,0],
        'bq': [0,0,0,0,0,0,0,0,0,0,1,0,0],
        'bp': [0,0,0,0,0,0,0,0,0,0,0,1,0],
        '':   [0,0,0,0,0,0,0,0,0,0,0,0,1],
    }    
    state_list = list(state.values())    
    vector = []
    for piece in state_list:
        vector.append(piece_dict[piece])
    return np.array(vector).ravel()

def vector_to_state(vector):
    vec_dict = {
        '1000000000000': "wr",
        '0100000000000': "wn",
        '0010000000000': "wb",
        '0001000000000': "wk",
        '0000100000000': "wq",
        '0000010000000': "wp",
        '0000001000000': "br",
        '0000000100000': "bn",
        '0000000010000': "bb",
        '0000000001000': "bk",
        '0000000000100': "bq",
        '0000000000010': "bp",
        '0000000000001': ""
    }
    
    return vec_dict[vector]

In [3]:
def game_to_vectors(file):
    game = chess.ChessGame(file)
    vectors = [state_to_vector(game.state)]
    while not game.is_finished:
        try:
            game.next()
        except:
            pass
        vectors.append(state_to_vector(game.state))
    return np.stack(vectors)

In [4]:
def get_moves_from_pgn(pgn):
    with open(pgn) as p:
        data = p.read()
        data = re.sub(r'\{.*?\}', '', data)  # Removes pgn comments
        moves = re.findall(
            r'[a-h]x?[a-h]?[1-8]=?[BKNRQ]?|O-O-?O?|[BKNRQ][a-h1-8]?[a-h1-8]?x?[a-h][1-8]',
            data)
        return [move.replace('x', '') for move in moves]

# TODO all metadata

In [5]:
def get_metadata_from_pgn(pgn):
    with open(pgn) as p:
        data = p.read()
        data = re.sub(r'\{.*?\}', '', data)  # Removes pgn comments
        m = re.findall(r'\[(.*) "(.*)"]',data)
        metadata_keys = [i[0] for i in m]
        metadata_values = [i[1].replace(',',';') for i in m]
        return dict(zip(metadata_keys, metadata_values))
        
#         metadata = {}
#         metadata['white'] = re.findall(
#             r'\[White "(.*)"]',
#             data)[0]
#         metadata['black'] = re.findall(
#             r'\[Black "(.*)"]',
#             data)[0]
#         metadata['result'] = re.findall(
#             r'\[Result "(.*)"]',
#             data)[0]
#         return metadata

In [6]:
s = """[White \"abc\"]
[Black \"xyz,asd\"]
1. e4 { [%eval 0.24] [%clk 0:03:00] } 1... c5 { [%eval 0.32] [%clk 0:03:00] } 2. c3 { [%eval 0.0] [%clk 0:02:56] } 2... Nc6 { [%eval 0.0] [%clk 0:03:00] } 3. Nf3 { [%eval 0.12] [%clk 0:02:56] } 3... b5? { [%eval 2.4] [%clk 0:03:00] } 4. Na3? { [%eval 0.22] [%clk 0:02:53] } 4... Rb8? { [%eval 2.46] [%clk 0:02:59] } 5. Bxb5 { [%eval 2.5] [%clk 0:02:41] } 5... Na5? { [%eval 3.78] [%clk 0:02:51] } 6. Ng5? { [%eval 2.32] [%clk 0:02:23] } 6... Nf6 { [%eval 2.78] [%clk 0:02:50] } 7. f3? { [%eval 0.72] [%clk 0:02:20] } 7... h6 { [%eval 0.78] [%clk 0:02:48] } 8. Nh3 { [%eval 0.96] [%clk 0:02:16] } 8... Rg8? { [%eval 2.47] [%clk 0:02:48] } 9. b4?! { [%eval 1.5] [%clk 0:02:17] } 9... cxb4 { [%eval 1.58] [%clk 0:02:47] } 10. cxb4 { [%eval 1.42] [%clk 0:02:16] } 10... Nb7? { [%eval 2.77] [%clk 0:02:42] } 11. d4 { [%eval 2.99] [%clk 0:02:17] } 11... Nd6? { [%eval 4.41] [%clk 0:02:40] } 12. Nf2?? { [%eval -2.37] [%clk 0:02:17] } 12... a6?? { [%eval 4.93] [%clk 0:02:36] } 13. Ba4 { [%eval 4.01] [%clk 0:02:10] } 13... g5 { [%eval 4.85] [%clk 0:02:24] } 14. O-O?! { [%eval 3.86] [%clk 0:02:04] } 14... Bb7? { [%eval 6.5] [%clk 0:02:12] } 15. e5 { [%eval 6.68] [%clk 0:02:03] } 15... Nd5 { [%eval 6.58] [%clk 0:01:55] } 16. exd6 { [%eval 6.67] [%clk 0:02:01] } 16... exd6 { [%eval 8.05] [%clk 0:01:56] } 17. Ne4 { [%eval 7.37] [%clk 0:01:58] } 17... Rc8 { [%eval 9.93] [%clk 0:01:54] } 18. f4?? { [%eval 6.18] [%clk 0:01:42] } 18... gxf4 { [%eval 5.94] [%clk 0:01:53] } 19. Bxf4?? { [%eval 0.62] [%clk 0:01:37] } 19... Nxf4?? { [%eval 5.44] [%clk 0:01:53] } 20. Rxf4 { [%eval 4.88] [%clk 0:01:39] } 20... Qh4?? { [%eval 15.7] [%clk 0:01:47] } 21. Rxh4 { [%eval 15.79] [%clk 0:01:40] } 21... d5 { [%eval 19.78] [%clk 0:01:35] } 22. Nc3 { [%eval 14.09] [%clk 0:01:34] } 22... Bxb4 { [%eval 15.79] [%clk 0:01:35] } 23. Qb3 { [%eval 16.01] [%clk 0:01:24] } 23... Bxa3?! { [%eval #9] [%clk 0:01:30] } 24. Qxa3?! { [%eval 20.1] [%clk 0:01:23] } 24... Rd8?! { [%eval #1] [%clk 0:01:18] } 25. Nxd5 { [%eval #8] [%clk 0:01:22] } 25... Bxd5 { [%eval #8] [%clk 0:01:11] } 26. Rxh6 { [%eval #7] [%clk 0:01:11] } 26... Bxg2 { [%eval #7] [%clk 0:01:11] } 27. h3?! { [%eval 19.96] [%clk 0:00:50] } 27... Bxh3+?! { [%eval #8] [%clk 0:01:08] } 28. Kh2 { [%eval #7] [%clk 0:00:43] } 28... Bf5 { [%eval #4] [%clk 0:01:00] } 29. Rf1 { [%eval #7] [%clk 0:00:43] } 29... Be6 { [%eval #5] [%clk 0:00:55] } 30. Qe3?! { [%eval 18.7] [%clk 0:00:33] } 30... Rc8 { [%eval 42.63] [%clk 0:00:54] } 31. Bb3 { [%eval 15.13] [%clk 0:00:29] } 31... Rc7 { [%eval 16.35] [%clk 0:00:24] } 32. d5 { [%eval 15.05] [%clk 0:00:29] } 1-0"""

m = re.findall(r'\[(.*) "(.*)"]',s)

metadata_keys = [i[0] for i in m]
metadata_values = [i[1].replace(',',';') for i in m]
print(metadata_keys)
print(metadata_values)
d = dict(zip(metadata_keys, metadata_values))
print([k for k in d])
print([d[k] for k in d])


['White', 'Black']
['abc', 'xyz;asd']
['White', 'Black']
['abc', 'xyz;asd']


In [7]:
print("start loading")

start loading


# Extract individual games from PGN file

In [8]:
lines_to_read = 500000

In [9]:
from pathlib import Path
Path("games").mkdir(parents=True, exist_ok=True)

with open(FILE_NAME, 'r') as f:
    all_games = ''.join([f.readline() for i in range(lines_to_read)])

span = 2
all_games = all_games.split("\n\n")
split_games  = ["\n\n".join(all_games[i:i+span]) for i in range(0, len(all_games), span)]
print(split_games[-2])
print(len(split_games))

[Event "Rated Rapid game"]
[Site "https://lichess.org/NCnCFGfQ"]
[Date "2021.08.01"]
[Round "-"]
[White "milesurquhart"]
[Black "rafaahasan"]
[Result "0-1"]
[UTCDate "2021.08.01"]
[UTCTime "00:18:39"]
[WhiteElo "1844"]
[BlackElo "1716"]
[WhiteRatingDiff "-89"]
[BlackRatingDiff "+7"]
[ECO "B00"]
[Opening "Ware Defense"]
[TimeControl "600+5"]
[Termination "Normal"]

1. e4 { [%clk 0:10:00] } 1... a5 { [%clk 0:10:00] } 2. Bc4 { [%clk 0:09:42] } 2... e6 { [%clk 0:09:59] } 3. a3 { [%clk 0:09:40] } 3... c6 { [%clk 0:10:00] } 4. Qe2 { [%clk 0:09:12] } 4... b5 { [%clk 0:10:01] } 5. Bd3 { [%clk 0:08:55] } 5... h6 { [%clk 0:09:58] } 6. e5 { [%clk 0:08:44] } 6... g5 { [%clk 0:09:58] } 7. Be4 { [%clk 0:08:07] } 7... Bg7 { [%clk 0:09:49] } 8. d4 { [%clk 0:07:57] } 8... Ne7 { [%clk 0:09:50] } 9. Nf3 { [%clk 0:07:41] } 9... Bb7 { [%clk 0:09:52] } 10. Nfd2 { [%clk 0:07:27] } 10... O-O { [%clk 0:09:50] } 11. Nb3 { [%clk 0:07:22] } 11... d5 { [%clk 0:09:51] } 12. Bf3 { [%clk 0:07:10] } 12... Nd7 { [%clk 

## filter out games without eval score or without clk

In [10]:
filtered = []
for game in split_games:
    if 'eval' in game and 'clk' in game:
        filtered.append(game)
split_games = filtered
print(len(split_games))

2106


## keep n games

In [11]:
n_games = 2000
split_games = split_games[:n_games]
print(len(split_games))
print(split_games[-1])

2000
[Event "Rated Bullet game"]
[Site "https://lichess.org/WiJEcrEp"]
[Date "2021.08.01"]
[Round "-"]
[White "henryln1"]
[Black "OturanGeyik"]
[Result "1-0"]
[UTCDate "2021.08.01"]
[UTCTime "00:17:46"]
[WhiteElo "1659"]
[BlackElo "1638"]
[WhiteRatingDiff "+5"]
[BlackRatingDiff "-5"]
[ECO "A22"]
[Opening "English Opening: King's English Variation, Two Knights Variation, Reversed Dragon"]
[TimeControl "60+0"]
[Termination "Time forfeit"]

1. c4 { [%eval 0.2] [%clk 0:01:00] } 1... e5 { [%eval 0.12] [%clk 0:01:00] } 2. Nc3 { [%eval 0.16] [%clk 0:00:59] } 2... Nf6 { [%eval 0.32] [%clk 0:00:58] } 3. g3 { [%eval 0.0] [%clk 0:00:59] } 3... d5 { [%eval 0.26] [%clk 0:00:58] } 4. cxd5 { [%eval 0.18] [%clk 0:00:57] } 4... Bf5? { [%eval 2.21] [%clk 0:00:57] } 5. Bg2 { [%eval 2.16] [%clk 0:00:56] } 5... Bd6 { [%eval 2.28] [%clk 0:00:56] } 6. Nf3 { [%eval 1.92] [%clk 0:00:55] } 6... Nbd7? { [%eval 3.19] [%clk 0:00:55] } 7. O-O { [%eval 2.99] [%clk 0:00:54] } 7... h6 { [%eval 2.93] [%clk 0:00:53] } 8

## store individual games as PGN files

In [12]:
for i in range(len(split_games)):
    with open('games/game-{:05d}.pgn'.format(i+1),'w') as f:
        f.write(split_games[i])

In [13]:
notrandgames = ['games/game-{:05d}.pgn'.format(n+1) for n in range(10000)]
print(len(notrandgames))

10000


# Loading games from individual PGN files

In [14]:
notrandgames_checked = []
metadata = []
old_md_keys = None
for id, g in enumerate(notrandgames):
    try:
        game_to_vectors(g)
    except:
        pass
    else:
        notrandgames_checked.append((id,g))
        metadata_dict = get_metadata_from_pgn(g)
        # get least common denominator among keys in all samples such that there aren't outlier samples that have more metadata than others
        md_keys = [k for k in metadata_dict]
        if old_md_keys:
            md_keys = list(set(md_keys).intersection(old_md_keys))
        old_md_keys = md_keys
        metadata.append(metadata_dict)
        
# remove outlier metadata such that only shared metadata among all samples remains
for d in metadata:
    keys = [k for k in d]
    dif = list(set(keys) - set(md_keys))
    for k in dif:
        d.pop(k)
        
print(md_keys)
print(metadata[:100])

['Opening', 'TimeControl', 'UTCTime', 'White', 'Event', 'Black', 'ECO', 'BlackElo', 'UTCDate', 'Site', 'Termination', 'WhiteElo', 'Round', 'Result', 'Date']
[{'Event': 'Rated Blitz tournament https://lichess.org/tournament/zTLnP8ob', 'Site': 'https://lichess.org/q5HJFu3Z', 'Date': '2021.08.01', 'Round': '-', 'White': 'Gersonz', 'Black': 'Scheyla_Perdomo26', 'Result': '1-0', 'UTCDate': '2021.08.01', 'UTCTime': '00:00:24', 'WhiteElo': '1552', 'BlackElo': '1321', 'ECO': 'B22', 'Opening': 'Sicilian Defense: Alapin Variation', 'TimeControl': '180+2', 'Termination': 'Time forfeit'}, {'Event': 'Rated Bullet game', 'Site': 'https://lichess.org/hsdob0QP', 'Date': '2021.08.01', 'Round': '-', 'White': 'zoki-pantelic', 'Black': 'hbju', 'Result': '1-0', 'UTCDate': '2021.08.01', 'UTCTime': '00:00:25', 'WhiteElo': '1789', 'BlackElo': '1793', 'ECO': 'C24', 'Opening': "Bishop's Opening: Berlin Defense", 'TimeControl': '60+0', 'Termination': 'Normal'}, {'Event': 'Rated Bullet game', 'Site': 'https://lic

In [15]:
d1 = {'a': 1, 'b': 2, 'c': 3}
d2 = {'a': 1, 'b': 2}
d1k = [k for k in d1]
d2k = [k for k in d2]
keys = list(set(d1k).difference(d2k))
print(keys)
for k in keys:
    d1.pop(k)
print(d1)
    

['c']
{'a': 1, 'b': 2}


In [16]:
# add all games regardless of first move
firstmoves = [(g[0],get_moves_from_pgn(g[1])[0]) for g in notrandgames_checked]
indices = []
for idx, fm in firstmoves:
    indices.append(idx)
games = ['games/game-{:05d}.pgn'.format(n+1) for n in np.array(indices)]
game_matrices = [game_to_vectors(g) for g in games]

### handling an error where the last 2 states of each game are equivalent - remove redundant one

In [17]:
# game matrices is num_games * num_turns * 832 
for game in range(len(game_matrices)):
    # check for each game whether last 2 game states are equivalent
    if np.all(game_matrices[game][-2] == game_matrices[game][-1]):
        # if so, remove the last state
        game_matrices[game] = game_matrices[game][:-1]

## keep opening moves only
https://www.365chess.com/eco.php lists the longest openings at 10 moves, therefore keep 10 moves for each game

In [18]:
game_matrices = [x[:10] for x in game_matrices]

In [19]:
final_data = np.concatenate(game_matrices)

In [20]:
tsne = TSNE(
    perplexity=200,
    n_jobs=6,
    metric='euclidean',
    random_state=42
)

In [21]:
%time embedding = tsne.fit(np.array(final_data))

CPU times: user 8min 8s, sys: 10.9 s, total: 8min 19s
Wall time: 1min 51s


In [22]:
embedding_split = np.array_split(embedding, np.add.accumulate([len(l) for l in game_matrices]))[:-1]

In [23]:
# fig = plt.figure(figsize=(15,15))
# ax = fig.add_subplot(111)
# ax.set_aspect(1)
# for game in embedding_split[:10]:
#     tck, u = interpolate.splprep(game.transpose(), s=0)
#     unew = np.arange(0, 1.01, 0.01)
#     out = interpolate.splev(unew, tck)
#     ax.plot(out[0], out[1], '-r', alpha=0.03, color='red')
#     ax.scatter(game[:,0], game[:,1], s=0.1, color='red')
# #for game in embedding_split[10:800]:
# #    tck, u = interpolate.splprep(game.transpose(), s=0)
# #    unew = np.arange(0, 1.01, 0.01)
# #    out = interpolate.splev(unew, tck)
# #    ax.plot(out[0], out[1], '-r', alpha=0.03, color='blue')
# #    ax.scatter(game[:,0], game[:,1], s=0.1, color='blue')
# plt.xlim((-40,50));
# plt.ylim((-60,40));
# #for game in embedding_split[100:]:
# #    ax.plot(game[:,0], game[:,1], '-r', alpha=0.1, color='blue')


# write header
csv = open("lichess_tsne.csv", "w")
features = "x,y,line,cp,algo,player,age,"
# TODO join regex metadata with comma separator
print(md_keys)
features += ','.join(md_keys)
features += ",a8,b8,c8,d8,e8,f8,g8,h8,a7,b7,c7,d7,e7,f7,g7,h7,a6,b6,c6,d6,e6,f6,g6,h6,a5,b5,c5,d5,e5,f5,g5,h5,a4,b4,c4,d4,e4,f4,g4,h4,a3,b3,c3,d3,e3,f3,g3,h3,a2,b2,c2,d2,e2,f2,g2,h2,a1,b1,c1,d1,e1,f1,g1,h1"
csv.write(features)
csv.write("\n")
idx = 0

# for gameIndex, game in enumerate(embedding_split[:450]):
for gameIndex, game in enumerate(embedding_split):
    pi = 0
    for idx, pos in enumerate(game):
        csv.write(str(pos[0]))
        csv.write(",")
        csv.write(str(pos[1]))
        
        # number of game
        csv.write(",")
        csv.write(str(gameIndex))
        
        # checkpoint
        csv.write(",")
        if idx == 0:
            csv.write("1")
        elif idx == len(game) - 1:
            csv.write("1")
        else:
            csv.write("0")
        
        
        
        # method, in this case the starting move
        csv.write(",")
        # TODO replace with which player wins this game
        winner = metadata[gameIndex]['Result']
        winner = winner.replace('1-0', metadata[gameIndex]['White'])
        winner = winner.replace('0-1', metadata[gameIndex]['Black'])
        csv.write(winner)

        # player - whose turn is it
        csv.write(",")
        if idx % 2 == 0:
            csv.write(metadata[gameIndex]['Black'])
        else:
            csv.write(metadata[gameIndex]['White'])
        
        # age
        csv.write(",")
        csv.write(str(idx))
        csv.write(",")
        
        # TODO all metadata from regex
        md_values = [metadata[gameIndex][k] for k in metadata[gameIndex]]
        csv.write(','.join(md_values))
        
        for n in range(0, 64):
            csv.write(",")
            str1 = ''.join(str(e) for e in game_matrices[gameIndex][idx][n * 13: (n+1) * 13])
            csv.write(vector_to_state(str1))
            
        csv.write("\n")
    gameIndex += 1
        
csv.close()

['Opening', 'TimeControl', 'UTCTime', 'White', 'Event', 'Black', 'ECO', 'BlackElo', 'UTCDate', 'Site', 'Termination', 'WhiteElo', 'Round', 'Result', 'Date']


In [24]:
import umap.umap_ as umap
from matplotlib import pyplot as plt
np.random.seed(0)

In [25]:
reducer = umap.UMAP()

In [26]:
umap_embedding = reducer.fit_transform(final_data)

In [27]:
umap_embedding_split = np.array_split(umap_embedding, np.add.accumulate([len(l) for l in game_matrices]))[:-1]

In [28]:
import pandas as pd
df = pd.read_csv('lichess_tsne.csv')
df.head()

Unnamed: 0,x,y,line,cp,algo,player,age,Opening,TimeControl,UTCTime,...,g2,h2,a1,b1,c1,d1,e1,f1,g1,h1
0,-11.335345,-11.39539,0,1,Gersonz,Scheyla_Perdomo26,0,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
1,5.156773,-14.059902,0,0,Gersonz,Gersonz,1,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
2,5.155588,-2.414431,0,0,Gersonz,Scheyla_Perdomo26,2,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
3,4.304133,-2.025064,0,0,Gersonz,Gersonz,3,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
4,4.662256,-1.438602,0,0,Gersonz,Scheyla_Perdomo26,4,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr


In [29]:
df['x'] = umap_embedding[:,0]
df['y'] = umap_embedding[:,1]
df.head()
# store

Unnamed: 0,x,y,line,cp,algo,player,age,Opening,TimeControl,UTCTime,...,g2,h2,a1,b1,c1,d1,e1,f1,g1,h1
0,11.815928,17.154633,0,1,Gersonz,Scheyla_Perdomo26,0,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
1,1.292129,15.669477,0,0,Gersonz,Gersonz,1,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
2,17.470268,-5.515378,0,0,Gersonz,Scheyla_Perdomo26,2,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
3,3.014392,3.620153,0,0,Gersonz,Gersonz,3,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr
4,2.951796,3.480122,0,0,Gersonz,Scheyla_Perdomo26,4,Rated Blitz tournament https://lichess.org/tou...,https://lichess.org/q5HJFu3Z,2021.08.01,...,wp,wp,wr,wn,wb,wq,wk,wb,wn,wr


In [30]:
df.to_csv('lichess_umap_seed0.csv')

In [31]:
# fig = plt.figure(figsize=(8,8))
# ax = fig.add_subplot(111)
# ax.set_aspect(1)
# for game in umap_embedding_split[:100]:
#     ax.plot(game[:,0], game[:,1], '-r', alpha=0.1, color='red')
# for game in umap_embedding_split[100:]:
#     ax.plot(game[:,0], game[:,1], '-r', alpha=0.1, color='blue')

## default parameters for UMAP over multiple seeds

In [32]:
# import pandas as pd
# import numpy as np

# for seed in range(20):
#     np.random.seed(seed)
#     reducer = umap.UMAP()
#     umap_embedding = reducer.fit_transform(final_data)
#     df = pd.read_csv('tsne.csv')
#     df['x'] = umap_embedding[:,0]
#     df['y'] = umap_embedding[:,1]
#     save_path = 'umap_seed_'+str(seed)+'.csv'
#     print('storing to', save_path)
#     df.to_csv(save_path)

## UMAP Hparam Search

In [33]:
# import pandas as pd
# import umap.umap_ as umap
# from matplotlib import pyplot as plt
# OUTPUT_FILE_NAME = 'umap_outputs/'

In [34]:
# def run_umap(data, path, learning_rate, nn, n_epochs, min_dist=0.1
#     reducer = umap.UMAP(n_neighbors=nn, learning_rate=learning_rate, n_epochs=n_epochs, min_dist=min_dist)
    
#     csv_path = path+'_nn'+str(nn)+'_lr'+str(learning_rate)+'_nepochs'+str(n_epochs)+'_mindist'+str(min_dist)+'.csv'
#     image_file_name = path+'_nn'+str(nn)+'_lr'+str(learning_rate)+'_nepochs'+str(n_epochs)+'_mindist'+str(min_dist)+'.png'
    
#     print('fitting umap embedding for', csv_path)
#     umap_embedding = reducer.fit_transform(data)
    
#     umap_df = pd.DataFrame(umap_embedding)
#     umap_df.to_csv(csv_path)
    
#     plt.figure()
#     plt.scatter(umap_embedding[:,0],umap_embedding[:,1])
#     print('storing to', csv_path)
#     plt.savefig(image_file_name)
#     plt.close()

In [35]:
# nns = [25,30,35,40,45]
# n_epochs_settings = [200, 300, 400, 500, 600, 700,]
# min_dist = [0.1]
# learning_rate = [1.0]
# for nn in nns:
#     for n_epochs in n_epochs_settings:
#         for md in min_dist:
#             for lr in learning_rate:
#                 run_umap(data=final_data, path=OUTPUT_FILE_NAME, learning_rate=lr, nn=nn, n_epochs=n_epochs, min_dist=md)