In [2]:
import sys
import scipy.io as sio
import h5py
import numpy as np
from os.path import join as oj
import matplotlib.pyplot as plt
# %matplotlib inline
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
import pandas as pd
import torch
sys.path.insert(1, oj(sys.path[0], '..'))  # insert parent path
from tqdm import tqdm
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
from sklearn import decomposition
import matplotlib.gridspec as grd
from sklearn import neural_network
from torch.nn import functional as F
from torch import nn, optim
import torchvision.utils as vutils
import torchvision.models as models


%load_ext autoreload
%autoreload 2

num_gpu = 1 if torch.cuda.is_available() else 0
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
import stringer_dset
ims, resps, ims_val, resps_val = stringer_dset.get_data()

# fit latent space

In [19]:
import models
G = models.get_generator()

In [21]:
reg_model = models.get_reg_model()



In [None]:
import utils    
    
learning_rate = 1e-11
lambda_reg = 0.1
    
out_dir = 'out'
os.makedirs(out_dir, exist_ok=True)
its = 10000
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-11 # 1e-12 works
model = models.GenNet(G).to(device)
optimizer = torch.optim.SGD(model.fc1.parameters(), 
                            lr=learning_rate)
lambda_reg = 0.1
divisor = 34 * 45 * resps.shape[0]

print('training...')        
for it in range(its):
    # lr step down
    if it == 100:
        optimizer.param_groups[0]['lr'] *= 0.1
    if it == 600:
        optimizer.param_groups[0]['lr'] *= 0.5
    if it == 1000:
        optimizer.param_groups[0]['lr'] *= 0.25    
    if it == 20000:
        optimizer.param_groups[0]['lr'] *= 0.5    
    if it == 50000:
        optimizer.param_groups[0]['lr'] *= 0.5        
    
    ims_pred = model(resps)
    loss = loss_fn(ims_pred, ims) + lambda_reg * 1 - utils.lay1_sim(reg_model, ims, ims_pred)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if it % 20 == 0:
        print(it, 'loss', loss.detach().item() / divisor, 'lr', optimizer.param_groups[0]['lr'])
    if torch.sum(model.fc1.weight.grad).detach().item() == 0:
        print('zero grad!')
        print('w', torch.sum(model.fc1.weight))    
        break

    if it % 100 == 0:
        utils.save_ims(out_dir, ims_pred, ims, it, num_ims=50)
        print('\tloss mse', loss_fn(ims_pred, ims).detach().item() / divisor)
        print('\tloss reg', 1 - utils.lay1_sim(reg_model, ims_pred, ims).detach().item())
        with torch.no_grad():
            ims_pred_val = model(resps_val)
            utils.save_ims(out_dir, ims_pred_val, ims_val, it, num_ims=50, val=True)
            print('\tval loss mse', loss_fn(ims_pred_val, ims_val).detach().item() / (34 * 45 * resps_val.shape[0]))
            print('\tval loss reg', 1 - utils.lay1_sim(reg_model, ims_pred_val, ims_val).detach().item())
    if it % 1000 == 0:
        torch.save(model.state_dict(), oj(out_dir, 'model_' + str(it) + '.pth'))

training...
0 loss 1.2036013071895424 lr 1e-11
	loss mse 1.2036013888888888
	loss reg 0.6945621371269226
	val loss mse 1.0336727941176471
	val loss reg 0.6905852854251862
20 loss 1.1349286764705881 lr 1e-11
40 loss 1.0494067810457517 lr 1e-11
60 loss 1.0254953431372549 lr 1e-11
80 loss 1.0218225490196078 lr 1e-11
100 loss 1.354948611111111 lr 1e-12
	loss mse 1.354948611111111
	loss reg 0.7492066919803619
	val loss mse 1.1334281045751633
	val loss reg 0.7260875701904297
120 loss 1.0028217320261439 lr 1e-12


**generate random ims**

In [None]:
def generate_random_ims():
    batch_size = 25
    latent_size = 100

    fixed_noise = torch.randn(batch_size, latent_size, 1, 1).to(device)
    print(fixed_noise.shape)
    fake_images = G(fixed_noise)

    fake_images_np = fake_images.cpu().detach().numpy()
    print(fake_images_np.shape)
    fake_images_np = fake_images_np.reshape(fake_images_np.shape[0], 34, 45)
    # fake_images_np = fake_images_np.transpose((0, 2, 3, 1))
    plt.figure(figsize=(4.5, 3.4), dpi=100)
    R, C = 5, 5
    for i in range(batch_size):
        plt.subplot(R, C, i + 1)
        plt.imshow(fake_images_np[i], interpolation='bilinear', cmap='gray')
        plt.axis('off')
        plt.tight_layout()
        plt.subplots_adjust(hspace=0, wspace=0, left=0)
    plt.show()
# generate_random_ims()