In [1]:
import os
import sys
import time
import torch
import numpy as np
import collections
import matplotlib.pyplot as plt
import torch_cae_baseline as net

In [None]:
# data & configs
sst_mask = np.load('./data/npy/sst.mnmean.mask.npy')
nt, nx, ny = sst_mask.shape
n_levels = 4
max_fs = 3**(n_levels)
err_threshold = 1e-5

In [None]:
# padding sizes
px = int(np.ceil(nx/max_fs)*max_fs)
py = int(np.ceil(ny/max_fs)*max_fs)
top_margin = int(np.floor((px-nx)/2))
bottom_margin = int(np.ceil((px-nx)/2))
left_margin = int(np.floor((py-ny)/2))
right_margin = int(np.ceil((py-ny)/2))

print('margins: left {}, right {}, top {}, bottom {}'.format(left_margin, right_margin, top_margin, bottom_margin))

In [None]:
# pad & scale the images
D = np.pad(sst_mask, ((0, 0), (0, 0), (left_margin, right_margin)), 'wrap') # cirular (horizontally)
D = np.pad(D, ((0, 0), (top_margin, bottom_margin), (0, 0)), 'reflect')     # symmetric (vertically)
print('shape of the padded images: {} x {}'.format(D.shape[1], D.shape[2]))

plt.figure(figsize=(15,8))
plt.imshow(D[0,:,:])
plt.title('sampled image with padding',fontsize=20)

base = D.min()
scale = np.max(D) - np.min(D)
scaled_D = (D - base)/scale
print('max: {}, min: {}, mean: {}'.format(scaled_D.max(), scaled_D.min(), scaled_D.mean()))

In [None]:
# init model & load data
model = net.CAE(n_levels=4, n_layers=2, n_blocks_for_each_unit=2, activation=torch.nn.Sequential(), std=0.01)
model.load_data(scaled_D)

In [None]:
# train on level 0
model.train_arch(max_epoch=10, batch_size=100, lr=1e-3, multiscale_loss=True, 
                 loss_type='l2', verbose=2, model_path='./model')

In [6]:
samples = torch.rand(100, 1, 81, 81)
op = torch.nn.Conv2d(1, 9, 3, stride=3, padding=0)
# init weight 
# for i in range(9):
#     op.bias.data[i] = 0.0
#     weight = np.zeros(9)
#     weight[i] = 1.0
#     op.weight.data[i] = torch.tensor(weight).resize(3,3)
    
# rep = op(samples)
# print(rep[0])

op2 = torch.nn.ConvTranspose2d(9, 1, 3, stride=3, padding=0)
# init weight 
# op2.bias.data[0] = 0.0
# for i in range(9):
#     weight = np.zeros(9)
#     weight[i] = 1.0
#     op2.weight.data[i] = torch.tensor(weight).resize(3,3)
# out = op2(rep)
# print((out - samples).max())

criterion = torch.nn.MSELoss()
loss = criterion(samples, out)
optimizer = torch.optim.Adam([{'params': op.parameters()}, {'params': op2.parameters()}], lr=1e-3, eps=1e-3)
for i in range(2000):
    optimizer.zero_grad()
    out = op2(op(samples))
    loss = criterion(out, samples)
    if i % 200 == 0:
        print(loss.item())
    loss.backward()
    optimizer.step()

0.5000169277191162
0.0755719244480133
0.056541915982961655
0.041283853352069855
0.029635069891810417
0.021189451217651367
0.015160070732235909
0.009920932352542877
0.004347345791757107
0.0011327755637466908


In [16]:
samples1 = torch.rand(100, 1, 81, 81)
samples2 = torch.rand(100, 1, 81, 81)
diff = torch.max(torch.abs(samples1 - samples2), dim=0, keepdim=True)[0]

In [None]:
# init model & load data
data_path = '../data/npy/toy1.npy'
model_path = '../model/toy1/'
result_path = '../result/toy1/'

scaled_Phi = np.load(data_path)
model = net.CAE(n_levels=4, activation=torch.nn.Sequential())
model.load_data(scaled_Phi)

In [None]:
for i in range(4):
    model.deeper_op(std=0.0)

In [None]:
_, inp, out, hids = model(model.train[[1], :, :, :], model.cur_level, True, True)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 3))
for i, k in enumerate(inp.keys()):
    print(inp[k].size())
    axes[i].pcolor((inp[k]).squeeze().cpu().detach().numpy())

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 3))
for i, k in enumerate(hids.keys()):
    print(k, ': ', hids[k].size())
    axes[i].pcolor((hids[k]).squeeze().cpu().detach().numpy())

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 3))
for i, k in enumerate(out.keys()):
    print(k, ': ', out[k].size())
    axes[i].pcolor((out[k]).squeeze().cpu().detach().numpy())

In [None]:
sigmoid = torch.nn.Sigmoid()
output = model._modules['L0_Conv_0'](inp['0'])
plt.pcolor(output.squeeze().cpu().detach().numpy())
fig, axes = plt.subplots(1, 4, figsize=(16, 3))
for i in range(4):
    tmp = torch.nn.functional.pad(output, (1, 1, 1, 1), 'replicate')
    output = model._modules['L{}_deConv_0'.format(i)](tmp)
    output = sigmoid(output[:, :, 2:-2, 2:-2])
    axes[i].pcolor(output.squeeze().cpu().detach().numpy())