In [2]:
import sys
import scipy.io as sio
import h5py
import numpy as np
from os.path import join as oj
import matplotlib.pyplot as plt
# %matplotlib inline
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
import pandas as pd
import torch
sys.path.insert(1, oj(sys.path[0], '..'))  # insert parent path
from tqdm import tqdm
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
from sklearn import decomposition
import matplotlib.gridspec as grd
from sklearn import neural_network
from torch.nn import functional as F
from torch import nn, optim
import torchvision.utils as vutils
import torchvision.models as models


%load_ext autoreload
%autoreload 2

In [3]:
from stringer_dset import StringerDset
num_gpu = 1 if torch.cuda.is_available() else 0
device = 'cuda' # 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
# get data
sdset = StringerDset()


# get gan
gan_dir = '/accounts/projects/vision/chandan/gan/cifar100_dcgan_grayscale'
sys.path.insert(1, gan_dir)

# load the models
from dcgan import *

D = Discriminator_rect(ngpu=num_gpu).to(device)
G = Generator_rect(ngpu=num_gpu).to(device)

# load weights
D.load_state_dict(torch.load(oj(gan_dir, 'weights_rect/netD_epoch_299.pth')))
G.load_state_dict(torch.load(oj(gan_dir, 'weights_rect/netG_epoch_299.pth')))

# fit latent space

In [5]:
(ims, resps) = sdset[:100]
means = np.mean(ims, axis=0)
stds = np.std(ims, axis=0) + 1e-8 # stds basically just magnifies stuff in the middle, no need to multiply it back
ims_norm = (ims - means) / stds
ims = torch.Tensor(ims_norm).to(device)
# resps = (resps - np.mean(resps, axis=0)) / (np.std(resps, axis=0) + 1e-8)
resps = torch.Tensor(resps).to(device)


(ims_val, resps_val) = sdset[-100:]
means_val = np.mean(ims_val, axis=0)
stds_val = np.std(ims_val, axis=0) + 1e-8 # stds basically just magnifies stuff in the middle, no need to multiply it back
ims_norm_val = (ims_val - means_val) / stds_val
# resps_val = (resps_val - np.mean(resps_val, axis=0)) / (np.std(resps_val, axis=0) + 1e-8)
ims_val = torch.Tensor(ims_norm_val).to(device)
resps_val = torch.Tensor(resps_val).to(device)

In [7]:
vgg = models.vgg19(pretrained=True).to(device)
reg_model = list(vgg.features.modules())[1]

# vgg 
def lay1_sim(reg_model, im1, im2):
    # grayscale to 3 channel
    
    im1 = im1.expand(-1, 3, -1, -1)
    im2 = im2.expand(-1, 3, -1, -1)
    
    feat1 = reg_model(im1).flatten()
    feat2 = reg_model(im2).flatten()
    feat1 = feat1 / feat1.norm()
    feat2 = feat2 / feat2.norm()
    return torch.dot(feat1, feat2)

In [10]:
class GenNet(nn.Module):
    def __init__(self, G):
        super(GenNet, self).__init__()
        self.fc1 = nn.Linear(11449, 100) # num_neurons to latent space
        self.fc1.weight.data = 1e-3 * self.fc1.weight.data
        self.fc1.bias.data = 1e-3 * self.fc1.bias.data
        self.G = G.eval()

    def forward(self, x):
        x = self.fc1(x)
#         print('latent', x[0, :20])
        x = x.reshape(x.shape[0], x.shape[1], 1, 1)
        im = self.G(x)
        return im
    
class LinNet(nn.Module):
    def __init__(self):
        super(LinNet, self).__init__()
        self.fc1 = nn.Linear(11449, 34 * 45) # num_neurons to latent space

    def forward(self, x):
        x = self.fc1(x)
        x = x.reshape(x.shape[0], 34, 45)
        return x

    
def viz_ims(ims_pred, ims, num_ims=5):    
    plt.figure(figsize=(num_ims * 1.2, 2), dpi=100)
    R, C = 2, num_ims
    for i in range(num_ims):
        plt.subplot(R, C, i + 1)
        plt.imshow(ims_pred[i].cpu().detach().numpy().reshape(34, 45), interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    for i in range(num_ims):
        plt.subplot(R, C, i + 1 + num_ims)
        plt.imshow(ims[i].cpu().detach().numpy().reshape(34, 45), interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    plt.show()

def save_ims(ims_pred, ims, it, num_ims=5, val=False):      
    suffix = '_val' if val else ''
    
    ims_save = np.empty((2 * num_ims, 1, 34, 45), dtype=np.float32)
    ims = ims[:num_ims].cpu().detach().numpy()
    ims -= np.min(ims, axis=0)
    ims /= np.max(ims, axis=0)
    ims_save[0::2] = ims
    
    ims_pred = ims_pred[:num_ims].cpu().detach().numpy()
    ims_pred -= np.min(ims_pred, axis=0)
    ims_pred /= np.max(ims_pred, axis=0)
    ims_save[0::2] = ims
    ims_save[1::2] = ims_pred
    ims_save = torch.Tensor(ims_save)
    vutils.save_image(ims_save,
                '{}/{}_samples{}.png'.format(out_dir, it, suffix),
                normalize=False, nrow=10)    
    
    
    
    
out_dir = 'out'
os.makedirs(out_dir, exist_ok=True)
its = 10000
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-11 # 1e-12 works
model = GenNet(G).to(device)
optimizer = torch.optim.SGD(model.fc1.parameters(), 
                            lr=learning_rate)
num_ims = 8
lambda_reg = 0.1
divisor = 34 * 45 * resps.shape[0]

print('training...')        
for it in range(its):
    # lr step down
    if it == 100:
        optimizer.param_groups[0]['lr'] *= 0.1
    if it == 600:
        optimizer.param_groups[0]['lr'] *= 0.5
    if it == 1000:
        optimizer.param_groups[0]['lr'] *= 0.25    
    if it == 20000:
        optimizer.param_groups[0]['lr'] *= 0.5    
    if it == 50000:
        optimizer.param_groups[0]['lr'] *= 0.5        
    
    ims_pred = model(resps)
    loss = loss_fn(ims_pred, ims) + lambda_reg * 1 - lay1_sim(reg_model, ims, ims_pred)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if it % 20 == 0:
        print(it, 'loss', loss.detach().item() / divisor, 'lr', optimizer.param_groups[0]['lr'])
    if torch.sum(model.fc1.weight.grad).detach().item() == 0:
        print('zero grad!')
        print('w', torch.sum(model.fc1.weight))    
        break
#     print('pred', ims_pred[0, :20])

    if it % 100 == 0:
#         viz_ims(ims_pred, ims, num_ims)
        save_ims(ims_pred, ims, it, num_ims=50)
        print('\tloss mse', loss_fn(ims_pred, ims).detach().item() / divisor)
        print('\tloss reg', 1 - lay1_sim(reg_model, ims_pred, ims).detach().item())
        with torch.no_grad():
            ims_pred_val = model(resps_val)
            save_ims(ims_pred_val, ims_val, it, num_ims=50, val=True)
            print('\tval loss mse', loss_fn(ims_pred_val, ims_val).detach().item() / (34 * 45 * resps_val.shape[0]))
            print('\tval loss reg', 1 - lay1_sim(reg_model, ims_pred_val, ims_val).detach().item())
    if it % 1000 == 0:
        torch.save(model.state_dict(), oj(out_dir, 'model_' + str(it) + '.pth'))

training...
0 loss 1.2037367238562091 lr 1e-11
	loss mse 1.2037381535947713
	loss reg 0.6925430595874786
	val loss mse 24374585.01838235
	val loss reg 0.6907373666763306
20 loss 0.9782524509803922 lr 1e-11
40 loss 0.910654820261438 lr 1e-11
60 loss 0.8689192197712419 lr 1e-11
80 loss 0.857047079248366 lr 1e-11
100 loss 0.8064471507352942 lr 1e-12
	loss mse 0.8064491421568627
	loss reg 0.5930293202400208
	val loss mse 21675660.38602941
	val loss reg 0.6757190227508545
120 loss 0.7737780330882353 lr 1e-12
140 loss 0.7664726307189542 lr 1e-12
160 loss 0.7605632659313726 lr 1e-12
180 loss 0.7552952920751634 lr 1e-12
200 loss 0.7504363255718954 lr 1e-12
	loss mse 0.7504385212418301
	loss reg 0.5628479719161987
	val loss mse 22451666.36029412
	val loss reg 0.6801646649837494
220 loss 0.7459669628267974 lr 1e-12
240 loss 0.7417589869281046 lr 1e-12
260 loss 0.7377313112745097 lr 1e-12
280 loss 0.7339188112745098 lr 1e-12
300 loss 0.7302235498366013 lr 1e-12
	loss mse 0.7302257965686274
	loss 

KeyboardInterrupt: 

**generate random ims**

In [None]:
def generate_random_ims():
    batch_size = 25
    latent_size = 100

    fixed_noise = torch.randn(batch_size, latent_size, 1, 1).to(device)
    print(fixed_noise.shape)
    fake_images = G(fixed_noise)

    fake_images_np = fake_images.cpu().detach().numpy()
    print(fake_images_np.shape)
    fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 34, 45)
    # fake_images_np = fake_images_np.transpose((0, 2, 3, 1))
    plt.figure(figsize=(4.5, 3.4), dpi=100)
    R, C = 5, 5
    for i in range(batch_size):
        plt.subplot(R, C, i + 1)
        plt.imshow(fake_images_np[i], interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    plt.show()
# generate_random_ims()