In [1]:
from agent import QLearningAgent
from verify_data import FishGame, ParseError

In [None]:
import os
import pickle
import torch
import multiprocessing as mp

def process_file(filename, data_dir='data'):
    filepath = os.path.join(data_dir, filename)
    file_memories = []
    try:
        with open(filepath, 'r') as f:
            print(f"{filename}")
            game = FishGame(f.readlines())
            for player in game.players:
                for _ in range(100):
                    game.shuffle()
                    file_memories.append(game.memory(player))
        return file_memories
    except ParseError as e:
        print(f"{filename}: {e}")
        return []

if os.path.isfile('memories_extended.pkl'):
    with open('memories_extended.pkl', 'rb') as f:
        memories = pickle.load(f)
else:
    memories = []
    filenames = os.listdir('data')
    
    if torch.cuda.is_available():
        num_processes = min(mp.cpu_count(), 8)
        with mp.Pool(processes=num_processes) as pool:
            results = pool.map(process_file, filenames)
        for result in results:
            memories.extend(result)
    else:
        for filename in filenames:
            filepath = os.path.join('data', filename)
            with open(filepath, 'r') as f:
                try:
                    print(f"{filename}")
                    game = FishGame(f.readlines())
                    for player in game.players:
                        for _ in range(100):
                            game.shuffle()
                            memories.append(game.memory(player))
                except ParseError as e:
                    print(f"{filename}: {e}")
                    break
    with open('memories_extended.pkl', 'wb') as f:
        pickle.dump(memories, f)

12-6_11:08.txt12-3_11:12.txt1-15_11:15.txt12-3_11:30.txt12-3_14:27.txt12-4_11:11.txt12-10_11:07.txt12-3_15:27.txt







12-3_14:05.txt


In [None]:
agent = QLearningAgent()
agent.train_on_data(memories, 0, 500)

  accuracies = ((one_hot * episode['hands']).sum((1,2)) - guarantee) / (cards_remaining - guarantee)
Training Hand Predictor epoch 28 train loss 1.09382 test loss 1.32096 train acc 0.26 test acc 0.24 lr 0.0008:   6%|▌         | 29/500 [07:58<2:08:54, 16.42s/it]

In [None]:
batch = agent.unpack_memory([memories[3][15]])
agent.hand_predictor(agent.tensor(batch['state']), agent.action_masks(*batch['mask_dep'].values())['hands'])

In [None]:
from datetime import datetime

os.makedirs('models', exist_ok=True)
model_path = f'models/fish_agent_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pth'
agent.save_model(model_path)

In [None]:
agent = QLearningAgent() 
agent.load_model('models/fish_agent_20250417_192317.pth')

In [None]:
agent.train_self_play(2000, update_rate=1, hand_epochs=5, q_epochs=5, path='models/fish_agent_20250417_192317.pth')