In [1]:
import numpy as np
import tensorflow as tf
import torch
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from dataloader import get_dataloaders
import numpy as np
import warnings
from tqdm import tqdm
from time import time
import json
from sklearn.metrics import precision_recall_fscore_support
import dni
warnings.filterwarnings("ignore")


In [2]:

def one_hot(indexes, n_classes):
    result = torch.FloatTensor(indexes.size() + (n_classes,))
    result = result.cpu()
    result.zero_()
    indexes_rank = len(indexes.size())
    result.scatter_(
        dim=indexes_rank, index=indexes.data.unsqueeze(dim=indexes_rank), value=1
    )
    return Variable(result)


class Four_Layer_SG(nn.Module):
    """4-layer CNN as described in the paper"""

    def __init__(self):
        super(Four_Layer_SG, self).__init__()

        self.block_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=(12, 12),
                stride=(4, 4),
                padding=(4, 4),
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.classifier = nn.Sequential(
            nn.Linear(12 * 12 * 32, 512),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(512, 10),
        )

        self.backward_interface_1 = dni.BackwardInterface(
            dni.BasicSynthesizer(output_dim=12, n_hidden=1)
        )

    def forward(self, x, y=None):
        x = self.block_1(x)
        if self.training:
            # context = one_hot(y, 10)
            # with dni.synthesizer_context(context):
            x = self.backward_interface_1(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)


def test(model, test_loader):
    test_metrics = []

    model.eval()
    model.cuda()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()

            pred = output.data.max(1, keepdim=True)[
                1
            ]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            total += target.size(0)

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100 * correct / total

    prec, recall, f1, support = precision_recall_fscore_support(
        target.cpu(), pred.cpu(), average="weighted"
    )

    test_metrics.append((test_accuracy.item(), prec, recall, f1, support))

    print(
        "\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), test_accuracy
        )
    )
    return test_metrics


def train_sg(model, epoch, train_loader, log_interval, verbose=False):
    model.train()
    model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=3e-5)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data, target)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if verbose:
            if batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(data),
                        loss.data.item(),
                    )
                )

In [6]:
torch.cuda.device_count()

2

In [3]:

n_epochs = 10000
log_interval = 20

dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

batch_size_train = 256  # specified in the paper
train_loader, test_loader = get_dataloaders(1000, 200, batch_size_train)

model_dict = {
    "Four_Layer_SG": Four_Layer_SG,
}

# with open("data_SG.json", "r") as fp:
#     data = json.load(fp)

for key in model_dict.keys():
    model = model_dict[key]()
    torch.nn.DataParallel(model)
    stats = []
    test(model, test_loader)
    print(key, "\n")
    for epoch in tqdm(range(1, n_epochs + 1), leave=False):
        start_time = time()

        train_sg(model, epoch, train_loader, log_interval, verbose=False)

        end_time = time()
        time_taken = end_time - start_time

        metrics = test(model, test_loader)
        metrics.append(time_taken)

        stats.append(metrics)

    data[key] = stats
    # clear gpu memory
    torch.cuda.empty_cache()

# with open("data_SG.json", "w") as fp:
#     json.dump(data, fp)

  0%|                                                                                                                                                                                       | 0/10000 [00:00<?, ?it/s]


Test set: Avg. loss: 2.3177, Accuracy: 17/200 (8%)

Four_Layer_SG 



  0%|                                                                                                                                                                            | 1/10000 [00:03<10:57:49,  3.95s/it]


Test set: Avg. loss: 2.2386, Accuracy: 65/200 (32%)



  0%|                                                                                                                                                                             | 2/10000 [00:04<4:56:53,  1.78s/it]


Test set: Avg. loss: 2.1280, Accuracy: 104/200 (52%)



  0%|                                                                                                                                                                             | 3/10000 [00:04<3:00:08,  1.08s/it]


Test set: Avg. loss: 1.9809, Accuracy: 122/200 (61%)



  0%|                                                                                                                                                                             | 4/10000 [00:04<2:05:11,  1.33it/s]


Test set: Avg. loss: 1.7991, Accuracy: 130/200 (65%)



  0%|                                                                                                                                                                             | 5/10000 [00:04<1:35:41,  1.74it/s]


Test set: Avg. loss: 1.5796, Accuracy: 137/200 (68%)



  0%|                                                                                                                                                                             | 6/10000 [00:05<1:17:24,  2.15it/s]


Test set: Avg. loss: 1.3486, Accuracy: 141/200 (70%)



  0%|                                                                                                                                                                             | 7/10000 [00:05<1:05:30,  2.54it/s]


Test set: Avg. loss: 1.1338, Accuracy: 142/200 (71%)



  0%|▏                                                                                                                                                                              | 8/10000 [00:05<57:20,  2.90it/s]


Test set: Avg. loss: 0.9564, Accuracy: 146/200 (73%)



  0%|▏                                                                                                                                                                              | 9/10000 [00:05<52:11,  3.19it/s]


Test set: Avg. loss: 0.8123, Accuracy: 151/200 (76%)



  0%|▏                                                                                                                                                                             | 10/10000 [00:06<48:15,  3.45it/s]


Test set: Avg. loss: 0.7213, Accuracy: 155/200 (78%)



  0%|▏                                                                                                                                                                             | 11/10000 [00:06<46:49,  3.56it/s]


Test set: Avg. loss: 0.6385, Accuracy: 154/200 (77%)



  0%|▏                                                                                                                                                                             | 12/10000 [00:06<45:20,  3.67it/s]


Test set: Avg. loss: 0.5809, Accuracy: 157/200 (78%)



  0%|▏                                                                                                                                                                             | 13/10000 [00:06<44:56,  3.70it/s]


Test set: Avg. loss: 0.5402, Accuracy: 160/200 (80%)



  0%|▏                                                                                                                                                                             | 14/10000 [00:07<43:03,  3.87it/s]


Test set: Avg. loss: 0.4998, Accuracy: 159/200 (80%)



  0%|▎                                                                                                                                                                             | 15/10000 [00:07<42:39,  3.90it/s]


Test set: Avg. loss: 0.4828, Accuracy: 162/200 (81%)



  0%|▎                                                                                                                                                                             | 16/10000 [00:07<42:31,  3.91it/s]


Test set: Avg. loss: 0.4557, Accuracy: 164/200 (82%)



  0%|▎                                                                                                                                                                             | 17/10000 [00:07<42:50,  3.88it/s]


Test set: Avg. loss: 0.4325, Accuracy: 166/200 (83%)



  0%|▎                                                                                                                                                                             | 18/10000 [00:08<42:28,  3.92it/s]


Test set: Avg. loss: 0.4183, Accuracy: 167/200 (84%)



  0%|▎                                                                                                                                                                             | 19/10000 [00:08<41:59,  3.96it/s]


Test set: Avg. loss: 0.4086, Accuracy: 168/200 (84%)



  0%|▎                                                                                                                                                                             | 20/10000 [00:08<41:07,  4.04it/s]


Test set: Avg. loss: 0.3939, Accuracy: 166/200 (83%)



  0%|▎                                                                                                                                                                             | 21/10000 [00:08<41:23,  4.02it/s]


Test set: Avg. loss: 0.3807, Accuracy: 170/200 (85%)



  0%|▍                                                                                                                                                                             | 22/10000 [00:09<41:09,  4.04it/s]


Test set: Avg. loss: 0.3621, Accuracy: 172/200 (86%)



  0%|▍                                                                                                                                                                             | 23/10000 [00:09<40:44,  4.08it/s]


Test set: Avg. loss: 0.3606, Accuracy: 170/200 (85%)



  0%|▍                                                                                                                                                                             | 24/10000 [00:09<40:08,  4.14it/s]


Test set: Avg. loss: 0.3499, Accuracy: 172/200 (86%)



  0%|▍                                                                                                                                                                             | 25/10000 [00:09<40:19,  4.12it/s]


Test set: Avg. loss: 0.3549, Accuracy: 173/200 (86%)



  0%|▍                                                                                                                                                                             | 26/10000 [00:10<41:03,  4.05it/s]


Test set: Avg. loss: 0.3452, Accuracy: 174/200 (87%)



  0%|▍                                                                                                                                                                             | 27/10000 [00:10<41:03,  4.05it/s]


Test set: Avg. loss: 0.3378, Accuracy: 174/200 (87%)



  0%|▍                                                                                                                                                                             | 28/10000 [00:10<41:12,  4.03it/s]


Test set: Avg. loss: 0.3278, Accuracy: 175/200 (88%)



  0%|▌                                                                                                                                                                             | 29/10000 [00:10<41:18,  4.02it/s]


Test set: Avg. loss: 0.3221, Accuracy: 176/200 (88%)



  0%|▌                                                                                                                                                                             | 30/10000 [00:11<41:07,  4.04it/s]


Test set: Avg. loss: 0.3212, Accuracy: 180/200 (90%)



  0%|▌                                                                                                                                                                             | 31/10000 [00:11<40:38,  4.09it/s]


Test set: Avg. loss: 0.3100, Accuracy: 178/200 (89%)



  0%|▌                                                                                                                                                                             | 32/10000 [00:11<40:06,  4.14it/s]


Test set: Avg. loss: 0.3059, Accuracy: 179/200 (90%)



  0%|▌                                                                                                                                                                             | 33/10000 [00:11<40:37,  4.09it/s]


Test set: Avg. loss: 0.3002, Accuracy: 180/200 (90%)



  0%|▌                                                                                                                                                                             | 34/10000 [00:12<40:46,  4.07it/s]


Test set: Avg. loss: 0.3021, Accuracy: 178/200 (89%)



  0%|▌                                                                                                                                                                             | 35/10000 [00:12<40:10,  4.13it/s]


Test set: Avg. loss: 0.2946, Accuracy: 180/200 (90%)



  0%|▋                                                                                                                                                                             | 36/10000 [00:12<40:34,  4.09it/s]


Test set: Avg. loss: 0.3021, Accuracy: 180/200 (90%)



  0%|▋                                                                                                                                                                             | 37/10000 [00:12<40:24,  4.11it/s]


Test set: Avg. loss: 0.2976, Accuracy: 181/200 (90%)



  0%|▋                                                                                                                                                                             | 38/10000 [00:13<41:12,  4.03it/s]


Test set: Avg. loss: 0.2927, Accuracy: 180/200 (90%)



  0%|▋                                                                                                                                                                             | 39/10000 [00:13<41:17,  4.02it/s]


Test set: Avg. loss: 0.3034, Accuracy: 181/200 (90%)



  0%|▋                                                                                                                                                                             | 40/10000 [00:13<41:17,  4.02it/s]


Test set: Avg. loss: 0.2828, Accuracy: 180/200 (90%)



  0%|▋                                                                                                                                                                             | 41/10000 [00:13<40:54,  4.06it/s]


Test set: Avg. loss: 0.2876, Accuracy: 181/200 (90%)



  0%|▋                                                                                                                                                                             | 42/10000 [00:14<41:34,  3.99it/s]


Test set: Avg. loss: 0.2833, Accuracy: 182/200 (91%)



  0%|▋                                                                                                                                                                             | 43/10000 [00:14<40:41,  4.08it/s]


Test set: Avg. loss: 0.2927, Accuracy: 181/200 (90%)



  0%|▊                                                                                                                                                                             | 44/10000 [00:14<40:26,  4.10it/s]


Test set: Avg. loss: 0.2865, Accuracy: 181/200 (90%)



  0%|▊                                                                                                                                                                             | 45/10000 [00:14<40:39,  4.08it/s]


Test set: Avg. loss: 0.2849, Accuracy: 182/200 (91%)



  0%|▊                                                                                                                                                                             | 46/10000 [00:15<40:17,  4.12it/s]


Test set: Avg. loss: 0.2768, Accuracy: 181/200 (90%)



  0%|▊                                                                                                                                                                             | 47/10000 [00:15<40:04,  4.14it/s]


Test set: Avg. loss: 0.2775, Accuracy: 181/200 (90%)



  0%|▊                                                                                                                                                                             | 48/10000 [00:15<39:45,  4.17it/s]


Test set: Avg. loss: 0.2888, Accuracy: 182/200 (91%)



  0%|▊                                                                                                                                                                             | 49/10000 [00:15<39:39,  4.18it/s]


Test set: Avg. loss: 0.2675, Accuracy: 183/200 (92%)



  0%|▊                                                                                                                                                                             | 50/10000 [00:16<39:19,  4.22it/s]


Test set: Avg. loss: 0.2784, Accuracy: 182/200 (91%)



  1%|▉                                                                                                                                                                             | 51/10000 [00:16<40:35,  4.09it/s]


Test set: Avg. loss: 0.2798, Accuracy: 183/200 (92%)



                                                                                                                                                                                                                      


Test set: Avg. loss: 0.2758, Accuracy: 183/200 (92%)





KeyboardInterrupt: 