In [1]:
import pylab
from matplotlib import gridspec
from sklearn.datasets import make_classification
import numpy as np
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets
import pickle
import os


# pick the seed for reproducability - change it to explore the effects of random variations
np.random.seed(1)
import random

In [2]:
!wget https://raw.githubusercontent.com/shwars/NeuroWorkshop/master/Data/MNIST/mnist.pkl.gz
!gzip -d mnist.pkl.gz

--2022-03-11 16:54:42--  https://raw.githubusercontent.com/shwars/NeuroWorkshop/master/Data/MNIST/mnist.pkl.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10392609 (9.9M) [application/octet-stream]
Saving to: 'mnist.pkl.gz'

     0K .......... .......... .......... .......... ..........  0% 1.62M 6s
    50K .......... .......... .......... .......... ..........  0% 7.08M 4s
   100K .......... .......... .......... .......... ..........  1% 2.95M 4s
   150K .......... .......... .......... .......... ..........  1% 5.75M 3s
   200K .......... .......... .......... .......... ..........  2% 2.96M 3s
   250K .......... .......... .......... .......... ..........  2% 7.06M 3s
   300K .......... .......... .......... .......... ..........  3% 7.27M 3s
   350K

In [3]:
with open('mnist.pkl', 'rb') as mnist_pickle:
    MNIST = pickle.load(mnist_pickle)

In [4]:
all_len = MNIST['Train']['Features'].shape[0]
learn_len = int(0.8 * all_len)
print("Всего:", all_len)
print("Тренировочный датасет:", learn_len)
print("Тестовый датасет:", all_len - learn_len)

Всего: 42000
Тренировочный датасет: 33600
Тестовый датасет: 8400


In [5]:
def train_graph(positive_examples, negative_examples, num_iterations = 100):
    num_dims = positive_examples.shape[1]
    weights = np.zeros((num_dims,1)) # инициализируем веса
    
    pos_count = positive_examples.shape[0]
    neg_count = negative_examples.shape[0]
    
    report_frequency = 15;
    snapshots = []
    
    for i in range(num_iterations):
        pos = random.choice(positive_examples)
        neg = random.choice(negative_examples)

        z = np.dot(pos, weights)   
        if z < 0:
            weights = weights + pos.reshape(weights.shape)

        z  = np.dot(neg, weights)
        if z >= 0:
            weights = weights - neg.reshape(weights.shape)
            
        if i % report_frequency == 0:             
            pos_out = np.dot(positive_examples, weights)
            neg_out = np.dot(negative_examples, weights)        
            pos_correct = (pos_out >= 0).sum() / float(pos_count)
            neg_correct = (neg_out < 0).sum() / float(neg_count)
            snapshots.append((np.copy(weights),(pos_correct+neg_correct)/2.0))

    return np.array(snapshots, dtype=object)

In [6]:
def set_mnist_pos_neg(positive_label, negative_labels):
    positive_indices = [i for i, j in enumerate(MNIST['Train']['Labels']) 
                          if j == positive_label and i < learn_len]
    negative_indices = [i for i, j in enumerate(MNIST['Train']['Labels']) 
                          if j in negative_labels and i < learn_len]

    positive_images = MNIST['Train']['Features'][positive_indices]
    negative_images = MNIST['Train']['Features'][negative_indices]
    
    return positive_images, negative_images

In [7]:
def train_all():
    all_weights = []
    for i in range(10):
        pos, neg = set_mnist_pos_neg(i, set(range(10)) - set([i]))
        snapshots = train_graph(pos,neg,1000)
        print(i, snapshots[-1][1])
        all_weights.append(snapshots[-1][0])
    return all_weights

In [8]:
def predict(image, all_weights):
    answer = []
    for i in range(10):
        answer.append(np.dot(image, all_weights[i]))
    return answer.index(max(answer))

In [9]:
def accuracy(all_weights):
    correct = 0
    for i in range(learn_len, all_len):
        if predict(MNIST['Train']['Features'][i], all_weights) == MNIST['Train']['Labels'][i]:
            correct += 1
    return correct / (all_len - learn_len)

In [12]:
all_weights = train_all()

0 0.9731538387283609
1 0.9770322706553833
2 0.9098389999222237
3 0.882737600725128
4 0.9514562566579376
5 0.8822740318300746
6 0.9523278345295356
7 0.940708830864123
8 0.862422048818301
9 0.8979022447995973


In [13]:
print("Точность на тестовом датасете:", accuracy(all_weights))

Точность на тестовом датасете: 0.8277380952380953
