In [None]:
from agent import QLearningAgent
from verify_data import FishGame, ParseError

In [2]:
import os
import pickle
import torch
import multiprocessing as mp

def process_file(filename, data_dir='data'):
    filepath = os.path.join(data_dir, filename)
    file_memories = []
    try:
        with open(filepath, 'r') as f:
            print(f"{filename}")
            game = FishGame(f.readlines())
            for player in game.players:
                for _ in range(100):
                    game.shuffle()
                    file_memories.append(game.memory(player))
        return file_memories
    except ParseError as e:
        print(f"{filename}: {e}")
        return []

if os.path.isfile('memories_extended.pkl'):
    with open('memories_extended.pkl', 'rb') as f:
        memories = pickle.load(f)
else:
    memories = []
    filenames = os.listdir('data')
    
    if torch.cuda.is_available():
        num_processes = min(mp.cpu_count(), 8)
        with mp.Pool(processes=num_processes) as pool:
            results = pool.map(process_file, filenames)
        for result in results:
            memories.extend(result)
    else:
        for filename in filenames:
            filepath = os.path.join('data', filename)
            with open(filepath, 'r') as f:
                try:
                    print(f"{filename}")
                    game = FishGame(f.readlines())
                    for player in game.players:
                        for _ in range(100):
                            game.shuffle()
                            memories.append(game.memory(player))
                except ParseError as e:
                    print(f"{filename}: {e}")
                    break
    with open('memories_extended.pkl', 'wb') as f:
        pickle.dump(memories, f)

In [None]:
agent = QLearningAgent()
agent.load_model('models/fish_agent.pth', q_network=False)
agent.train_on_data(memories, 5, 0)

In [None]:
batch = agent.unpack_memory([memories[3][15]])
agent.hand_predictor(agent.tensor(batch['state']), agent.action_masks(*batch['mask_dep'].values())['hands'])

In [None]:
from datetime import datetime

os.makedirs('models', exist_ok=True)
# model_path = f'models/fish_agent_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pth'
model_path = 'models/fish_agent.pth'
agent.save_model(model_path)

In [None]:
from agent import QLearningAgent
agent = QLearningAgent(memories) 
agent.load_model('models/fish_agent.pth')
agent.train_self_play(2000, update_rate=1, hand_epochs=10, q_epochs=5, path='models/fish_agent.pth')

loading hand predictor
loading q vals
Z1 Z5:{SK} Z7:{SJ,SA,SQ,S9,ST} 0
Z8 Z4:{C6} Z6:{C4,C3,C7,C5,C2} 0
Z7 Z3:{HK} Z5:{HT,H9,HJ,HQ,HA} 0
Z6 Z2:{CK} Z4:{C9,CT,CQ,CA,CJ} 0
Z5 Z1:{D6} Z3:{D7,D5,D2,D4,D3} 0
Z6 Z6:{S7,S2,S4,S3} Z8:{S5} Z2:{S6} 0
Z3 Z7:{DK} Z1:{DQ,DJ,DT,DA,D9} 0
Z7 Z3:{H6} Z5:{H2,H5,H4,H3,H7} 0
Game 0 finished, 6800 memories collected
Z1 Z5:{SJ} Z7:{SA,SQ,S9,SK,ST} 0
Z8 Z4:{C6} Z6:{C4,C3,C7,C5,C2} 0
Z7 Z3:{CJ} Z5:{CT,CQ,CA,CK,C9} 0
Z6 Z2:{H9} Z4:{HK,HJ,HT,HQ,HA} 0
Z5 Z1:{H3} Z3:{H6,H2,H5,H4,H7} 0
Z5 Z1:{DA} Z3:{DK,DQ,DJ,DT,D9} 0
Z3 Z3:{*S,D8,*B,H8} Z1:{S8,C8} 0
{'Z5': ['S6', 'S2'], 'Z6': ['S4', 'S3'], 'Z7': [], 'Z8': ['S7', 'S5'], 'Z1': [], 'Z2': [], 'Z3': [], 'Z4': []}
helping ask... S7
Z5 Z2 S7 1
{'Z5': ['S7', 'S6', 'S2'], 'Z6': ['S4', 'S3'], 'Z7': [], 'Z8': ['S5'], 'Z1': [], 'Z2': [], 'Z3': [], 'Z4': []}
Z5 Z8 S4 1
Z5 Z5:{S2,S6,S7,S3,S4,S5} 0
Game 1 finished, 6800 memories collected
Memory loaded in 0.04 seconds


Training Hand Predictor epoch 9 train loss 0.8426 test loss 1.24174 train acc 0.27 test acc 0.13 lr 0.0008: 100%|██████████| 10/10 [00:07<00:00,  1.37it/s]
Training Q-Network epoch 4 train loss 0.20294 test loss 0.22174 lr 0.001: 100%|██████████| 5/5 [00:05<00:00,  1.14s/it]


Z1 Z1:{H8} Z5:{C8} Z7:{*S,S8,D8,*B} 0
{'Z4': ['S3', 'DA', 'S9', 'SK', 'H9', 'HK'], 'Z5': ['S7', 'S6', 'S5', 'D4', 'D7', 'D5', 'D3', 'DT', 'DQ', 'SQ', 'SJ', 'ST', 'C4', 'C5', 'C3', 'H2', 'H6', 'HT', 'HA'], 'Z6': ['C6'], 'Z7': ['S2'], 'Z8': ['D6', 'D2', 'HQ'], 'Z1': ['DK', 'C2', 'HJ'], 'Z2': [], 'Z3': ['S4', 'DJ', 'D9', 'SA', 'C7', 'H4', 'H7', 'H3', 'H5']}
Z4 Z3 SJ 0
{'Z3': ['S7', 'S6', 'S5', 'D4', 'C4', 'HQ'], 'Z4': ['D6', 'D2', 'D5', 'D3', 'DK', 'DT', 'D9', 'S9', 'SQ', 'SJ', 'ST', 'SA', 'SK', 'C7', 'C2', 'H6', 'HT', 'H9', 'HA', 'HJ'], 'Z5': ['S2', 'S3', 'S4', 'D7', 'DJ', 'DQ', 'C6', 'C3', 'H4', 'H7', 'H5'], 'Z6': ['DA', 'HK'], 'Z7': ['H3'], 'Z8': ['C5'], 'Z1': ['H2'], 'Z2': []}
Z3 Z2 D2 0
{'Z2': ['DJ', 'C6', 'H7', 'HJ'], 'Z3': ['S6', 'S3', 'D6', 'D4', 'D2', 'D7', 'D5', 'D3', 'DT', 'C4', 'C5', 'H3', 'HT'], 'Z4': ['S4', 'S9', 'SQ', 'SJ', 'ST', 'SA', 'SK', 'C7', 'H6'], 'Z5': ['DA', 'C2', 'C3', 'H2', 'H4', 'H5', 'H9', 'HQ', 'HA', 'HK'], 'Z6': ['DQ'], 'Z7': ['S5'], 'Z8': ['S7', 'S2'], 'Z1':

Training Hand Predictor epoch 9 train loss 0.92071 test loss 0.97361 train acc 0.25 test acc 0.21 lr 0.0008: 100%|██████████| 10/10 [00:03<00:00,  2.56it/s]
Training Q-Network epoch 4 train loss 0.08889 test loss 0.06524 lr 0.001: 100%|██████████| 5/5 [00:05<00:00,  1.09s/it]


{'Z4': ['C9', '*S', 'H8', 'C6', 'H3', 'H5', 'DT', 'HA', 'SK'], 'Z5': ['CJ', 'S8', 'D8', 'D4', 'C3', 'DK', 'D9', 'DA', 'DJ', 'HT', 'SQ', 'ST', 'SJ'], 'Z6': ['CK', 'H7', 'S6'], 'Z1': ['*B', 'D2', 'D6', 'D3', 'C7', 'C5', 'C4', 'C2', 'H6', 'DQ', 'HJ', 'HQ', 'HK', 'S5', 'S7', 'S9', 'SA'], 'Z2': ['CA', 'CQ', 'S2'], 'Z3': ['CT', 'C8', 'D5', 'D7', 'H2', 'H4', 'H9', 'S4', 'S3']}
Z4 Z5 D9 0
{'Z5': ['CJ', '*B', 'D2', 'D6', 'D5', 'C4', 'H6', 'DJ', 'SA'], 'Z6': ['CA', 'CQ', 'CK', 'S8', 'D8', 'D4', 'C3', 'H7', 'SQ', 'ST', 'SJ'], 'Z1': ['C9', '*S'], 'Z2': ['D3', 'C7', 'C5', 'C2', 'DQ', 'HQ', 'HK', 'S7', 'S3', 'S9', 'SK'], 'Z3': ['H8', 'HA', 'S5', 'S6', 'S2'], 'Z4': ['CT', 'C8', 'D7', 'C6', 'H2', 'H4', 'H3', 'H5', 'DK', 'D9', 'DT', 'DA', 'HJ', 'H9', 'HT', 'S4']}
Z5 Z6 D9 0
{'Z6': ['D8', 'D3', 'C7', 'H2', 'S4', 'S5', 'S3', 'S6', 'S2'], 'Z1': ['C9', 'CQ', 'CK', 'CJ', 'S8', 'D4', 'C3', 'C6', 'H7', 'DQ', 'D9', 'DT', 'DJ', 'HQ', 'SK', 'SQ', 'ST', 'SJ'], 'Z2': ['CA', '*S', 'D2'], 'Z3': ['*B', 'C5', 'C4', 'C

Training Hand Predictor epoch 9 train loss 0.91295 test loss 0.84509 train acc 0.37 test acc 0.34 lr 0.0008: 100%|██████████| 10/10 [00:03<00:00,  2.60it/s]
Training Q-Network epoch 4 train loss 0.0454 test loss 0.11452 lr 0.001: 100%|██████████| 5/5 [00:24<00:00,  4.84s/it] 


Memory loaded in 0.37 seconds


  accuracies = ((one_hot * episode['hands']).sum((1,2)) - guarantee) / (cards_remaining - guarantee)
Training Hand Predictor epoch 29 train loss 1.00891 test loss 1.09123 train acc 0.44 test acc 0.36 lr 0.0008: 100%|██████████| 30/30 [00:23<00:00,  1.28it/s]
Training Q-Network epoch 2 train loss 0.04389 test loss 0.05831 lr 0.001:  20%|██        | 3/15 [00:35<02:21, 11.75s/it]