In [1]:
# Imports
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons, CheckButtons
from mpl_toolkits.mplot3d import Axes3D
import math
import os
import time
import itertools

import warnings
warnings.filterwarnings('ignore')

from matplotlib.patches import Circle, Wedge
from matplotlib.collections import PatchCollection

from scipy.spatial import ConvexHull

plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'monospace'
plt.rcParams['text.latex.preamble'] = r"\usepackage{amsmath}"

%matplotlib notebook

plt.rcParams['figure.figsize'] = [7, 7]

import torch
import torch.nn
import torch.utils.data
import torchvision

In [2]:
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(2, 8)
        self.fc2 = torch.nn.Linear(8, 8)
        self.fc3 = torch.nn.Linear(8, 8)
        self.fc4 = torch.nn.Linear(8, 2)
        
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = self.fc4(x)
        
        return x

In [3]:
def get_dataloader(train=True, r1=0.5, batch_size=100, shuffle=False):
    
    np.random.seed(0)

    r1sqr = r1**2
    r2sqr = 2*r1sqr
    r3sqr = 3*r1sqr

    train_x = []
    train_y = []

    test_x = []
    test_y = []

    while len(train_x) < 100000:
        sample = 2*np.random.uniform(size=(1,2)) - 1
        if np.sum(np.square(sample)) < r1sqr:
            train_x.append(sample)
            train_y.append(np.array([[0, 1]]))
        elif r2sqr < np.sum(np.square(sample)) < r3sqr:
            train_x.append(sample)
            train_y.append(np.array([[1, 0]]))

    test_data = []

    for i in range(101):
        for j in range(101):
            test_data.append([2*i/100 - 1, 2*j/100 - 1])

    test_data = np.array(test_data)

    for elem in test_data:
        if np.sum(np.square(elem)) < r1sqr:
            test_x.append(elem)
            test_y.append(np.array([[0, 1]]))
        elif r2sqr < np.sum(np.square(elem)) < r3sqr:
            test_x.append(elem)
            test_y.append(np.array([[1, 0]]))
    
    if train:
        tensor_x = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in train_x])
        tensor_y = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in train_y])
    else:
        tensor_x = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in test_x])
        tensor_y = torch.stack([torch.tensor(elem.astype(np.float32)) for elem in test_y])
        
    my_dataset = torch.utils.data.TensorDataset(tensor_x, tensor_y)
    return torch.utils.data.DataLoader(my_dataset, batch_size=batch_size, shuffle=shuffle)

In [4]:
train_loader = get_dataloader(train=True, batch_size=100, shuffle=True)
test_loader =  get_dataloader(train=False, batch_size=100, shuffle=True)

In [5]:
timeSeen = 0
accuracies = []

In [6]:
def train(network, optimizer, lossFunc, train_loader, epoch):
    global timeSeen
    global accuracies
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = lossFunc(output, target)
        loss.backward()
        optimizer.step()
        
        accuracies.append(test(network, test_loader))
        
        print(epoch, batch_idx, accuracies[-1])
    
        torch.save(network, 'trainedNetsCircleComplex/network{:09d}.pt'.format(timeSeen))
        timeSeen += 1
        
def test(network, test_loader):
    network.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            pred = output.data.max(1, keepdim=True)[1]
            tar = target.data.max(2, keepdim=True)[1]
            correct += pred.eq(tar.data.view_as(pred)).sum()
    return correct.item()/len(test_loader.dataset)

In [7]:
acc = 0
while max(accuracies + [0]) < 0.98:
    criterion = torch.nn.MSELoss()
    network = SimpleNet()
    optimizer = torch.optim.SGD(network.parameters(), 0.1, momentum=0.9, weight_decay=1e-4)
    timeSeen = 0
    accuracies = []

    torch.save(network, 'trainedNetsCircleComplex/network{:09d}.pt'.format(timeSeen))
    timeSeen += 1
    for i in range(1):
        train(network, optimizer, criterion, train_loader, i)
        #acc = test(network, test_loader)
        print(i, accuracies[-1])

0 0 0.49732074508803265
0 1 0.5427404950242409
0 2 0.511099770349579
0 3 0.5026792549119673
0 4 0.5026792549119673
0 5 0.5026792549119673
0 6 0.5026792549119673
0 7 0.5669813727991835
0 8 0.49732074508803265
0 9 0.49732074508803265
0 10 0.49732074508803265
0 11 0.49732074508803265
0 12 0.49732074508803265
0 13 0.49732074508803265
0 14 0.49732074508803265
0 15 0.49732074508803265
0 16 0.49732074508803265
0 17 0.5026792549119673
0 18 0.5026792549119673
0 19 0.5026792549119673
0 20 0.5026792549119673
0 21 0.5026792549119673
0 22 0.5026792549119673
0 23 0.5026792549119673
0 24 0.5026792549119673
0 25 0.7213574891553968
0 26 0.543250829293187
0 27 0.49732074508803265
0 28 0.60398060729778
0 29 0.6402143403929574
0 30 0.6624138810921153
0 31 0.6629242153610615
0 32 0.5026792549119673
0 33 0.5026792549119673
0 34 0.5026792549119673
0 35 0.5026792549119673
0 36 0.5026792549119673
0 37 0.5026792549119673
0 38 0.5026792549119673
0 39 0.5026792549119673
0 40 0.5026792549119673
0 41 0.502679254911

0 667 1.0
0 668 1.0
0 669 1.0
0 670 1.0
0 671 1.0
0 672 1.0
0 673 1.0
0 674 1.0
0 675 1.0
0 676 1.0
0 677 1.0
0 678 1.0
0 679 1.0
0 680 1.0
0 681 1.0
0 682 1.0
0 683 1.0
0 684 1.0
0 685 1.0
0 686 1.0
0 687 1.0
0 688 1.0
0 689 1.0
0 690 1.0
0 691 1.0
0 692 1.0
0 693 1.0
0 694 1.0
0 695 1.0
0 696 1.0
0 697 1.0
0 698 1.0
0 699 1.0
0 700 1.0
0 701 1.0
0 702 1.0
0 703 1.0
0 704 1.0
0 705 1.0
0 706 1.0
0 707 1.0
0 708 1.0
0 709 1.0
0 710 1.0
0 711 1.0
0 712 1.0
0 713 1.0
0 714 1.0
0 715 1.0
0 716 1.0
0 717 1.0
0 718 1.0
0 719 1.0
0 720 1.0
0 721 1.0
0 722 1.0
0 723 1.0
0 724 1.0
0 725 1.0
0 726 1.0
0 727 1.0
0 728 1.0
0 729 1.0
0 730 1.0
0 731 1.0
0 732 1.0
0 733 1.0
0 734 1.0
0 735 1.0
0 736 1.0
0 737 1.0
0 738 1.0
0 739 1.0
0 740 1.0
0 741 1.0
0 742 1.0
0 743 1.0
0 744 1.0
0 745 1.0
0 746 1.0
0 747 1.0
0 748 1.0
0 749 1.0
0 750 1.0
0 751 1.0
0 752 1.0
0 753 1.0
0 754 1.0
0 755 1.0
0 756 1.0
0 757 1.0
0 758 1.0
0 759 1.0
0 760 1.0
0 761 1.0
0 762 1.0
0 763 1.0
0 764 1.0
0 765 1.0
0 766 1.0


In [8]:
np.save('accuraciesCircleComplex.npy', accuracies)
np.save('indicesCircleComplex.npy', range(timeSeen-1))