In [3]:
import xarray as xr
import iris
import numpy as np

iris.FUTURE.netcdf_promote = True

In [15]:
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from scipy.misc import toimage

def list_mogreps_uk(folder, years, months, days, hours, fcsts):
    fs = ['prods_op_mogreps-uk_{:04d}{:02d}{:02d}_{:02d}_00_{:03d}.nc'.format(year, month, day, hour, fcst)
            for year in years for month in months for day in days for hour in hours for fcst in fcsts]
    fs = [os.path.join(folder,f) for f in fs]
    return fs

class PrecipDataset(Dataset):
    def __init__(self, array, transform=None):
        self.array = array
        self.transform = transform

    def __len__(self):
        return len(self.array)

    def __getitem__(self, idx):
        def low_res(self, array, scale_factor, r=1):
            low_res = np.add.reduceat(np.add.reduceat(array.data, 
                                                      list(range(0, array.data.shape[0], scale_factor))),
                                      list(range(0, array.data.shape[1], scale_factor)), 
                                      axis=1) / scale_factor ** 2
            low_res = low_res.repeat(r, 0).repeat(r, 1)
            
            img = toimage(low_res.data)
            if self.transform:
                img = self.transform(img)
            return img
            
        try:
            precip = self.array[idx, :544, :416]

#             scale_factor = 4
#             low_res = np.add.reduceat(np.add.reduceat(precip.data, 
#                                                       list(range(0, precip.data.shape[0], scale_factor))),
#                                       list(range(0, precip.data.shape[1], scale_factor)), 
#                                       axis=1) / scale_factor ** 2

#             img = toimage(low_res.data)

#             if self.transform:
#                 img = self.transform(img)

            return low_res(self, precip, 16, 4), \
                   low_res(self, precip, 4)
        except:
            return self.__getitem__(np.random.randint(0, len(self.array)))

In [16]:
arrays = np.load('data/cloud_data.npz')

input_data = arrays['data']
test_data = arrays['test']

In [17]:
precip = PrecipDataset(input_data)

In [18]:
precip[5]

(<PIL.Image.Image image mode=L size=104x136 at 0x7FD49433DC18>,
 <PIL.Image.Image image mode=L size=104x136 at 0x7FD428D02F60>)

In [27]:
104*136

14144

In [20]:
27*35

945

In [21]:
from torchvision import datasets, transforms
precip_train = PrecipDataset(input_data,
                             transform=transforms.ToTensor())

precip_test = PrecipDataset(test_data,
                            transform=transforms.ToTensor())

# precip_train = list(precip_train)
# precip_test = list(precip_test)

In [22]:
14522 * 3200

46470400

In [34]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

args_dict = {'batch_size': 5,
             'epochs': 100,
             'cuda': True,
             'seed': 1,
             'log_interval': 10}
args = type('test', (object,), {})()
args.__dict__.update(args_dict)


torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)


kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

train_loader = torch.utils.data.DataLoader(precip_train,
                                           batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(precip_test,
                                          batch_size=args.batch_size, shuffle=True, **kwargs)

N = 14144
# n2 = int(N/2)
# n3 = int(N/9)
n2 = 3200
n3 = 20
n22 = 3200
N2 = 14144

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(N, n2)
        self.fc21 = nn.Linear(n2, n3)
        self.fc22 = nn.Linear(n2, n3)
        self.fc3 = nn.Linear(n3, n2)
        self.fc4 = nn.Linear(n2, N)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, N))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE()
if args.cuda:
    model.cuda()


def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, N))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Normalise by same number of elements as in reconstruction
    KLD /= args.batch_size * N

    return BCE + KLD


optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        print(batch_idx)
        data = Variable(data)
        target = Variable(target)
        if args.cuda:
            data = data.cuda()
            target = target.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, target, mu, logvar)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    for i, (data, target) in enumerate(test_loader):
        print(i)
        if args.cuda:
            data = data.cuda()
            target = target.cuda()
        data = Variable(data, volatile=True)
        target = Variable(target, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += loss_function(recon_batch, target, mu, logvar).data[0]
        if i == 0:
            n = min(data.size(0), 8)
            comparison = torch.cat([data[:n],
                                    target[:n],
                                  recon_batch.view(args.batch_size, 1, 136, 104)[:n]])
            save_image(comparison.data.cpu(),
                     'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)
    sample = Variable(torch.randn(64, n3))
    if args.cuda:
        sample = sample.cuda()
    sample = model.decode(sample).cpu()
    save_image(sample.data.view(64, 1, 136, 104),
               'results/sample_' + str(epoch) + '.png')

In [11]:
input_data.shape

(504, 548, 421)