In [1]:
import os

import numpy as np

In [2]:
path = 'data/www.bridgebase.com'
filenames = sorted(os.listdir(path))

In [3]:
# Reference: https://github.com/morgoth/lin/blob/master/lib/lin/parser.rb

import re
from pprint import pprint

class Card:
    def __init__(self, suit, value):
        assert type(suit) in [str, Suit]
        if value not in 'AKQJT98765432':
            raise ValueError(f'Unknown card value "{value}"')
        assert value in 'AKQJT98765432'
        self.suit = Suit(suit) if type(suit) is str else suit
        self.value = value
        
    def __repr__(self):
        return '%s%s' % (self.suit, self.value)
    
    
class Suit:
    suit_face = {
        'S': '♠',
        'H': '♥',
        'D': '♦',
        'C': '♣',
        'N': 'N',
    }
    def __init__(self, value):
        assert value in self.suit_face.keys()
        self.value = value
        
    def __eq__(self, other):
        if type(other) is Suit:
            return self.value == other.value
        return self.value == other
    
    def __repr__(self):
        return self.suit_face[self.value]
    
    
class Bid:
    def __init__(self, value, suit=None):
        assert type(suit) in [type(None), str, Suit]
        assert value in '1234567PDRN'
        self.suit = Suit(suit) if type(suit) is str else suit
        self.value = value
        
    def __eq__(self, other):
        return self.value == other.value and self.suit == other.suit
        
    def __repr__(self):
        if self.suit:
            return '%s%s' % (self.value, self.suit)
        return self.value

    
class LinParser:
    def __init__(self):
        self.games = []
    
    def parse(self, filename):
        try:
            with open(filename) as f:
                content = f.read()

            self.games = self.parse_games(content)
            return self.games
        except:
            print('Filename:', filename)
            raise

    def parse_games(self, content):
        games = []
        for game_content in self.extract_games(content):
            game = self.parse_game(game_content)
            if self.is_valid_game(game):
                games.append(game)
        return games
    
    def extract_games(self, content):
        return content.split('qx|')
    
    def parse_game(self, game_content):
        game = {}
        game['hands'] = self.parse_hands(game_content)
        game['dealer'] = self.parse_dealer(game_content)
        game['bids'] = self.parse_bids(game_content)
        game['played_cards'] = self.parse_played_cards(game_content)
        game['claimed_tricks'] = self.parse_claimed_tricks(game_content)
        game['vulnerability'] = self.parse_vulnerability(game_content)
        return game
    
    def parse_hands(self, game_content):
        try:
            pattern = re.compile(r'md\|\d(\w+),(\w+),(\w+),(\w+)')
            match = pattern.search(game_content)

            players = ['south', 'west', 'north', 'east']
            hands = {}
            if match:
                for player, hand in zip(players, match.groups()):
                    hands[player] = self.parse_hand(hand)
        except:
            print(game_content)
            raise
        return hands
    
    def parse_hand(self, hand_content):
        suits = ['S', 'H', 'D', 'C']
        hand = []
        
        for char in hand_content:
            char = char.upper()
            if char in suits:
                suit = char
            else:
                value = char
                
                # Bug fix
                if value == 'N':
                    value = '9'
                
                hand.append(Card(suit=suit, value=value))
            
        return hand
    
    def parse_dealer(self, game_content):
        pattern = re.compile(r'md\|(\d)')
        match = pattern.search(game_content)
        code_to_string = {
            '1': 'south',
            '2': 'west',
            '3': 'north',
            '4': 'east',
        }
        dealer = ''
        if match:
            dealer = code_to_string[match.groups()[0]]
        return dealer
    
    def parse_bids(self, game_content):
        try:
            bids = []
            pattern = re.compile(r'(?:\|mb|^mb)\|(\w\w?)')
            match = pattern.findall(game_content)
            for bid in match:
                bids.append(self.parse_bid(bid.upper()))
        except:
            print(game_content)
            raise
        return bids
    
    def parse_bid(self, bid):
        if len(bid) > 1:
            bid = bid.replace('NT', 'N') # bug fix
            
            value, suit = bid
            return Bid(value=value, suit=suit)
        else:
            return Bid(value=bid)
        
    def parse_played_cards(self, game_content):
        try:
            cards = []
            pattern = re.compile(r'pc\|([cdhsCDHS](?:\d\d|\d|\w))')
            match = pattern.findall(game_content)
            for card in match:
                # Bug fix
                card = card.replace('10', 'T')
                
                suit, value = card.upper()
                cards.append(Card(suit=suit, value=value))
        except:
            print(game_content)
            raise
        return cards
    
    def parse_claimed_tricks(self, game_content):
        pattern = re.compile(r'mc\|(\d+)')
        match = pattern.search(game_content)
        tricks = 0
        if match:
            tricks = int(match.groups()[0])
        return tricks
    
    def parse_vulnerability(self, game_content):
        try:
            pattern = re.compile(r'sv\|(\w)')
            match = pattern.search(game_content)
            code_to_string = {
                'o': 'NONE',
                '0': 'NONE',
                'n': 'NS',
                's': 'NS',
                'e': 'EW',
                'w': 'EW',
                'b': 'BOTH',
            }
            dealer = ''
            if match:
                dealer = code_to_string[match.groups()[0].lower()]
        except:
            print(game_content)
            raise
        return dealer
    
    def is_valid_game(self, game):
        return (game['hands']
                and game['bids']
                and game['dealer']
                and game['played_cards']
                and game['claimed_tricks']
                and game['vulnerability']
        )
    

filename = os.path.join(path, filenames[0])
parser = LinParser()
games = parser.parse(filename)
print('Number of games: ', len(games))
print('Games:')
pprint(parser.games[:2])

Number of games:  29
Games:
[{'bids': [P, 2♠, P, P, P],
  'claimed_tricks': 6,
  'dealer': 'north',
  'hands': {'east': [♠A, ♠Q, ♠T, ♠7, ♠2, ♥T, ♥9, ♥3, ♦K, ♦5, ♦2, ♣T, ♣3],
            'north': [♠J, ♠6, ♠5, ♥A, ♥Q, ♥J, ♦Q, ♦7, ♦6, ♦4, ♣Q, ♣7, ♣6],
            'south': [♠8, ♠4, ♥8, ♥5, ♥4, ♥2, ♦A, ♦9, ♦3, ♣A, ♣K, ♣9, ♣4],
            'west': [♠K, ♠9, ♠3, ♥K, ♥7, ♥6, ♦J, ♦T, ♦8, ♣J, ♣8, ♣5, ♣2]},
  'played_cards': [♣K,
                   ♣2,
                   ♣7,
                   ♣3,
                   ♥2,
                   ♥K,
                   ♥A,
                   ♥3,
                   ♥Q,
                   ♥9,
                   ♥4,
                   ♥6,
                   ♥J,
                   ♥T,
                   ♥5,
                   ♥7,
                   ♣Q,
                   ♣T,
                   ♣4,
                   ♣5,
                   ♣6,
                   ♠2,
                   ♣9,
                   ♣8,
                   ♠A,
                   ♠4,
   

In [4]:
class FeatureConverter:
    def convert_game(self, game, player):
        result = []
        result.extend(self.convert_own_hand(game, player))
        result.extend(self.convert_partner_hand(game, player))
        result.extend(self.convert_vulnerability(game, player))
        result.extend(self.convert_bidding_sequence(game))
        return result
    
    def convert_own_hand(self, game, player):
        return self.convert_hand(game['hands'][player])
    
    def convert_partner_hand(self, game, player):
        partner = self.get_partner(player)
        return self.convert_hand(game['hands'][partner])
    
    def get_partner(self, player):
        partners = {
            'north': 'south',
            'south': 'north',
            'east': 'west',
            'west': 'east',
        }
        return partners[player]
    
    def convert_hand(self, hand):
        # One-hot encoding
        result = [0 for _ in range(52)]
        for card in hand:
            card_index = self.calculate_card_index(card)
            result[card_index] = 1
        return result
    
    def calculate_card_index(self, card):
        suit_order = ['S', 'H', 'D', 'C']
        value_order = ['A', 'K', 'Q', 'J', 'T', '9', '8', '7', '6', '5', '4', '3', '2']
        suit_index = suit_order.index(card.suit)
        value_index = value_order.index(card.value)
        card_index = 51 - (suit_index*13 + value_index)
        return card_index
    
    def convert_vulnerability(self, game, player):
        if game['vulnerability'] == 'NONE':
            return [0, 0]
        elif game['vulnerability'] == 'BOTH':
            return [1, 1]
        elif game['vulnerability'] == self.get_team(player):
            # own team is vulnerable
            return [1, 0]
        else:
            # opponent is vulnerable
            return [0, 1]
    
    def get_team(self, player):
        if player in ['north', 'south']:
            return 'NS'
        else:
            return 'EW'
    
    def convert_bidding_sequence(self, game):
         # ignore 3 final passes
        if game['bids'][-3:] == [Bid(value='P'), Bid(value='P'), Bid(value='P')]:
            bids = game['bids'][:-3]
        else:
            bids = game['bids']
        
        result = []
        
        value_order = ['1', '2', '3', '4', '5', '6', '7']
        suit_order = [Suit('C'), Suit('D'), Suit('H'), Suit('S'), Suit('N')]
        contract_bids = self.split_contract_bids(bids)
        
        if contract_bids and contract_bids[0]:
            result.extend(self.convert_initial_passes(contract_bids[0]))
        else:
            result.extend([0 for _ in range(3)])
            
        if contract_bids:
            contract_bids.pop(0)
            
        for value in value_order:
            for suit in suit_order:
                bid = Bid(value=value, suit=suit)
                if contract_bids and contract_bids[0][0] == bid:
                    non_contract_bids = contract_bids[0][1:]
                    result.extend([1] + self.convert_non_contract_bids(non_contract_bids))
                    contract_bids.pop(0)
                else:
                    result.extend([0 for _ in range(9)])
                    
        return result
    
    def split_contract_bids(self, bids):
        result = []
        current_contract = []
        
        for bid in bids:
            if bid.suit:
                result.append(current_contract)
                current_contract = [bid]
            else:
                current_contract.append(bid)
        
        if current_contract:
            result.append(current_contract)
                
        return result
    
    def convert_initial_passes(self, bids):
        result = [1 for _ in range(len(bids))]
        result.extend([0 for _ in range(3 - len(bids))])
        return result
    
    def convert_non_contract_bids(self, bids):
        result = [0 for _ in range(8)]
        
        # Longest non-contract bids
        sequence = [Bid(value='P'), Bid(value='P'), Bid(value='D'),
                    Bid(value='P'), Bid(value='P'), Bid(value='R'),
                    Bid(value='P'), Bid(value='P')]
        
        for i in range(len(sequence)):
            if bids and sequence[i] == bids[0]:
                result[i] = 1
                bids.pop(0)
        
        return result
        
    
# parser = LinParser()
# print(filename)
# games = parser.parse(filename)
# print(games)
# game = parser.games[0]
# converter = FeatureConverter()
# north_feature_vector = converter.convert_game(game, player='north')
# print(game['bids'])
# print('Feature vector:', north_feature_vector)

# assert north_feature_vector[0:52] == [
#     1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0,
#     0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
#     0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
#     0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0
# ]
# assert sum(north_feature_vector[0:52]) == 13
# assert north_feature_vector[52:104] == [
#     0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1,
#     0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
#     0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0,
#     0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0
# ]
# assert sum(north_feature_vector[52:104]) == 13
# assert north_feature_vector[104:106] == [0, 0]
# assert north_feature_vector[106:] == [
#     1, 0, 0, # 1 initial pass
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 1C
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 1D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 1H
#     1, 0, 0, 0, 0, 0, 0, 0, 0, # 1S
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 1NT
#     1, 0, 0, 0, 0, 0, 0, 0, 0, # 2C
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 2D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 2H
#     1, 0, 0, 1, 0, 0, 0, 0, 0, # 2S, double
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 2NT
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 3C
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 3D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 3H
#     1, 1, 1, 0, 0, 0, 0, 0, 0, # 3S, pass, pass
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 3NT
#     1, 0, 0, 0, 0, 0, 0, 0, 0, # 4C, (pass, pass, pass)
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 4D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 4H
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 4S
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 4HT
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 5C
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 5D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 5H
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 5S
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 5HT
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 6C
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 6D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 6H
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 6S
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 6HT
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 7C
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 7D
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 7H
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 7S
#     0, 0, 0, 0, 0, 0, 0, 0, 0, # 7HT
# ]

# south_feature_vector = converter.convert_game(game, player='south')
# print('Feature vector:', south_feature_vector)

# assert south_feature_vector[0:52] == north_feature_vector[52:104]
# assert south_feature_vector[52:104] == north_feature_vector[0:52]
# assert south_feature_vector[104:106] == north_feature_vector[104:106]
# assert south_feature_vector[106:] == north_feature_vector[106:]

# east_feature_vector = converter.convert_game(game, player='east')
# print('Feature vector:', east_feature_vector)
# assert east_feature_vector[0:52] != north_feature_vector[52:104]
# assert east_feature_vector[52:104] != north_feature_vector[0:52]
# assert east_feature_vector[106:] == north_feature_vector[106:]

# west_feature_vector = converter.convert_game(game, player='west')
# print('Feature vector:', east_feature_vector)
# assert west_feature_vector[0:52] == east_feature_vector[52:104]
# assert west_feature_vector[52:104] == east_feature_vector[0:52]
# assert west_feature_vector[104:106] == east_feature_vector[104:106]
# assert west_feature_vector[106:] == east_feature_vector[106:]

In [5]:
class ENNDataGenerator:
    def generate_dataset(self, path):
        filenames = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('.lin')])
        
        X = []
        Y = []
        
        for filename in filenames:
            X_file, Y_file = self.generate_dataset_from_lin_file(filename)
            X.extend(X_file)
            Y.extend(Y_file)
            
        return np.array(X), np.array(Y)
    
    def generate_dataset_from_lin_file(self, filename):
        parser = LinParser()
        games = parser.parse(filename)
        
        X = []
        Y = []
        
        for game in games:
            X_game, Y_game = self.convert_game_to_features(game)
            X.extend(X_game)
            Y.extend(Y_game)

        return X, Y
    
    def convert_game_to_features(self, game):
        converter = FeatureConverter()
        players = ['north', 'south', 'east', 'west']
        
        X = []
        Y = []
        
        for player in players:
            feature = converter.convert_game(game, player=player)
            
            x = feature[:52] + feature[104:] # own hand, vulnerability, bidding sequence
            y = feature[52:104] # partner's hand
            
            X.append(x)
            Y.append(y)
        
        return X, Y
        
        
class PNNDataGenerator:
    def generate_dataset(self, path):
        filenames = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('.lin')])
        
        X = []
        Y = []
        
        for filename in filenames:
            X_file, Y_file = self.generate_dataset_from_lin_file(filename)
            X.extend(X_file)
            Y.extend(Y_file)
            
        return np.array(X), np.array(Y)
    
    def generate_dataset_from_lin_file(self, filename):
        parser = LinParser()
        games = parser.parse(filename)
        
        X = []
        Y = []
        
        for game in games:
            X_game, Y_game = self.generate_dataset_for_game(game)
            X.extend(X_game)
            Y.extend(Y_game)

        return X, Y
    
    def generate_dataset_for_game(self, game):
        X = []
        Y = []
        
        converter = FeatureConverter()
        bid_converter = BidFeatureConverter()
        
        bid_sequences = self.generate_bid_sequences(game)
        for player, bid_sequences in bid_sequences.items():
            for bid_sequence in bid_sequences:
                game_copy = dict(game)
                game_copy['bids'] = bid_sequence['history']

                x = converter.convert_game(game_copy, player=player)
                y = bid_converter.convert_bid(bid_sequence['bid'])

                X.append(x)
                Y.append(y)
        
        return X, Y
    
    def generate_bid_sequences(self, game):
        feature_converter = BidFeatureConverter()
        
        bid_sequences = {
            'north': [],
            'east': [],
            'south': [],
            'west': [],
        }
        player_sequence = {
            'north': 'east',
            'east': 'south',
            'south': 'west',
            'west': 'north',
        }
            
        current_player = game['dealer']
        for i in range(len(game['bids'])):
            next_player = player_sequence[current_player]
            bid_sequences[next_player].append({
                'history': game['bids'][:i],
                'bid': game['bids'][i],
            })
            current_player = next_player
    
        return bid_sequences
            
        
        
class BidFeatureConverter:
    def __init__(self):
        self.bid_order = self.generate_bid_order()
        self.num_bids = len(self.bid_order)
    
    def convert_bid(self, bid):
        result = [0 for _ in range(self.num_bids)]
        bid_index = self.bid_order.index(bid)
        result[bid_index] = 1
        return result
    
    def generate_bid_order(self):
        order = []
        value_order = ['1', '2', '3', '4', '5', '6', '7']
        suit_order = [Suit('C'), Suit('D'), Suit('H'), Suit('S'), Suit('N')]
            
        for value in value_order:
            for suit in suit_order:
                order.append(Bid(value=value, suit=suit))
                
        order.extend([Bid(value='P'), Bid(value='D'), Bid(value='R')])
        return order
        

In [19]:
class DataGenerator:
    def generate_data(self, path):
        filenames = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('.lin')])
        
        for filename in filenames:
            games = LinParser().parse(filename)
            
            for game in games:
                X_enn, Y_enn = self.generate_enn_data(game)
                X_pnn, Y_pnn = self.generate_pnn_data(game)
                yield X_enn, Y_enn, X_pnn, Y_pnn
    
    def generate_enn_data(self, game):
        X = []
        Y = []
        
        converter = FeatureConverter()
        players = ['north', 'south', 'east', 'west']
        
        bid_sequences = self.generate_bid_sequences(game)
        for player, bid_sequences in bid_sequences.items():
            for bid_sequence in bid_sequences:
                game_copy = dict(game)
                game_copy['bids'] = bid_sequence['history']

                feature = converter.convert_game(game, player=player)
                x = feature[:52] + feature[104:] # own hand, vulnerability, bidding sequence
                y = feature[52:104] # partner's hand

                X.append(x)
                Y.append(y)
        
        return X, Y        
    
    def generate_pnn_data(self, game):
        X = []
        Y = []
        
        converter = FeatureConverter()
        bid_converter = BidFeatureConverter()
        
        bid_sequences = self.generate_bid_sequences(game)
        for player, bid_sequences in bid_sequences.items():
            for bid_sequence in bid_sequences:
                game_copy = dict(game)
                game_copy['bids'] = bid_sequence['history']

                x = converter.convert_game(game_copy, player=player)
                y = bid_converter.convert_bid(bid_sequence['bid'])

                X.append(x)
                Y.append(y)
        
        return X, Y
    
    def generate_bid_sequences(self, game):
        feature_converter = BidFeatureConverter()
        
        bid_sequences = {
            'north': [],
            'east': [],
            'south': [],
            'west': [],
        }
        player_sequence = {
            'north': 'east',
            'east': 'south',
            'south': 'west',
            'west': 'north',
        }
            
        current_player = game['dealer']
        for i in range(len(game['bids'])):
            next_player = player_sequence[current_player]
            bid_sequences[next_player].append({
                'history': game['bids'][:i],
                'bid': game['bids'][i],
            })
            current_player = next_player
    
        return bid_sequences

In [20]:
import json

def process_bridge_files(input_path, output_path):
    data_generator = DataGenerator()
    print_every = 1000
    for i, (X_enn, Y_enn, X_pnn, Y_pnn) in enumerate(data_generator.generate_data(input_path)):
        data = {
            'X_enn': X_enn,
            'Y_enn': Y_enn,
            'X_pnn': X_pnn,
            'Y_pnn': Y_pnn,
        }
        output_filename = f'{i}.json'
        with open(os.path.join(output_path, output_filename), 'w') as f:
            json.dump(data, f)

        if i % print_every == 0:
            print(i)
            
process_bridge_files(input_path='data/www.bridgebase.com/', output_path='processed_data/')

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000


In [19]:
import json
import math
import os
import random

def consolidate_files(files, files_per_output, output_path, prefix='out'):
    def save():
        filename = f'{prefix}-{num_saved_files}.json'
        print(filename)
        with open(os.path.join(output_path, filename), 'w') as out:
            data = {
                'X_enn': X_enn,
                'Y_enn': Y_enn,
                'X_pnn': X_pnn,
                'Y_pnn': Y_pnn,
            }
            json.dump(data, out)
        
    X_enn = []
    Y_enn = []
    X_pnn = []
    Y_pnn = []
    num_saved_files = 0
    
    for i, file in enumerate(files):
        with open(file) as f:
            data = json.load(f)
            
            X_enn.extend(data['X_enn'])
            Y_enn.extend(data['Y_enn'])
            X_pnn.extend(data['X_pnn'])
            Y_pnn.extend(data['Y_pnn'])
            
            if (i + 1) % files_per_output == 0:
                enn = list(zip(X_enn, Y_enn))
                pnn = list(zip(X_pnn, Y_pnn))
                
                random.shuffle(enn)
                random.shuffle(pnn)
                
                X_enn, Y_enn = zip(*enn)
                X_pnn, Y_pnn = zip(*pnn)
                
                num_saved_files += 1
                save()
                    
                X_enn = []
                Y_enn = []
                X_pnn = []
                Y_pnn = []
    
    if X_enn and Y_enn and X_pnn and Y_pnn:
        save()
                
        

def split_dataset(input_path, output_path, training_ratio=0.7, validation_ratio=0.1, test_ratio=0.2):
    assert training_ratio + validation_ratio + test_ratio == 1
    
    files = [os.path.join(input_path, f) for f in os.listdir(input_path) if f.endswith('.json')]
    print(len(files))
    print(files[:10])
    
    random.shuffle(files)
    print(files[:10])
    
    training_size = math.floor(len(files)*training_ratio)
    validation_size = math.floor(len(files)*validation_ratio)
    
    split = {
        'training': files[:training_size],
        'validation': files[training_size:training_size + validation_size],
        'test': files[training_size + validation_size:],
    }
    
    files_per_output = 1000
    consolidate_files(split['training'], files_per_output, os.path.join(output_path, 'training'), prefix='training')
    consolidate_files(split['validation'], files_per_output, os.path.join(output_path, 'validation'), prefix='validation')
    consolidate_files(split['test'], files_per_output, os.path.join(output_path, 'test'), prefix='test')

split_dataset(input_path='processed_data/', output_path='split_data/')

1075226
['processed_data/228132.json', 'processed_data/829401.json', 'processed_data/446327.json', 'processed_data/753873.json', 'processed_data/1058343.json', 'processed_data/410547.json', 'processed_data/798977.json', 'processed_data/588343.json', 'processed_data/477883.json', 'processed_data/543247.json']
['processed_data/1041289.json', 'processed_data/1043858.json', 'processed_data/945254.json', 'processed_data/303549.json', 'processed_data/372768.json', 'processed_data/615070.json', 'processed_data/129298.json', 'processed_data/987919.json', 'processed_data/667279.json', 'processed_data/409648.json']
training-1.json
training-2.json
training-3.json
training-4.json
training-5.json
training-6.json
training-7.json
training-8.json
training-9.json
training-10.json
training-11.json
training-12.json
training-13.json
training-14.json
training-15.json
training-16.json
training-17.json
training-18.json
training-19.json
training-20.json
training-21.json
training-22.json
training-23.json
train

training-429.json
training-430.json
training-431.json
training-432.json
training-433.json
training-434.json
training-435.json
training-436.json
training-437.json
training-438.json
training-439.json
training-440.json
training-441.json
training-442.json
training-443.json
training-444.json
training-445.json
training-446.json
training-447.json
training-448.json
training-449.json
training-450.json
training-451.json
training-452.json
training-453.json
training-454.json
training-455.json
training-456.json
training-457.json
training-458.json
training-459.json
training-460.json
training-461.json
training-462.json
training-463.json
training-464.json
training-465.json
training-466.json
training-467.json
training-468.json
training-469.json
training-470.json
training-471.json
training-472.json
training-473.json
training-474.json
training-475.json
training-476.json
training-477.json
training-478.json
training-479.json
training-480.json
training-481.json
training-482.json
training-483.json
training-4

test-25.json
test-26.json
test-27.json
test-28.json
test-29.json
test-30.json
test-31.json
test-32.json
test-33.json
test-34.json
test-35.json
test-36.json
test-37.json
test-38.json
test-39.json
test-40.json
test-41.json
test-42.json
test-43.json
test-44.json
test-45.json
test-46.json
test-47.json
test-48.json
test-49.json
test-50.json
test-51.json
test-52.json
test-53.json
test-54.json
test-55.json
test-56.json
test-57.json
test-58.json
test-59.json
test-60.json
test-61.json
test-62.json
test-63.json
test-64.json
test-65.json
test-66.json
test-67.json
test-68.json
test-69.json
test-70.json
test-71.json
test-72.json
test-73.json
test-74.json
test-75.json
test-76.json
test-77.json
test-78.json
test-79.json
test-80.json
test-81.json
test-82.json
test-83.json
test-84.json
test-85.json
test-86.json
test-87.json
test-88.json
test-89.json
test-90.json
test-91.json
test-92.json
test-93.json
test-94.json
test-95.json
test-96.json
test-97.json
test-98.json
test-99.json
test-100.json
test-101.js