In [2]:
from agent import QLearningAgent
from verify_data import FishGame, ParseError

In [4]:
import os
import pickle
import torch
import multiprocessing as mp

def process_file(filename, data_dir='data'):
    filepath = os.path.join(data_dir, filename)
    file_memories = []
    try:
        with open(filepath, 'r') as f:
            print(f"{filename}")
            game = FishGame(f.readlines())
            for player in game.players:
                for _ in range(100):
                    game.shuffle()
                    file_memories.append(game.memory(player))
        return file_memories
    except ParseError as e:
        print(f"{filename}: {e}")
        return []

if os.path.isfile('memories_extended.pkl'):
    with open('memories_extended.pkl', 'rb') as f:
        memories = pickle.load(f)
else:
    memories = []
    filenames = os.listdir('data')
    
    if torch.cuda.is_available():
        num_processes = min(mp.cpu_count(), 8)
        with mp.Pool(processes=num_processes) as pool:
            results = pool.map(process_file, filenames)
        for result in results:
            memories.extend(result)
    else:
        for filename in filenames:
            filepath = os.path.join('data', filename)
            with open(filepath, 'r') as f:
                try:
                    print(f"{filename}")
                    game = FishGame(f.readlines())
                    for player in game.players:
                        for _ in range(100):
                            game.shuffle()
                            memories.append(game.memory(player))
                except ParseError as e:
                    print(f"{filename}: {e}")
                    break
    with open('memories_extended.pkl', 'wb') as f:
        pickle.dump(memories, f)

In [None]:
agent = QLearningAgent()
agent.load_model('fish_agent.pth')
agent.train_on_data(memories, 0, 50)

Memory loaded in 10.31 seconds


  accuracies = ((one_hot * episode['hands']).sum((1,2)) - guarantee) / (cards_remaining - guarantee)
Training Hand Predictor epoch 49 train loss 1.05107 test loss 1.06269 train acc 0.34 test acc 0.34 lr 0.0008: 100%|██████████| 50/50 [14:45<00:00, 17.71s/it]
Training Q-Network: 0it [00:00, ?it/s]


In [None]:
batch = agent.unpack_memory([memories[3][15]])
agent.hand_predictor(agent.tensor(batch['state']), agent.action_masks(*batch['mask_dep'].values())['hands'])

In [None]:
from datetime import datetime

os.makedirs('models', exist_ok=True)
model_path = f'models/fish_agent_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pth'
agent.save_model(model_path)

In [None]:
from agent import QLearningAgent
agent = QLearningAgent() 
agent.load_model('models/fish_agent.pth')
agent.train_self_play(2000, update_rate=1, hand_epochs=10, q_epochs=10, path='models/fish_agent.pth')

loading hand predictor
loading q vals
Z2 Z7 CK 0
Z7 Z4 C4 0
Z4 Z1 HQ 0
Z1 Z6 D5 0
Z6 Z3 D5 0
Z3 Z8 D5 1
Z3 Z8 D5 0
Z8 Z5 HQ 0
Z5 Z2 HQ 0
Z2 Z5 C4 0
Z5 Z2 HQ 0
Z2 Z7 C4 1
Z2 Z7 C4 0
Z7 Z4 C9 0
helping ask... C3
Z4 Z1 C3 1
helping call... {'Z4': {'C5', 'C3'}, 'Z6': {'C2'}, 'Z8': {'C6'}, 'Z2': {'C7', 'C4'}}
Z4 Z4:{C5,C3} Z6:{C2} Z8:{C6} Z2:{C7,C4} 1
helping ask... S3
Z4 Z1 S3 1
helping call... {'Z4': {'S3', 'S2'}, 'Z6': {'S7'}, 'Z8': set(), 'Z2': {'S4', 'S6', 'S5'}}
Z4 Z4:{S3,S2} Z6:{S7} Z2:{S4,S6,S5} 1
helping ask... D9
Z4 Z1 D9 1
Z4 Z1 HQ 0
Z1 Z6 CK 0
Z6 Z3 D9 0
Z3 Z8 D5 0
Z8 Z5 HQ 0
helping ask... SQ
Z5 Z2 SQ 1
Z5 Z6 DT 0
Z6 Z3 D9 0
Z3 Z8 D5 0
Z8 Z5 HQ 0
Z5 Z6 DT 0
helping ask... D3
Z6 Z3 D3 1
helping ask... D6
Z6 Z3 D6 1
Z6 Z3 D5 1
Z6 Z3 H4 0
Z3 Z8 CK 1
Z3 Z8 CK 0
Z8 Z5 HQ 0
Z5 Z2 D9 0
Z2 Z5 CK 0
Z5 Z2 D9 0
Z2 Z7 CK 0
Z7 Z4 H5 0
Z4 Z1 D9 0
helping ask... CJ
Z1 Z2 CJ 1
helping call... {'Z1': {'CT', 'CQ', 'CJ'}, 'Z3': {'CK', 'CA'}, 'Z5': set(), 'Z7': {'C9'}}
Z1 Z1:{CT,CQ,CJ} Z3:{CK,CA} 