In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F

from early_stopping import EarlyStopping
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn as nn

from Custom_Dataset import NewDataset
from lenet import Lenet
from infogan_layers import ConvTranspose2d,Linear
from InfoGAN import InfoGAN

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
# Now let's start building the GAN
# But first, we're going to redefine Conv2D and Linear with our own initialisations
# We're going to use Glorot (aka Xavier) uniform init for all weights
# And we will use zero init for all biases

c1_len = 10 # Multinomial
c2_len = 2 # Gaussian
c3_len = 2 # Bernoulli
z_len = 114 # Noise vector length
embedding_len = 128

class Conv2d(nn.Conv2d):
    def reset_parameters(self):
        stdv = np.sqrt(6 / ((self.in_channels  + self.out_channels) * np.prod(self.kernel_size)))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

class ConvTranspose2d(nn.ConvTranspose2d):
    def reset_parameters(self):
        stdv = np.sqrt(6 / ((self.in_channels  + self.out_channels) * np.prod(self.kernel_size)))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

class Linear(nn.Linear):
    def reset_parameters(self):
        stdv = np.sqrt(6 / (self.in_features + self.out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = Linear(z_len + c1_len + c2_len + c3_len, 1024)
        self.fc2 = Linear(1024, 7 * 7 * 128)

        self.convt1 = ConvTranspose2d(128, 64, kernel_size = 4, stride = 2, padding = 1)
        self.convt2 = ConvTranspose2d(64, 1, kernel_size = 4, stride = 2, padding = 1)

        self.bn1 = nn.BatchNorm1d(1024)
        self.bn2 = nn.BatchNorm1d(7 * 7 * 128)
        self.bn3 = nn.BatchNorm2d(64)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x))).view(-1, 128, 7, 7)

        x = F.relu(self.bn3(self.convt1(x)))
        x = self.convt2(x)

        return F.sigmoid(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = Conv2d(1, 64, kernel_size = 4, stride = 2, padding = 1) # 28 x 28 -> 14 x 14
        self.conv2 = Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1) # 14 x 14 -> 7 x 7

        self.fc1 = Linear(128 * 7 ** 2, 1024)
        self.fc2 = Linear(1024, 1)
        self.fc1_q = Linear(1024, embedding_len)

        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm1d(1024)
        self.bn_q1 = nn.BatchNorm1d(embedding_len)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.bn1(self.conv2(x))).view(-1, 7 ** 2 * 128)

        x = F.leaky_relu(self.bn2(self.fc1(x)))
        return self.fc2(x), F.leaky_relu(self.bn_q1(self.fc1_q(x)))

In [None]:
infogan_gen = Generator().to(device)
infogan_dis = Discriminator().to(device)
G = InfoGAN(infogan_gen, infogan_dis, embedding_len, z_len, c1_len, c2_len, c3_len, device)
G.load('../../utils/trained_models/infogan_100_z_114/')

In [None]:
z_dict = {}
z_dict = G.get_z(c1_len * 3000, sequential = True)
gan_input = torch.cat([z_dict[k] for k in z_dict.keys()], dim =1)
gan_input = Variable(gan_input.to(device), requires_grad= True)
out_gen = G.gen(gan_input)

In [None]:
out_gen.shape

In [None]:
grid_img = torchvision.utils.make_grid(image_1[:50].detach().cpu(), nrow= 10)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
torch.save(out_gen,'train_images_2')

In [None]:
image_1 = torch.cat((torch.load('train_images_1'),torch.load('train_images_2')),dim=0)

In [None]:
image_1.shape

In [None]:
labels = torch.LongTensor([0,1,2,3,4,5,6,7,8,9]*6000)

In [None]:
torch.save(labels,'labels_latest')

In [None]:
train_data = torch.load('train_images')
train_labels = torch.load('labels_latest')

In [None]:
grid_img = torchvision.utils.make_grid(image_1[:50].detach().cpu(), nrow= 10)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
labels[:50]

In [None]:
# transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = NewDataset(image_1 ,labels)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=False)

test_dataset = dsets.MNIST(root='./newmnist',transform=transform,train=False,download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=100, 
                                          shuffle=True)

In [None]:
test_dataset[0][0].min()

In [None]:
seed = 22
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed )
torch.cuda.manual_seed(seed)
np.random.seed(seed)

In [None]:
full_cnn_model = Lenet()
full_cnn_model.load_state_dict(torch.load('../trained_models/cnn_lenet_28_sigmoid.pt',map_location='cuda:0'))
full_cnn_model.eval()
full_cnn_model.to(device)

In [None]:
labels_list = []
for images_new,_ in train_loader:
    out = F.softmax(full_cnn_model(images_new.to(device)),dim=1)
    _,labels = torch.max(out.data.cpu(), 1)
    labels_list.append(labels)

In [None]:
labels =  torch.stack(labels_list).view(-1)

In [None]:
torch.save(labels,'labels')

In [None]:
labels.shape

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = NewDataset(image_1 ,labels)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

test_dataset = dsets.MNIST(root='./newmnist',transform=transform,train=False,download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=100, 
                                          shuffle=True)

In [None]:
grid_img = torchvision.utils.make_grid(image_1[:50].detach().cpu(), nrow= 10)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
labels[:50]

In [None]:
labels.shape

In [None]:
images = image_1[:20000]

In [None]:
labels2 = labels[:10000]

In [None]:
images.shape

In [None]:
full_cnn_model = Lenet()
full_cnn_model.to(device)

criterion = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(full_cnn_model.parameters(), lr=learning_rate,amsgrad=True) 

In [None]:
for epoch in range(60):
    train_loss_list = []
    test_loss_list = []

    full_cnn_model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        optimizer.zero_grad()
        outputs = full_cnn_model (images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
    
    full_cnn_model.eval()
    accuracy_list = []
    for i, (images, labels) in enumerate(test_loader):
        images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        outputs = full_cnn_model(images)
        loss = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        total = labels.size(0)
        correct = (predicted == labels).sum()
        accuracy = 100 * (float(correct) /float( total))
        accuracy_list.append(accuracy)
        test_loss_list.append(loss.item())
    final_accuracy = sum(accuracy_list)/len(accuracy_list)
    traininig_loss = sum(train_loss_list)/len(train_loss_list)
    testing_loss = sum(test_loss_list)/len(test_loss_list)
    print('Epoch: {}. TrainLoss: {}. TestLoss: {}. Accuracy: {}'.format(epoch, traininig_loss,testing_loss, final_accuracy))