In [1]:
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from dcgan_discriminator import Discriminator
from dcgan_discriminator_multipleOut import Discriminator_MO
from dcgan_classifier import Classifier

In [2]:
# training parameters
batch_size = 128
lr = 0.0001#2
train_epoch = 100

# data_loader
img_size = 64

Either run bellow for MNIST or next for CIFAR

In [3]:
# MNIST
from SparseMNIST import SparseMNIST

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    SparseMNIST('data', 0.01, train=True, download=True, transform=transform, showAll=False),
    batch_size=batch_size, shuffle=True, num_workers=4)

val_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, download=True, transform=transform),
    batch_size=batch_size, shuffle=False, num_workers=4)


In [3]:
# CIFAR 10
from SparseCIFAR import SparseCIFAR

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,) * 3, std=(0.5,) * 3)
])
train_loader =  torch.utils.data.DataLoader(
    SparseCIFAR('data', 4000/50000, train=True, download=True, transform=transform, showAll=False),
    batch_size=batch_size, shuffle=True, num_workers=4)
val_loader =  torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, download=True, transform=transform),
    batch_size=batch_size, shuffle=False, num_workers=4)

#print(len(datasets.CIFAR10('data', train=True, download=True, transform=transform)))

Files already downloaded and verified
Files already downloaded and verified


Choose Discriminator

In [4]:
def mnist_classifier(model_name, conv_resolution=128, channels=1, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True):
    # MNIST - Load Pretrained
    D = Discriminator(conv_resolution=conv_resolution, channels=channels, 
                      mbl_size=mbl_size, train_classes=train_classes)
    D = torch.load(model_name)
    #D.weight_init(mean=0.0, std=0.02) # From scratch

    # MNIST - Save and build one with Multiple outputs
    torch.save(D.state_dict(),'tmp.pt')
    D_MO = Discriminator_MO(conv_resolution=conv_resolution, channels=channels, 
                            mbl_size=mbl_size, train_classes=train_classes, 
                            freeze=freezePretrained)
    #D_MO = Discriminator_MO(conv_resolution=128, channels=1, mbl_size=0, train_classes=False, freeze=False)
    D_MO.load_state_dict(torch.load('tmp.pt'))
    D_MO.eval()
    D_MO.cuda()

    # MNIST - Build classifier ontop
    C = Classifier(D_MO,d=conv_resolution)
    C.weight_init(mean=0.0, std=0.02)
    C.cuda()
    C.train()

In [4]:
def cifar_classifier(model_name, conv_resolution=100, channels=3, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True):
    #CIFAR - Load Pretrained
    D = Discriminator(conv_resolution=conv_resolution, channels=channels,
                      mbl_size=mbl_size, train_classes=train_classes)
    D = torch.load(model_name) # If pretrained
    #D.weight_init(mean=0.0, std=0.02) # From scratch

    # CIFAR - Save and build one with Multiple outputs
    torch.save(D.state_dict(),'tmp.pt')
    D_MO = Discriminator_MO(conv_resolution=conv_resolution, channels=channels, 
                            mbl_size=mbl_size, train_classes=train_classes, 
                            freeze=freezePretrained)
    #D_MO = Discriminator_MO(conv_resolution=100, channels=3, mbl_size=0, train_classes=False, freeze=freezePretrained)
    D_MO.load_state_dict(torch.load('tmp.pt'))
    D_MO.eval()
    D_MO.cuda()

    # CIFAR - Build classifier ontop
    C = Classifier(D_MO,d=conv_resolution)
    C.weight_init(mean=0.0, std=0.02)
    C.cuda()
    C.train()
    
    return C

In [5]:
# Trains a classifier
def train(C, train_loader, val_loader, train_epoch = 100):
    CE_loss = nn.CrossEntropyLoss(ignore_index=-100)
    C_optimizer = optim.Adam(C.parameters(), lr=lr, betas=(0.5, 0.999))

    train_hist = {}
    train_hist['C_losses'] = []
    train_hist['C_acc'] = []
    train_hist['per_epoch_ptimes'] = []
    train_hist['total_ptime'] = []

    num_iter = 0
    loss_info = False #Print Train Loss
    eval_nth = 5
    best_acc = 0

    print('training start!')
    start_time = time.time()

    for epoch in range(train_epoch):
        epoch_start_time = time.time()

        C_losses =[]

        for x_, y_real_label_ in train_loader:
            # train classifier
            C.zero_grad()
            x_ = Variable(x_.cuda())
            y_real_label_ = Variable(y_real_label_.cuda())

            C_result = C(x_).squeeze()
            C_real_loss = CE_loss(C_result, y_real_label_)

            C_real_loss.backward()
            C_optimizer.step()

            train_hist['C_losses'].append(C_real_loss.data[0])
            C_losses.append(C_real_loss.data[0])
            if(loss_info):
                print(C_real_loss.data[0])

            num_iter += 1

        C.eval()

        val_scores = []

        if epoch % eval_nth == 0 or epoch == train_epoch-1:
            print('Validate ...')
            for inputs, targets in val_loader:
                inputs, targets = Variable(inputs), Variable(targets)
                if torch.cuda.is_available():
                    inputs, targets = inputs.cuda(), targets.cuda()

                outputs = C(inputs).squeeze()

                _, preds = torch.max(outputs, 1)
                scores = np.mean((preds == targets).data.cpu().numpy())
                val_scores.append(scores)

            val_acc = np.mean(val_scores)
            train_hist['C_acc'].append(val_acc)

            if val_acc > best_acc:
                best_acc = val_acc

        C.train()

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        print('[%d/%d] - ptime: %.2f, loss_c: %.3f, acc_c: %.3f' % ((epoch + 1), 
            train_epoch, per_epoch_ptime, 
            torch.mean(torch.FloatTensor(C_losses)), val_acc))

        #p = 'MNIST_DCGAN_results/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
        #fixed_p = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
        #show_result((epoch+1), save=True, path=p, isFix=False)
        #show_result((epoch+1), save=True, path=fixed_p, isFix=True)
        train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

    end_time = time.time()
    total_ptime = end_time - start_time
    train_hist['total_ptime'].append(total_ptime)

    print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
    print("Best Accuaracy: {}".format(best_acc))


In [None]:
C = cifar_classifier('cifar_fm_models/d_0.model', conv_resolution=100, channels=3, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

training start!


  c = F.softmax(self.conv5_c(c))


Validate ...
[1/100] - ptime: 19.58, loss_c: 2.167, acc_c: 0.354
[2/100] - ptime: 8.89, loss_c: 2.067, acc_c: 0.354
[3/100] - ptime: 8.91, loss_c: 2.026, acc_c: 0.354
[4/100] - ptime: 8.93, loss_c: 1.994, acc_c: 0.354
[5/100] - ptime: 8.93, loss_c: 1.961, acc_c: 0.354
Validate ...
[6/100] - ptime: 15.29, loss_c: 1.942, acc_c: 0.455
[7/100] - ptime: 8.98, loss_c: 1.915, acc_c: 0.455
[8/100] - ptime: 8.98, loss_c: 1.892, acc_c: 0.455
[9/100] - ptime: 8.96, loss_c: 1.866, acc_c: 0.455
[10/100] - ptime: 8.98, loss_c: 1.847, acc_c: 0.455
Validate ...
[11/100] - ptime: 15.35, loss_c: 1.821, acc_c: 0.489
[12/100] - ptime: 8.97, loss_c: 1.803, acc_c: 0.489
[13/100] - ptime: 8.97, loss_c: 1.783, acc_c: 0.489
[14/100] - ptime: 8.97, loss_c: 1.761, acc_c: 0.489
[15/100] - ptime: 8.98, loss_c: 1.745, acc_c: 0.489
Validate ...
[16/100] - ptime: 15.34, loss_c: 1.737, acc_c: 0.504
[17/100] - ptime: 8.96, loss_c: 1.714, acc_c: 0.504
[18/100] - ptime: 8.97, loss_c: 1.705, acc_c: 0.504
[19/100] - ptime:

In [None]:
C = cifar_classifier('cifar_fm_models/d_1.model', conv_resolution=100, channels=3, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_models/d_4.model', conv_resolution=100, channels=3, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_models/d_9.model', conv_resolution=100, channels=3, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_models/d_19.model', conv_resolution=100, channels=3, 
                     mbl_size=0, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_mbd_models/d_0.model', conv_resolution=100, channels=3, 
                     mbl_size=8, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_mbd_models/d_1.model', conv_resolution=100, channels=3, 
                     mbl_size=8, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_mbd_models/d_2.model', conv_resolution=100, channels=3, 
                     mbl_size=8, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_mbd_models/d_4.model', conv_resolution=100, channels=3, 
                     mbl_size=8, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_mbd_models/d_9.model', conv_resolution=100, channels=3, 
                     mbl_size=8, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)

In [None]:
C = cifar_classifier('cifar_fm_mbd_models/d_19.model', conv_resolution=100, channels=3, 
                     mbl_size=8, train_classes=False, 
                     freezePretrained = True)
train(C, train_loader, val_loader)