In [1]:
import cartopy.crs as ccrs
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
import pandas as pd
import os
import time
from datetime import timedelta
import rasterio.warp as rasteriowarp

In [2]:
SATELLITE_DATA_PATH = os.path.expanduser('~/data/EUMETSAT/reprojected_subsetted/')
PV_DATA_FILENAME = os.path.expanduser('~/data/pvoutput.org/UK_PV_timeseries_batch.nc')
PV_METADATA_FILENAME = os.path.expanduser('~/data/pvoutput.org/UK_PV_metadata.csv')

DST_CRS = {
    'ellps': 'WGS84',
    'proj': 'tmerc',  # Transverse Mercator
    'units': 'm'  # meters
}

# Geospatial boundary in Transverse Mercator projection (meters)
SOUTH = 5513500
NORTH = 6613500
WEST =  -889500
EAST =   410500

## Load and convert PV metadata

In [3]:
pv_metadata = pd.read_csv(PV_METADATA_FILENAME, index_col='system_id')
pv_metadata.dropna(subset=['longitude', 'latitude'], how='any', inplace=True)

In [4]:
# Convert lat lons to Transverse Mercator
pv_metadata['x'], pv_metadata['y'] = rasteriowarp.transform(
    src_crs={'init': 'EPSG:4326'},
    dst_crs=DST_CRS,
    xs=pv_metadata['longitude'].values,
    ys=pv_metadata['latitude'].values)

# Filter 3 PV systems which apparently aren't in the UK!
pv_metadata = pv_metadata[
    (pv_metadata.x >= WEST) &
    (pv_metadata.x <= EAST) &
    (pv_metadata.y <= NORTH) &
    (pv_metadata.y >= SOUTH)]

len(pv_metadata)

2548

## Load and normalise PV power data

In [None]:
%%time
pv_power = xr.load_dataset(PV_DATA_FILENAME)

In [None]:
pv_power_selected = pv_power.loc[dict(datetime=slice('2018-06-01', '2019-07-01'))]

In [None]:
pv_power_df = pv_power_selected.to_dataframe().dropna(axis='columns', how='all')
pv_power_df = pv_power_df.clip(lower=0, upper=5E7)
pv_power_df.columns = [np.int64(col) for col in pv_power_df.columns]
pv_power_df = pv_power_df.tz_localize('Europe/London').tz_convert('UTC')

In [None]:
del pv_power
del pv_power_selected

In [None]:
# A bit of hand-crafted cleaning
pv_power_df[30248]['2018-10-29':'2019-01-03'] = np.NaN

In [None]:
# Only pick PV systems for which we have good metadata
def align_pv_system_ids(pv_metadata, pv_power_df):
    pv_system_ids = pv_metadata.index.intersection(pv_power_df.columns)
    pv_system_ids = np.sort(pv_system_ids)

    pv_power_df = pv_power_df[pv_system_ids]
    pv_metadata = pv_metadata.loc[pv_system_ids]
    return pv_metadata, pv_power_df
    
pv_metadata, pv_power_df = align_pv_system_ids(pv_metadata, pv_power_df)

In [None]:
# Scale to the range [0, 1]
pv_power_min = pv_power_df.min()
pv_power_max = pv_power_df.max()

pv_power_df -= pv_power_min
pv_power_df /= pv_power_max

In [None]:
# Drop systems which are producing over night
NIGHT_YIELD_THRESHOLD = 0.4
night_hours = list(range(21, 24)) + list(range(0, 4))
bad_systems = np.where(
    (pv_power_df[pv_power_df.index.hour.isin(night_hours)] > NIGHT_YIELD_THRESHOLD).sum()
)[0]
bad_systems = pv_power_df.columns[bad_systems]
print(len(bad_systems), 'bad systems found.')

#ax = pv_power_df[bad_systems].plot(figsize=(40, 10), alpha=0.5)
#ax.set_title('Bad PV systems');

In [None]:
pv_power_df.drop(bad_systems, axis='columns', inplace=True)

In [None]:
%%time
# Interpolate up to 15 minutes ahead.
pv_power_df = pv_power_df.interpolate(limit=3)

In [None]:
# Align again, after removing dud PV systems
pv_metadata, pv_power_df = align_pv_system_ids(pv_metadata, pv_power_df)

In [None]:
len(pv_power_df.columns)

In [None]:
#pv_power_df.plot(figsize=(40, 10), alpha=0.5, legend=False);

In [None]:
pv_power_df.head()

## Load satellite data

In [None]:
from glob import glob
from torch.utils.data import Dataset
from datetime import datetime

In [None]:
RECTANGLE_WIDTH_M = 128000 # in meters
RECTANGLE_HEIGHT_M = RECTANGLE_WIDTH_M

METERS_PER_PIXEL = 1000
RECTANGLE_WIDTH_PIXELS = np.int(RECTANGLE_WIDTH_M / METERS_PER_PIXEL)
RECTANGLE_HEIGHT_PIXELS = np.int(RECTANGLE_HEIGHT_M / METERS_PER_PIXEL)

SAT_IMAGE_MEAN = 20.444992
SAT_IMAGE_STD = 8.766013


def get_rectangle(data_array, centre_x, centre_y, width=RECTANGLE_WIDTH_M, height=RECTANGLE_HEIGHT_M):
    half_width = width / 2
    half_height = height / 2

    north = centre_y + half_height
    south = centre_y - half_height
    east = centre_x + half_width
    west = centre_x - half_width

    return data_array.loc[dict(
        x=slice(west, east), 
        y=slice(north, south))]


class SatelliteLoader(Dataset):
    """
    Attributes:
        index: pd.Series which maps from UTC datetime to full filename of satellite data.
        _data_array_cache: The last lazily opened xr.DataArray that __getitem__ was asked to open.
            Useful so that we don't have to re-open the DataArray if we're asked to get
            data from the same file on several different calls.
    """
    def __init__(self, file_pattern):
        self._load_sat_index(file_pattern)
        self._data_array_cache = None
        self._last_filename_requested = None
        
    def __getitem__(self, dt: datetime) -> xr.DataArray:
        """Returns lazily-opened DataArray"""
        sat_filename = self.index[dt]
        if sat_filename != self._last_filename_requested:
            self._data_array_cache = xr.open_dataarray(sat_filename)
            self._last_filename_requested = sat_filename
        return self._data_array_cache.sel(time=dt)
    
    def close(self):
        if self._data_array_cache is not None:
            self._data_array_cache.close()
        
    def __len__(self):
        return len(self.index)
        
    def _load_sat_index(self, file_pattern):
        """Opens all satellite files in `file_pattern` and loads all their datetime indicies into self.index."""
        sat_filenames = glob(file_pattern)
        sat_filenames.sort()
        
        n_filenames = len(sat_filenames)
        sat_index = []
        for i_filename, sat_filename in enumerate(sat_filenames):
            if i_filename % 10 == 0 or i_filename == (n_filenames - 1):
                print('\r {:5d} of {:5d}'.format(i_filename + 1, n_filenames), end='', flush=True)
            data_array = xr.open_dataarray(sat_filename, drop_variables=['x', 'y'])
            sat_index.extend([(sat_filename, t) for t in data_array.time.values])

        sat_index = pd.DataFrame(sat_index, columns=['filename', 'datetime']).set_index('datetime').squeeze()
        assert not any(sat_index.index.duplicated())
        self.index = sat_index.tz_localize('UTC')
        
    def get_rectangles_for_all_data(self, centre_x, centre_y, width=RECTANGLE_WIDTH_M, height=RECTANGLE_HEIGHT_M):
        """Iterate through all satellite filenames and load rectangle of imagery."""
        sat_filenames = np.sort(np.unique(self.index.values))
        for sat_filename in sat_filenames:
            data_array = xr.open_dataarray(sat_filename)
            yield get_rectangle(data_array, time, centre_x, centre_y, width, height)
        
    def get_rectangle(self, time, centre_x, centre_y, width=RECTANGLE_WIDTH_M, height=RECTANGLE_HEIGHT_M):
        data_array = self[time]
        return get_rectangle(data_array, centre_x, centre_y, width, height)

In [None]:
%%time
sat_loader = SatelliteLoader(os.path.join(SATELLITE_DATA_PATH, '*.nc'))
print()

In [None]:
len(sat_loader)

## Plot a test

In [None]:
# Test get rectangle
dt = pd.Timestamp('2019-02-21 10:15')
pv_system_id = pv_metadata.index[1]
x, y = pv_metadata.loc[pv_system_id][['x', 'y']]

In [None]:
%%time
sat_data = sat_loader.get_rectangle(time=dt, centre_x=x, centre_y=y) #, width=512000, height=512000)

In [None]:
fig = plt.figure(figsize=(10, 10))
crs = ccrs.TransverseMercator()
ax = plt.axes(projection=crs)
ax.coastlines(resolution='10m', alpha=0.5, color='pink')

sat_data.plot.imshow(ax=ax, cmap='gray', origin='upper', add_colorbar=True)
ax.scatter(x=x, y=y, alpha=0.7);

## Test clearsky

In [None]:
import pvlib
from pvlib.location import Location

In [None]:
def get_pvlib_location(pv_system_id):
    return Location(
        latitude=pv_metadata['latitude'][pv_system_id],
        longitude=pv_metadata['longitude'][pv_system_id],
        tz='UTC',
        name=pv_metadata['system_name'][pv_system_id])

location = get_pvlib_location(pv_system_id)
location

In [None]:
fig, ax = plt.subplots(figsize=(20, 7))
pv_data_to_plot = pv_power_df[pv_system_id][dt - timedelta(hours=48):dt + timedelta(hours=48)]
ax.plot(pv_data_to_plot, label='PV yield')
#ax.plot((dt, dt), (0, 1), linewidth=1, color='black', label='datetime of image above')
ax.set_title(dt)
ax.set_ylim((0, 1))

ax2 = ax.twinx()
clearsky = location.get_clearsky(pv_data_to_plot.index)
lines = ax2.plot(clearsky)
for line, label in zip(lines, clearsky.columns):
    line.set_label(label);
ax2.legend(loc='upper left');

## Align satellite datetime index with PV datetime index

In [None]:
datetime_index = pv_power_df.index.intersection(sat_loader.index.index)

In [None]:
# Filter by datetimes when sun is shining!
daylight_mask = location.get_clearsky(datetime_index)['ghi'] > 0
datetime_index = datetime_index[daylight_mask]

In [None]:
pv_power_df = pv_power_df.reindex(datetime_index)

In [None]:
len(datetime_index)

In [None]:
datetime_index.tz

In [None]:
# Split train & test by days
days = np.unique(datetime_index.date)
len(days)

In [None]:
# Use every 5th day for testing
testing_days = days[::5]
len(testing_days)

In [None]:
training_days = np.array(list(set(days) - set(testing_days)))
training_days = np.sort(training_days)
len(training_days)

In [None]:
def get_datetime_index_for_days(training_or_testing_days):
    return datetime_index[pd.Series(datetime_index.date).isin(training_or_testing_days)]

training_datetimes = get_datetime_index_for_days(training_days)
testing_datetimes = get_datetime_index_for_days(testing_days)
assert not set(training_datetimes).intersection(testing_datetimes)

len(training_datetimes), len(testing_datetimes)

### Load testing batch

In [None]:
import torch

In [None]:
def new_full_array(size, fill_value=np.NaN, dtype=np.float16):
    return np.full(shape=size, fill_value=fill_value, dtype=dtype)

In [None]:
TESTING_BATCH_SIZE = 256

In [None]:
testing_batch = {
    'sat_images': new_full_array(
        size=(TESTING_BATCH_SIZE, 1, RECTANGLE_WIDTH_PIXELS, RECTANGLE_HEIGHT_PIXELS),
        dtype=np.float32),  # use float32 to minimise problems with normalisation
    'pv_yield': new_full_array(
        size=(TESTING_BATCH_SIZE, 1)),
    'pv_system_id': np.zeros(shape=TESTING_BATCH_SIZE, dtype=np.int32),
    'datetime_index': testing_datetimes[:TESTING_BATCH_SIZE]}

In [None]:
for i, dt in enumerate(testing_batch['datetime_index']):
    # Randomly sample from PV systems which have data for this datetime
    pv_data_for_dt = pv_power_df.loc[dt].dropna()
    pv_system_id = np.random.choice(pv_data_for_dt.index)
    pv_yield = pv_data_for_dt[pv_system_id]
    
    # Load satellite image
    x, y = pv_metadata.loc[pv_system_id][['x', 'y']]
    sat_data = sat_loader.get_rectangle(time=dt, centre_x=x, centre_y=y)
    
    # Put into super batch
    testing_batch['sat_images'][i, 0] = sat_data.values
    testing_batch['pv_yield'][i, 0] = pv_yield
    testing_batch['pv_system_id'][i] = pv_system_id

In [None]:
# Normalise satellite images
testing_batch['sat_images'] -= SAT_IMAGE_MEAN
testing_batch['sat_images'] /= SAT_IMAGE_STD

## Load training super batch

In [None]:
N_RECTANGLES_PER_SAT_IMAGE = 32
N_DATETIMES_PER_SUPERBATCH = 4096

SUPER_BATCH_SIZE = N_RECTANGLES_PER_SAT_IMAGE * N_DATETIMES_PER_SUPERBATCH
BYTES_PER_PIXEL = 2  # float16
size_of_each_image_mb = (RECTANGLE_HEIGHT_PIXELS * RECTANGLE_WIDTH_PIXELS * BYTES_PER_PIXEL) / 1E6
super_batch_size_mb = size_of_each_image_mb * SUPER_BATCH_SIZE
print('Size of super batch: {:8.1f} MB'.format(super_batch_size_mb))
print('                     {:6d}   examples'.format(SUPER_BATCH_SIZE))

### Load CPU super batch from individual images

In [None]:
%%time
cpu_super_batch = {
    'sat_images': new_full_array(
        size=(SUPER_BATCH_SIZE, 1, RECTANGLE_WIDTH_PIXELS, RECTANGLE_HEIGHT_PIXELS), 
        dtype=np.float32),  # use float32 to minimise issues with normalisation
    'pv_yield': new_full_array(
        size=(SUPER_BATCH_SIZE, 1)),
    'pv_system_id': np.zeros(shape=SUPER_BATCH_SIZE, dtype=np.int32),
    'datetime_index': np.zeros(shape=SUPER_BATCH_SIZE, dtype='datetime64[s]')}

In [None]:
HALF_RECTANGLE_WIDTH_M = RECTANGLE_WIDTH_M / 2
HALF_RECTANGLE_HEIGHT_M = RECTANGLE_HEIGHT_M / 2

def load_data_into_cpu_super_batch():
    # Get datetimes for superbatch by randomly sampling
    super_batch_datetimes = np.random.choice(training_datetimes.to_numpy(dtype=object), size=N_DATETIMES_PER_SUPERBATCH)
    super_batch_datetimes = np.sort(super_batch_datetimes)
    super_batch_datetimes = pd.DatetimeIndex(super_batch_datetimes)
    cpu_super_batch['datetime_index'] = np.zeros(shape=SUPER_BATCH_SIZE, dtype='datetime64[s]')
    
    # Load satellite data and PV data
    for image_i, dt in enumerate(super_batch_datetimes):
        print('\r{:6d} of {:d}'.format(image_i+1, N_DATETIMES_PER_SUPERBATCH), end='', flush=True)

        # Randomly sample from PV systems which have data for this datetime
        pv_data_for_dt = pv_power_df.loc[dt].dropna()
        replace = len(pv_data_for_dt) < N_RECTANGLES_PER_SAT_IMAGE
        pv_system_ids = np.random.choice(pv_data_for_dt.index, size=N_RECTANGLES_PER_SAT_IMAGE, replace=replace)
        locations = pv_metadata.loc[pv_system_ids][['x', 'y']]

        # Get bounding box
        north = locations['y'].max() + HALF_RECTANGLE_HEIGHT_M
        south = locations['y'].min() - HALF_RECTANGLE_HEIGHT_M
        west = locations['x'].min() - HALF_RECTANGLE_WIDTH_M
        east = locations['x'].max() + HALF_RECTANGLE_WIDTH_M

        # Load satellite images
        data_array = sat_loader[dt]
        data_array = data_array.loc[dict(
            x=slice(west, east), 
            y=slice(north, south))]
        data_array = data_array.load()

        example_i = image_i * N_RECTANGLES_PER_SAT_IMAGE
        for pv_system_id, row in locations.iterrows():
            sat_data = get_rectangle(data_array, centre_x=row.x, centre_y=row.y)
            pv_yield = pv_data_for_dt[pv_system_id]

            # Put into super batch
            cpu_super_batch['sat_images'][example_i, 0] = sat_data.values
            cpu_super_batch['pv_yield'][example_i, 0] = pv_yield
            cpu_super_batch['pv_system_id'][example_i] = pv_system_id
            cpu_super_batch['datetime_index'][example_i] = dt.to_numpy()  # TODO: Maybe move this to a vectorised solution?
            example_i += 1
            
    cpu_super_batch['datetime_index'] = pd.DatetimeIndex(cpu_super_batch['datetime_index'], tz='UTC')

    # Normalise satellite images
    cpu_super_batch['sat_images'] -= SAT_IMAGE_MEAN
    cpu_super_batch['sat_images'] /= SAT_IMAGE_STD

    print()
    return cpu_super_batch

In [None]:
%%time
cpu_super_batch = load_data_into_cpu_super_batch()

In [None]:
def plot(batch_dict, i):
    plt.imshow(batch_dict['sat_images'][i, 0].astype(np.float32))
    print(batch_dict['datetime_index'][i])
    print('PV yield', batch_dict['pv_yield'][i])
    for key in ['clearsky', 'hours_of_day']:
        try:
            print(key, batch_dict[key][i])
        except KeyError:
            pass

plot(cpu_super_batch, 5000)

### Compute hour of day and clearsky

In [None]:
HOURS_OF_DAY_MEAN = 11.628418
HOURS_OF_DAY_STD = 4.1584363

def compute_hour_of_day(batch_dict):
    hours_of_day = batch_dict['datetime_index'].hour.values.astype(np.float32)
    hours_of_day -= HOURS_OF_DAY_MEAN
    hours_of_day /= HOURS_OF_DAY_STD
    batch_dict['hours_of_day'] = hours_of_day[:, np.newaxis]
    return batch_dict

In [None]:
cpu_super_batch = compute_hour_of_day(cpu_super_batch)

In [None]:
testing_batch = compute_hour_of_day(testing_batch)

#### Clearsky

In [None]:
CLEARSKY_MEAN = np.array([373.1623 , 538.70374,  80.82757], dtype=np.float32)
CLEARSKY_STD = np.array([268.6872  , 254.62102 ,  42.651264], dtype=np.float32)

def compute_clearsky(batch_dict):
    n_examples = len(batch_dict['datetime_index'])
    clearsky = np.full(shape=(n_examples, 3), fill_value=np.NaN, dtype=np.float32)
    pv_ids_and_datetimes = pd.DataFrame(
        {'pv_system_id': batch_dict['pv_system_id'], 
         'datetime_index': batch_dict['datetime_index']})
    
    for pv_system_id, df in pv_ids_and_datetimes.groupby('pv_system_id'):
        dt_index = pd.DatetimeIndex(df['datetime_index'])
        location = get_pvlib_location(pv_system_id)
        clearsky_for_location = location.get_clearsky(dt_index)
        clearsky[df.index] = clearsky_for_location.values

    assert not any(np.isnan(clearsky).flatten())
    
    clearsky -= CLEARSKY_MEAN
    clearsky /= CLEARSKY_STD
    
    batch_dict['clearsky'] = clearsky
    return batch_dict

In [None]:
%%time
cpu_super_batch = compute_clearsky(cpu_super_batch)

In [None]:
%%time
testing_batch = compute_clearsky(testing_batch)

### GPU super batch

In [None]:
from copy import copy

In [None]:
def new_full_tensor(size, fill_value=np.NaN, dtype=torch.float16, device='cuda'):
    return torch.full(size=size, fill_value=fill_value, dtype=dtype, device=device)

In [None]:
%%time
gpu_super_batch = {
    'sat_images': new_full_tensor(
        size=(SUPER_BATCH_SIZE, 1, RECTANGLE_WIDTH_PIXELS, RECTANGLE_HEIGHT_PIXELS)),
    'pv_yield': new_full_tensor(size=(SUPER_BATCH_SIZE, 1)),
    'hours_of_day': new_full_tensor(size=(SUPER_BATCH_SIZE, 1)),
    'clearsky': new_full_tensor(size=(SUPER_BATCH_SIZE, 3))
}

In [None]:
def move_superbatch_to_gpu(cpu_super_batch):
    for k, v in cpu_super_batch.items():
        if k in ['datetime_index', 'pv_system_id']:
            gpu_super_batch[k] = copy(v)
        else:
            try:
                gpu_super_batch[k].copy_(torch.HalfTensor(v))
            except:
                print('Problem with', k)
                raise

    return gpu_super_batch

In [None]:
%%time
gpu_super_batch = move_superbatch_to_gpu(cpu_super_batch)

In [None]:
gpu_super_batch.keys()

In [None]:
# Move testing batch into GPU memory
for key in ['sat_images', 'pv_yield', 'hours_of_day', 'clearsky']:
    testing_batch[key] = torch.cuda.HalfTensor(testing_batch[key])

## Define neural net

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Net(nn.Module):
    def __init__(self, dropout_proportion=0.1):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=12, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=5)
        HOURS_OF_DAY_CHANNELS = 1
        CLEARSKY_CHANNELS = 3
        self.fc1 = nn.Linear(16 * 29 * 29, 120)
        self.fc2 = nn.Linear(120 + HOURS_OF_DAY_CHANNELS + CLEARSKY_CHANNELS, 84)
        self.fc3 = nn.Linear(84, 1)
        self.dropout_layer = nn.Dropout(p=dropout_proportion)

    def forward(self, x, hour_of_day, clearsky):
        #x = self.dropout_layer(x)
        x = self.pool(F.relu(self.conv1(x)))
        # x is now <batch_size>, 6, 62, 62.  
        # 62 is 124 / 2.  124 is the 128-dim input - 4
        x = self.dropout_layer(x)
        x = self.pool(F.relu(self.conv2(x)))
        # x is now <batch_size>, 16, 29, 29
        x = x.view(-1, 16 * 29 * 29)
        # x is now <batch_size>, 16 x 29 x 29
        x = self.dropout_layer(x)
        x = F.relu(self.fc1(x))
        x = self.dropout_layer(x)
        x = torch.cat((x, hour_of_day, clearsky), dim=1)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net().cuda().half()

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0.01)
loss_func = nn.MSELoss()
mae_loss_func = nn.L1Loss()

In [None]:
train_losses = []
train_mae_losses = []
test_losses = []
test_mae_losses = []

training_index_len_minus_1 = SUPER_BATCH_SIZE - 1

In [None]:
TRAINING_BATCH_SIZE = 128

BATCHES_PER_EPOCH = int(SUPER_BATCH_SIZE / TRAINING_BATCH_SIZE)
STATS_PERIOD = int(BATCHES_PER_EPOCH / 4)
N_EPOCHS = 7
N_LOADS = 7
N_BATCHES_TO_TRAIN = BATCHES_PER_EPOCH * N_EPOCHS

TESTING_INPUTS = testing_batch['sat_images']
TESTING_TARGET = testing_batch['pv_yield']
TESTING_HOURS_OF_DAY = testing_batch['hours_of_day']
TESTING_CLEARSKY = testing_batch['clearsky']

In [None]:
%%time

for i_load in range(N_LOADS):
    print('loading', i_load, 'of', N_LOADS)
    t0 = time.time()
    running_train_loss = 0.0
    running_train_mae = 0.0
    for i_batch in range(N_BATCHES_TO_TRAIN):
        print('\rBatch: {:4d} of {}'.format(i_batch + 1, N_BATCHES_TO_TRAIN), end='', flush=True)

        # Create batch
        batch_index = np.random.randint(low=0, high=training_index_len_minus_1, size=TRAINING_BATCH_SIZE)
        inputs = gpu_super_batch['sat_images'][batch_index]
        hours_of_day_for_batch = gpu_super_batch['hours_of_day'][batch_index]
        clearsky_for_batch = gpu_super_batch['clearsky'][batch_index]
        target = gpu_super_batch['pv_yield'][batch_index]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        net.train()
        outputs = net(inputs, hours_of_day_for_batch, clearsky_for_batch)
        train_loss = loss_func(outputs, target)
        train_loss.backward()
        optimizer.step()
        running_train_loss += train_loss.item()

        # MAE
        train_mae = mae_loss_func(outputs, target)
        running_train_mae += train_mae.item()

        # print statistics
        if i_batch == 0 or i_batch % STATS_PERIOD == STATS_PERIOD - 1:    # print every STATS_PERIOD mini-batches
            t1 = time.time()

            # Train loss
            if i_batch == 0:
                mean_train_loss = running_train_loss
                mean_train_mae = running_train_mae
            else:
                mean_train_loss = running_train_loss / STATS_PERIOD
                mean_train_mae = running_train_mae / STATS_PERIOD

            train_losses.append(mean_train_loss)
            train_mae_losses.append(mean_train_mae)

            # Test loss
            net.eval()
            test_outputs = net(TESTING_INPUTS, TESTING_HOURS_OF_DAY, TESTING_CLEARSKY)
            test_loss = loss_func(test_outputs, TESTING_TARGET).item()
            test_losses.append(test_loss)
            test_mae = mae_loss_func(test_outputs, TESTING_TARGET).item()
            test_mae_losses.append(test_mae)

            print(
                '\n        time =   {:.2f} milli seconds per batch.\n'
                '   train loss = {:8.5f}\n'
                '    train MAE = {:8.5f}\n'
                '    test loss = {:8.5f}\n'
                '     test MAE = {:8.5f}'.format(
                    ((t1 - t0) / STATS_PERIOD) * 1000,
                    mean_train_loss, 
                    mean_train_mae,
                    test_loss,
                    test_mae
                ))
            running_train_loss = 0.0
            running_train_mae = 0.0
            t0 = time.time()
          
    print()
    print('Loading new data!')
    cpu_super_batch = load_data_into_cpu_super_batch()
    cpu_super_batch = compute_hour_of_day(cpu_super_batch)
    cpu_super_batch = compute_clearsky(cpu_super_batch)
    gpu_super_batch = move_superbatch_to_gpu(cpu_super_batch)

print()
print('Finished Training')

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True, figsize=(20, 10))

ax1.plot(test_losses, label='testing')
ax1.plot(train_losses, label='training')
ax1.set_title('MSE (training objective)')
ax1.set_ylabel('MSE')
ax1.legend()

ax2.plot(test_mae_losses, label='testing')
ax2.plot(train_mae_losses, label='training')
ax2.set_title('MAE')
ax2.set_ylabel('MAE')
ax2.legend();

In [None]:
i = 15
plt.imshow(inputs[i, 0].cpu().numpy().astype(np.float32))
dt = gpu_super_batch['datetime_index'][batch_index]
dt[i]

In [None]:
target[i]

In [None]:
clearsky_for_batch[i]

In [None]:
np.corrcoef(target.cpu(), clearsky_for_batch[:, 0].cpu())