In [1]:
from agent import QLearningAgent
from verify_data import FishGame, ParseError

In [None]:
import os
import pickle
import torch
import multiprocessing as mp

def process_file(filename, data_dir='data'):
    filepath = os.path.join(data_dir, filename)
    file_memories = []
    try:
        with open(filepath, 'r') as f:
            print(f"{filename}")
            game = FishGame(f.readlines())
            for player in game.players:
                for _ in range(100):
                    game.shuffle()
                    file_memories.append(game.memory(player))
        return file_memories
    except ParseError as e:
        print(f"{filename}: {e}")
        return []

if os.path.isfile('memories_extended.pkl'):
    with open('memories_extended.pkl', 'rb') as f:
        memories = pickle.load(f)
else:
    memories = []
    filenames = os.listdir('data')
    
    if torch.cuda.is_available():
        num_processes = min(mp.cpu_count(), 8)
        with mp.Pool(processes=num_processes) as pool:
            results = pool.map(process_file, filenames)
        for result in results:
            memories.extend(result)
    else:
        for filename in filenames:
            filepath = os.path.join('data', filename)
            with open(filepath, 'r') as f:
                try:
                    print(f"{filename}")
                    game = FishGame(f.readlines())
                    for player in game.players:
                        for _ in range(100):
                            game.shuffle()
                            memories.append(game.memory(player))
                except ParseError as e:
                    print(f"{filename}: {e}")
                    break
    with open('memories_extended.pkl', 'wb') as f:
        pickle.dump(memories, f)

In [None]:
agent = QLearningAgent()
agent.load_model('fish_agent_20250427_152210.pth')
agent.train_on_data(memories, 5, 0)

In [None]:
batch = agent.unpack_memory([memories[3][15]])
agent.hand_predictor(agent.tensor(batch['state']), agent.action_masks(*batch['mask_dep'].values())['hands'])

In [None]:
from datetime import datetime

os.makedirs('models', exist_ok=True)
model_path = f'models/fish_agent_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pth'
agent.save_model(model_path)

In [None]:
from agent import QLearningAgent
agent = QLearningAgent() 
agent.load_model('models/fish_agent.pth')
agent.train_self_play(2000, update_rate=1, hand_epochs=10, q_epochs=5, path='models/fish_agent.pth')

loading hand predictor
loading q vals
helping ask... H3
helping ask... H5
helping call... {'Z6': {'H6', 'H2'}, 'Z8': {'H4', 'H5', 'H3'}, 'Z2': set(), 'Z4': {'H7'}}
helping ask... S5
helping ask... S7
Game 0 finished, 0 memories collected
helping ask... C2
Game 1 finished, 0 memories collected
Memory loaded in 0.28 seconds


Training Hand Predictor epoch 9 train loss 0.85838 test loss 0.89675 train acc 0.28 test acc 0.25 lr 0.001: 100%|██████████| 10/10 [00:15<00:00,  1.55s/it]
Training Q-Network epoch 0 train loss 0.08811 test loss 0.01674 lr 0.001:  20%|██        | 1/5 [00:09<00:36,  9.11s/it]