In [3]:
from agent import QLearningAgent
from verify_data import FishGame, ParseError

In [3]:
import os
import pickle
import torch
import multiprocessing as mp

def process_file(filename, data_dir='data'):
    filepath = os.path.join(data_dir, filename)
    file_memories = []
    try:
        with open(filepath, 'r') as f:
            print(f"{filename}")
            game = FishGame(f.readlines())
            for player in game.players:
                for _ in range(100):
                    game.shuffle()
                    file_memories.append(game.memory(player))
        return file_memories
    except ParseError as e:
        print(f"{filename}: {e}")
        return []

if os.path.isfile('memories_extended.pkl'):
    with open('memories_extended.pkl', 'rb') as f:
        memories = pickle.load(f)
else:
    memories = []
    filenames = os.listdir('data')
    
    if torch.cuda.is_available():
        num_processes = min(mp.cpu_count(), 8)
        with mp.Pool(processes=num_processes) as pool:
            results = pool.map(process_file, filenames)
        for result in results:
            memories.extend(result)
    else:
        for filename in filenames:
            filepath = os.path.join('data', filename)
            with open(filepath, 'r') as f:
                try:
                    print(f"{filename}")
                    game = FishGame(f.readlines())
                    for player in game.players:
                        for _ in range(100):
                            game.shuffle()
                            memories.append(game.memory(player))
                except ParseError as e:
                    print(f"{filename}: {e}")
                    break
    with open('memories_extended.pkl', 'wb') as f:
        pickle.dump(memories, f)

In [5]:
import pickle
with open('call_memories.pkl', 'rb') as f:
    memories = pickle.load(f)

In [18]:
memories[1501]

[{'state': array([2, 0, 0, 0, 0, 0, 6, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
         0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]),
  'hands': array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [6]:
agent = QLearningAgent()
agent.load_model('models/fish_agent.pth')
agent.train_on_data(memories, 15, 0)

loading hand predictor
loading q vals


Training Q-Network:   0%|          | 0/15 [00:00<?, ?it/s]

Training Q-Network epoch 5 train loss 162760419704832.0 test loss 0.82417 lr 0.0008:  40%|████      | 6/15 [00:18<00:27,  3.04s/it]


KeyboardInterrupt: 

In [None]:
batch = agent.unpack_memory([memories[3][15]])
agent.hand_predictor(agent.tensor(batch['state']), agent.action_masks(*batch['mask_dep'].values())['hands'])

In [5]:
from datetime import datetime

os.makedirs('models', exist_ok=True)
# model_path = f'models/fish_agent_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pth'
model_path = 'models/fish_agent.pth'
agent.save_model(model_path)

Model saved to models/fish_agent.pth


In [None]:
from agent import QLearningAgent
agent = QLearningAgent(memories) 
agent.load_model('models/fish_agent.pth')
agent.train_self_play(2000, update_rate=1, hand_epochs=10, q_epochs=5, path='models/fish_agent.pth')