In [2]:
import csv
import math
import os.path
import urllib.request

import numpy as np

In [3]:
def mwu(expert_classification_list, correct_values_list, e=None):
    """
        expert_classification: List of classification list for each expert
        e_param: Param e in MWU formula  
        
        Returns: list of p for each expert 
    """
    loss = [0]
    m = len(expert_classification_list)  # number of experts
    T = len(expert_classification_list[0])  # total steps
    if e is None:  # use default
        e = math.sqrt(math.log(m) / T) 
    w = [1 for _ in range(m)]  # initial weights
    p = [1 / w[i] for i in range(m)]  # initial probabilities
    
    for t in range(T):
        # If wants to choose an expert to follow
        # expert_choosed = np.random.choice([i for i in range(T)], 1, p=p)[0]
        expert_losses = [1 if expert_classification_list[i][t] != correct_values_list[t] 
                         else 0 
                         for i in m]
        current_loss = sum([expert_losses[i] * p[i] 
                            for i in range(T)])
        loss.append(loss[t] + current_loss)
        
        for i in range(m):
            w[i] = w[i] * math.e ** (-e * expert_losses[i])
        
        sum_weights = sum(w)
        for i in range(m):
            p[i] = w[i] / sum_weights
            
    return p, loss        

In [5]:
# Open Mnist data
# https://pjreddie.com/projects/mnist-in-csv/

file_url = 'https://pjreddie.com/media/files/'

test_data = 'mnist_test.csv'
train_data = 'mnist_train.csv'

# download file if not exists:
for file_name in [test_data, train_data]:
    if not os.path.isfile(file_name):
        print(f"downloading {file_name}...")
        urllib.request.urlretrieve(file_url + file_name, file_name) 

train_X = []
train_Y = []
test_X = []
test_Y = []

#extract only zero and ones
allowed_numbers = [0, 1]

print("parsing files...")
with open(train_data) as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    for row in csv_reader:
        int_row = [int(value) for value in row]
        if int_row[0] in allowed_numbers:
            train_X.append(int_row[1:])
            train_Y.append(int_row[0])
            
with open(test_data) as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    for row in csv_reader:
        int_row = [int(value) for value in row]
        if int_row[0] in allowed_numbers:
            test_X.append(int_row[1:])
            test_Y.append(int_row[0])
print('finish')

parsing files...
