In [34]:
import sys
import scipy.io as sio
import h5py
import numpy as np
from os.path import join as oj
import matplotlib.pyplot as plt
# %matplotlib inline
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
import pandas as pd
import torch
sys.path.insert(1, oj(sys.path[0], '..'))  # insert parent path
from tqdm import tqdm
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
from sklearn import decomposition
import matplotlib.gridspec as grd
from sklearn import neural_network
from torch.nn import functional as F
from torch import nn, optim


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
from stringer_dset import StringerDset
num_gpu = 1 if torch.cuda.is_available() else 0
device = 'cuda' # 'cuda' if torch.cuda.is_available() else 'cpu'

In [36]:
# get data
sdset = StringerDset()


# get gan
gan_dir = '/accounts/projects/vision/chandan/gan/cifar100_dcgan_grayscale'
sys.path.insert(1, gan_dir)

# load the models
from dcgan import *

D = Discriminator_rect(ngpu=num_gpu).to(device)
G = Generator_rect(ngpu=num_gpu).to(device)

# load weights
D.load_state_dict(torch.load(oj(gan_dir, 'weights_rect/netD_epoch_299.pth')))
G.load_state_dict(torch.load(oj(gan_dir, 'weights_rect/netG_epoch_299.pth')))

# fit latent space

In [49]:
(ims, resps) = sdset[:1000]
# model = LinNet().to(device) 

# normalize data
means = np.mean(ims, axis=0)
stds = np.std(ims, axis=0) + 1e-8
ims_norm = (ims - means) / stds

ims = torch.Tensor(ims_norm).to(device)
resps = torch.Tensor(resps).to(device)

# trans = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

In [None]:
class GenNet(nn.Module):
    def __init__(self, G):
        super(GenNet, self).__init__()
        self.fc1 = nn.Linear(11449, 100) # num_neurons to latent space
        self.fc1.weight.data = 1e-3 * self.fc1.weight.data
        self.fc1.bias.data = 1e-3 * self.fc1.bias.data
        self.G = G.eval()

    def forward(self, x):
        x = self.fc1(x)
#         print('latent', x[0, :20])
        x = x.reshape(x.shape[0], x.shape[1], 1, 1)
        im = self.G(x)
        return im
    
class LinNet(nn.Module):
    def __init__(self):
        super(LinNet, self).__init__()
        self.fc1 = nn.Linear(11449, 34 * 45) # num_neurons to latent space

    def forward(self, x):
        x = self.fc1(x)
        x = x.reshape(x.shape[0], 34, 45)
        return x

    
def viz_ims(ims_pred, ims, num_ims=5):    
    plt.figure(figsize=(num_ims * 1.2, 2), dpi=100)
    R, C = 2, num_ims
    for i in range(num_ims):
        plt.subplot(R, C, i + 1)
        plt.imshow(ims_pred[i].cpu().detach().numpy().reshape(34, 45), interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    for i in range(num_ims):
        plt.subplot(R, C, i + 1 + num_ims)
        plt.imshow(ims[i].cpu().detach().numpy().reshape(34, 45), interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    plt.show()
    
print('initializing...')
model = GenNet(G).to(device)
its = 10000
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-12 # 1e-12 works
optimizer = torch.optim.SGD(model.fc1.parameters(), 
                            lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=100, 
                                           verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
num_ims = 8

print('training...')        
for it in range(its):
    ims_pred = model(resps)
    loss = loss_fn(ims_pred, ims)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if it % 10 == 0:
        print(it, '\tloss', loss.detach().item() / (34 * 45 * resps.shape[0]), 'lr', optimizer.param_groups[0]['lr'])
    if torch.sum(model.fc1.weight.grad).detach().item() == 0:
        print('zero grad!')
        print('w', torch.sum(model.fc1.weight))    
#     print('pred', ims_pred[0, :20])å
    if it % 100 == 0:
        viz_ims(ims_pred, ims, num_ims)
    scheduler.step(loss)

# generate random ims

In [None]:
def generate_random_ims():
    batch_size = 25
    latent_size = 100

    fixed_noise = torch.randn(batch_size, latent_size, 1, 1).to(device)
    print(fixed_noise.shape)
    fake_images = G(fixed_noise)

    fake_images_np = fake_images.cpu().detach().numpy()
    print(fake_images_np.shape)
    fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 34, 45)
    # fake_images_np = fake_images_np.transpose((0, 2, 3, 1))
    plt.figure(figsize=(4.5, 3.4), dpi=100)
    R, C = 5, 5
    for i in range(batch_size):
        plt.subplot(R, C, i + 1)
        plt.imshow(fake_images_np[i], interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    plt.show()
generate_random_ims()