In [None]:
import random
import matplotlib.pyplot as plt
import math
import numpy as np
import seaborn as sns
import scipy.io
import copy
from itertools import islice
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn.init as init
from google.colab import files
from io import BytesIO

random.seed(3242023)
torch.manual_seed(3242023) # Seeded with a constant, so that behavior is deterministic.
torch.cuda.manual_seed_all(3242023)

# TODO not currently available torch.use_deterministic_algorithms(True)

if torch.cuda.is_available():
    device = "cuda"
    def print_memory_usage():
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
else:
    device = "cpu"
    def print_memory_usage(): print("Using cpu")
torch.set_default_device(device)


print(f"Pytorch running on {device}")

In [None]:
# n = lambda x: nn.GroupNorm(1, x)
n = nn.BatchNorm2d

def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = n(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = n(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     n(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = n(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def resnet20():
    return ResNet(BasicBlock, [3, 3, 3]).to(device)

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [None]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

def cov(mat):
    """
    mat should be a matrix whose rows are separate samples.
    """
    xbar = torch.zeros(mat.size()[1], device=device)
    for row in mat:
        xbar += row
    xbar /= mat.size()[0]
    matdelta = [row-xbar for row in mat]
    print(f"Squares of L2 norms of X-Xbar: {[torch.nansum(torch.square(row)) for row in matdelta]}")
    ans = []
    for row in matdelta:
        unsqueezed_row = row.unsqueeze(0)
        ans.append(torch.matmul(unsqueezed_row.t(), unsqueezed_row))
    return sum(ans)/(len(ans)-1)

In [None]:
def convert_to_one_hot(data): # Taken from https://stackoverflow.com/questions/36960320/convert-a-2d-matrix-to-a-3d-one-hot-matrix-numpy
    """
    WARNING: mutates input, incrementing each value in it by 1
    """
    data += 1
    return np.squeeze((np.arange(data.max()) == data[...,None]-1).astype(int))

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # This intentionally doesn't include any data augmentation. That comes later (TODO it hasn't been implemented yet)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
train, ans = zip(*trainset)
test, test_ans = zip(*testset)
train_mat = torch.stack(train).to(device)
test_mat = torch.stack(test).to(device)
ans_mat = convert_to_one_hot(np.array(ans))
ans_mat_test = convert_to_one_hot(np.array(test_ans))
print(train_mat.shape, test_mat.shape, ans_mat.shape, ans_mat_test.shape)

def augment_train_data(train_data_in, labels):
    temp = torch.cat([train_data_in, transforms.functional.hflip(train_data_in)], dim=0)
    temp_labels = torch.cat([labels, labels], dim=0)
    ans = torch.cat([temp, transforms.functional.vflip(temp)], dim=0)
    ans_labels = torch.cat([temp_labels, temp_labels], dim=0)
    return ans, ans_labels

In [None]:
def get_params(layer):
    return torch.cat((layer.weight.data, layer.bias.data.unsqueeze(1)), 1)

def test_model(model, test_data, test_labels, microbatch_size=1000):
    assert test_data.shape[0] % microbatch_size == 0 # TODO this isn't the most elegant
    accuracies = []
    for microbatch_idx in range(0, test_data.shape[0], microbatch_size):
        test_out = model(test_data[microbatch_idx:microbatch_idx+microbatch_size])
        accuracies.append(get_accuracy(test_out.detach(), test_labels[microbatch_idx:microbatch_idx+microbatch_size]))
    return sum(accuracies)/len(accuracies)

def clip_group_sample_grad(grad_list, max_norm=1.0):
    reshaped_grad_list = []
    for part_grad in grad_list:
        reshaped_grad_list.append(torch.reshape(part_grad, (-1,)))
    grad_tensor = torch.cat(reshaped_grad_list, 0)

    l2_norm = torch.norm(grad_tensor, p=2)
    print(f"Current norm: {l2_norm}")

    if l2_norm > max_norm:
        clip_strength = max_norm/l2_norm
        for grad in grad_list:
            grad *= clip_strength

    return grad_list

def measure_model_stability(models):
    coeffs_list = []

    for model in models:
        coeffs_list.append(np.array(torch.cat([p.flatten() for p in model.parameters()]).detach().cpu()))

    coeffs_list = [i.flatten() for i in coeffs_list]
    coeffs_norms = [math.sqrt(np.dot(i, i)) for i in coeffs_list]

    num_trials = len(models)

    avg_coeff = sum(coeffs_list)/num_trials
    coeff_diffs_list = [i-avg_coeff for i in coeffs_list]
    coeffs_diffs_norms = [math.sqrt(np.dot(i, i)) for i in coeff_diffs_list]

    avg_l2 = sum(coeffs_norms)/num_trials
    avg_l2_deviation = sum(coeffs_diffs_norms)/num_trials

    return avg_l2, avg_l2_deviation

def train_models_divergent_one_step(models_in, optimizers, data_sets, label_sets, measure_stability=True, criterion=nn.CrossEntropyLoss(), test_data=None, test_labels=None, microbatch_size=100, clipping_func = lambda x: x, train_augment=True):
    assert len(data_sets) == len(label_sets) and len(data_sets) == len(models_in)
    accuracies = []
    for i in range(len(data_sets)):
        all_group_sample_gradients = []
        if train_augment:
            augmented_data, augmented_ans = augment_train_data(data_sets[i], label_sets[i])
        else:
            augmented_data, augmented_ans = data_sets[i], label_sets[i]
        for j in range(0, len(data_sets[i]), microbatch_size):
            microbatch = augmented_data[j:j+microbatch_size]
            microbatch_ans = augmented_ans[j:j+microbatch_size]

            out = models_in[i](microbatch)
            loss = criterion(out, microbatch_ans)
            loss.backward()

            group_sample_gradients = clipping_func([p.grad.detach().clone() for p in models_in[i].parameters()])

            all_group_sample_gradients.append(group_sample_gradients)
            optimizers[i].zero_grad()

        num_microbatches = len(all_group_sample_gradients) # MICROBATCHING IS HERE
        processed_grads = [sum([idx[k] for idx in all_group_sample_gradients])/num_microbatches for k in range(len(all_group_sample_gradients[0]))]

        for idx, p in enumerate(models_in[i].parameters()):
            p.grad = processed_grads[idx]

        optimizers[i].step()
        if test_data is not None:
            accuracy = test_model(models_in[i], test_data, test_labels)
            print(f"Model {i} has an accuracy of {accuracy}")
            accuracies.append(accuracy)
    if measure_stability:
        if len(accuracies) > 0:
            return (measure_model_stability(models_in), sum(accuracies)/len(accuracies))
        else:
            return measure_model_stability(models_in)

def train_models_divergent(train_set, train_ans, num_datapoints=10000, num_trials=8, steps=16, microbatch_size=100, test_data=None, test_labels=None, graph_results=True, clipping_threshold=1.0, base_model=None, train_augment=True):
    """
    returns a list of tuples. Each tuple represents one epoch and has two
    elements. The first element is the average l2 norm of the models after that
    epoch. The second element is the average deviation in that same epoch.
    """
    if microbatch_size <= 0:
        microbatch_size = num_datapoints
    data_sets, label_sets = random_sample(train_set, train_ans, num_datapoints, num_trials)
    # data_sets, label_sets = random_sample_some_removed(train_set, train_ans, num_datapoints, 1, num_trials)
    data_sets = torch.Tensor(data_sets).to(device)
    label_sets = torch.Tensor(label_sets).to(device)
    if test_data is not None:
        test_data = torch.Tensor(test_data).to(device)
        test_labels = torch.Tensor(test_labels).to(device)

    if base_model is None:
        base_model = resnet20()
    models = [copy.deepcopy(base_model) for i in range(num_trials)]
    optimizers = [torch.optim.SGD(m.parameters(), lr=0.1) for m in models] # TODO lower

    stability_data = []
    accuracy_data = []

    if clipping_threshold <= 0:
        temp_clip_func = lambda x: x
    else:
        temp_clip_func = lambda x: clip_group_sample_grad(x, clipping_threshold)

    for step in range(steps):
        if test_data is not None:
            print(f"Now beginning epoch {step}")
        results = train_models_divergent_one_step(models, optimizers, data_sets, label_sets, True, test_data=test_data, test_labels=test_labels, microbatch_size=microbatch_size, clipping_func=temp_clip_func, train_augment=train_augment)
        stability_data.append(results[0])
        if test_data is not None:
            accuracy_data.append(results[1])

    if graph_results:
        l2_fractional_deviation = [data[1]/data[0] for data in stability_data]
        plt.plot([0] + l2_fractional_deviation)
    if test_data is not None:
        return (stability_data, models, accuracy_data)
    return (stability_data, models)

In [None]:
def add_noise_for_one_epoch(model, norm):
    param_vector = nn.utils.parameters_to_vector(model.parameters())
    num_params = len(param_vector)
    noise = torch.randn(num_params)*norm
    param_vector.add_(noise)
    nn.utils.vector_to_parameters(param_vector, model.parameters())
    return model

def train_models_checkpointed(train_set, train_ans, test_set=None, test_ans=None, num_datapoints=10000, num_models=8, epochs_per_group=10, total_groups=10, microbatch_size=0, clipping_threshold=0, graph_results=False, noise_level=0.0, train_augment=True):
    current_model = None
    data = []
    accuracies = []
    for i in range(total_groups):
        print(f"\nBeginning training group {i} of {total_groups}\n\n")
        out = train_models_divergent(train_set, train_ans, num_datapoints, num_models, epochs_per_group, microbatch_size, test_set, test_ans, False, clipping_threshold, current_model, train_augment)
        current_model = out[1][0]
        data.append(out[0])
        if test_set is not None: accuracies.append(out[2])
        if noise_level > 0.0:
            add_noise_for_one_epoch(current_model, out[0][-1][1]*noise_level)

    if graph_results: # TODO this needs to be completely revamped. As of right now, it's minimally useful.
        for trial in data:
            plt.plot([epoch[1]/epoch[0] for epoch in trial])
        plt.show()
        for trial in data:
            plt.plot([epoch[1] for epoch in trial])
        plt.show() # TODO this can and should be prettier
    if test_set is not None: return (data, accuracies)
    return data

In [None]:
def random_sample(data, labels, qty=1000, num_samples=1):
    ans_data = []
    ans_labels = []
    dataset_size = data.shape[0]
    for i in range(num_samples):
        listed = np.random.choice(range(dataset_size), qty, replace=False)
        ans_data.append(np.take(data.cpu(), listed, axis=0))
        ans_labels.append(np.take(labels, listed, axis=0))
    return np.stack(ans_data), np.stack(ans_labels)

def random_sample_some_removed(data, labels, qty=1000, samples_removed=1, num_samples=1):
    ans_data = []
    ans_labels = []
    dataset_size = data.shape[0]
    sample_pool = np.random.choice(range(dataset_size), qty, replace=False)
    for i in range(num_samples):
        listed = np.random.choice(sample_pool, qty-samples_removed, replace=False)
        ans_data.append(np.take(data, listed, axis=0))
        ans_labels.append(np.take(labels, listed, axis=0))
    return np.stack(ans_data), np.stack(ans_labels)

def grade_prediction_one_hot(prediction, truth):
    return truth[np.argmax(prediction.cpu())] == 1

def get_accuracy(output, ground_truth, grader = grade_prediction_one_hot):
    samples = output.shape[0]
    successes = 0
    for sample in range(samples):
        if grader(output[sample], ground_truth[sample]):
            successes += 1
    return successes/samples

def linear_regression(train_data, train_ans):
    if len(train_ans.shape) <= 2: # If data is not 1-hot
        return np.linalg.lstsq(train_data, train_ans)[0]
    ans = []
    for batch in range(train_ans.shape[0]):
        ans.append(np.linalg.lstsq(train_data[batch], train_ans[batch])[0])
    return np.stack(ans)

def test_stability(train_in, train_ans, num_removed = None, printout = True, train_printout = True, test_data = None, test_ans = None):
    """
    Returns a tuple whose first element is the average model l2 norm and whose
    second element is the standard deviation of l2 norms.
    """
    coeffs_list = []
    for i in range(train_in.shape[0]):
        coeffs = np.linalg.lstsq(train_in[i], train_ans[i])[0]
        if train_printout:
            train_results = np.matmul(train_in[i], coeffs)
            print("Train accuracy: ", get_accuracy(train_results, train_ans[i]))
        if test_data is not None and test_ans is not None:
            test_results = np.matmul(test_data, coeffs)
            print("Test accuracy: ", get_accuracy(test_results, test_ans))
        coeffs_list.append(coeffs)

    coeffs = np.stack(coeffs_list).squeeze()
    coeffs_list = [i.flatten() for i in coeffs_list]
    coeffs_norms = [math.sqrt(np.dot(i, i)) for i in coeffs_list]

    num_trials = len(coeffs_norms)

    avg_coeff = sum(coeffs_list)/num_trials
    coeff_diffs_list = [i-avg_coeff for i in coeffs_list]
    coeffs_diffs_norms = [math.sqrt(np.dot(i, i)) for i in coeff_diffs_list]

    avg_l2 = sum(coeffs_norms)/num_trials
    avg_l2_deviation = sum(coeffs_diffs_norms)/num_trials

    if printout:
        print("")
        if num_removed == None:
            print(f"{num_trials} were conducted. Each trial contained {train_in.shape[1]} data points.")
            print(f"Average l2 norm was {avg_l2:.2f}. We expect the average deviation to be equal to the average multiplied by the fraction of samples removed (e.g. for removing 1 sample from a set of 1000, divide by 1000)")
        else:
            print(f"{num_trials} trials were conducted. A bank of {train_in.shape[1]+num_removed} data points was used with each trial omitting {num_removed} data points")
            print(f"Average l2 norm was {avg_l2:.2f}. We expect the average deviation to be equal to 1/{(train_in.shape[1]+num_removed)/num_removed}")
        print(f"Average l2 norm deviation was {avg_l2_deviation:.2f}. This is equal to {100*avg_l2_deviation/avg_l2:.4f}% of the average l2 norm, or 1/{avg_l2/avg_l2_deviation:.1f}")

    return avg_l2, avg_l2_deviation

# TODO when using small amounts of (raw, doing this with preprocessed data would be v. bad) data, use data augmentation (e.g. flip, crop, etc.) to act as if we have large amounts of data.
# TODO this is the important one. From here on out, train a Resnet with 1,000 samples per group, 25 groups TODO add noise after each group proportional to l2 norm of that group. fine-tune # steps per group for maxmial privacy-utility tradeoff. TODO also try this with group-sample clipping.
# TODO noise for the above is gaussian noise with norm equal to absolute l2 deviation.
# TODO test BatchNorm vs GroupNorm in the Resnet20.

stability_data = []
for run_data in ((0, 0.0),): # ((100, 2.5),):
    stability_data.append(train_models_checkpointed(train_mat, ans_mat, test_mat, ans_mat_test, 1000, 16, 16, 8, run_data[0], run_data[1], True, 0.0, True))

In [None]:
# for hyperparams in stability_data:
#     for series in hyperparams:
#         plt.plot([0] + [i[1] for i in series])
#     plt.show()
print(len(hyperparams[0][0][0]))
for hyperparams in stability_data:
    list_of_points = [0]
    for series in hyperparams[0]:
        previous_stopping_point = list_of_points[-1]
        for point in series:
            print(point)
            list_of_points.append(point[1]+previous_stopping_point)
    plt.plot(list_of_points)
    plt.show()
    list_of_points = []
    for series in hyperparams[1]:
        for point in series:
            list_of_points.append(point)
    plt.plot(list_of_points)
    plt.show()

In [None]:
for series in stability_data:
    for i in series:
        print(i[-1], end=", ")