In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from early_stopping import EarlyStopping
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torch.nn as nn
from Custom_Dataset import NewDataset

from lenet import Lenet
from dcgan import Generator as GAN
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
G = GAN()
G.load_state_dict(torch.load("../trained_models/trained_dcgan.pth",map_location='cuda:0'))
G.eval()
G.to(device)

full_cnn_model = Lenet()
full_cnn_model.load_state_dict(torch.load('../trained_models/cnn_lenet_28_tanh.pt',map_location='cuda:0'))
full_cnn_model.eval()
full_cnn_model.to(device)

In [None]:
images = torch.load('../data_gan/dcgan_data/dcgan_images')
labels = torch.load('../data_gan/dcgan_data/dcgan_labels')

In [None]:
print(images.shape)
print(labels.shape)

In [None]:
grid_img = torchvision.utils.make_grid(images[:50].detach().cpu(), nrow= 10)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
print(labels[:50])

In [None]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
train_dataset = NewDataset(images, labels)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

test_dataset = dsets.MNIST(root='./newmnist', train=False, transform=transform,download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=100, 
                                          shuffle=False)

In [None]:
seed = 22
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed )
torch.cuda.manual_seed(seed )
np.random.seed(seed)

In [None]:
full_cnn_model = Lenet()
full_cnn_model.to(device)

criterion = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(full_cnn_model.parameters(), lr=learning_rate) 

In [None]:
for epoch in range(60):
    train_loss_list = []
    test_loss_list = []

    full_cnn_model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        optimizer.zero_grad()
        outputs = full_cnn_model (images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_list.append(loss.item())
    
    full_cnn_model.eval()
    accuracy_list = []
    for i, (images, labels) in enumerate(test_loader):
        images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        outputs = full_cnn_model(images)
        loss = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        total = labels.size(0)
        correct = (predicted == labels).sum()
        accuracy = 100 * (float(correct) /float( total))
        accuracy_list.append(accuracy)
        test_loss_list.append(loss.item())
    final_accuracy = sum(accuracy_list)/len(accuracy_list)
    traininig_loss = sum(train_loss_list)/len(train_loss_list)
    testing_loss = sum(test_loss_list)/len(test_loss_list)
    print('Epoch: {}. TrainLoss: {}. TestLoss: {}. Accuracy: {}'.format(epoch, traininig_loss,testing_loss, final_accuracy))