In [None]:
import sys
sys.path.append('..')

from config import get_config
from data.random_data import get_dataloaders
from submission.resnet import ResNetPV as Model
from util import util
import submission.keys as keys

import numpy as np
import torch
import torch.nn as nn
from datetime import datetime
import matplotlib.pyplot as plt

In [None]:
config = get_config('../configs/resnet.yaml', [])
ckpt_path = '../ckpts/resnext50_imstoopid.pt.best_ema'

config.data.eval_subset_size = 50_000

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model(config.model.config).to(device)
model.load_state_dict(torch.load(ckpt_path))
model.eval()
dataloader = get_dataloaders(
    config=config,
    meta_features=keys.META,
    nonhrv_features=model.REQUIRED_NONHRV,
    weather_features=model.REQUIRED_WEATHER,
    future_features=None,
    load_train=False,
)

In [None]:
def eval(dataloader, model, criterion=nn.L1Loss()):
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tot_loss_4h, count = 0, 0

    gt = np.zeros((len(dataloader.dataset), 48))
    preds = np.zeros((len(dataloader.dataset), 48))
    with torch.no_grad():
        for i, (pv_features, meta, nonhrv, weather, pv_targets) in enumerate(dataloader):
            meta, nonhrv, weather = util.dict_to_device(meta), util.dict_to_device(nonhrv), util.dict_to_device(weather)
            pv_features = pv_features.to(device, dtype=torch.float)
            pv_targets = pv_targets.to(device, dtype=torch.float)

            predictions = model(pv_features, meta, nonhrv, weather)


            loss_4h = criterion(predictions, pv_targets)
            
            gt[i * dataloader.batch_size: (i + 1) * dataloader.batch_size] = pv_targets.cpu().numpy()
            preds[i * dataloader.batch_size: (i + 1) * dataloader.batch_size] = predictions.cpu().numpy()

            size = int(pv_targets.size(0))
            tot_loss_4h += float(loss_4h) * size
            count += size

    val_loss_4h = tot_loss_4h / count

    return val_loss_4h, preds, gt

In [None]:
loss, preds, gt = eval(dataloader, model)
loss

In [None]:
losses = np.abs(preds - gt).mean(axis=1)
losses.shape

In [None]:
_ = plt.hist(losses, bins=100)

# Time of day loss

In [None]:
tod_losses = np.zeros(24, dtype=float)
tod_counts = np.zeros(24, dtype=int)
for i, (_, meta, _, _, _) in enumerate(dataloader):
    for j, key in enumerate(meta[keys.META.TIME]):
        dtime = datetime.fromtimestamp(key)
        tod_losses[dtime.hour] += losses[i * dataloader.batch_size + j]
        tod_counts[dtime.hour] += 1

tod_losses /= tod_counts
# tod_losses[np.isnan(tod_losses)] = 0
plt.plot(range(24), tod_losses)
plt.title('Loss by Time of Day')

In [None]:
tod_losses = np.zeros(366, dtype=float)
tod_counts = np.zeros(366, dtype=int)
for i, (_, meta, _, _, _) in enumerate(dataloader):
    for j, key in enumerate(meta[keys.META.TIME]):
        dtime = datetime.fromtimestamp(key)
        tod_losses[dtime.timetuple().tm_yday - 1] += losses[i * dataloader.batch_size + j]
        tod_counts[dtime.timetuple().tm_yday - 1] += 1
    # dtime = datetime.fromtimestamp(meta[keys.META.TIME]).timetuple().tm_yday - 1
    # # times_of_day[dtime.hour] += losses[i]
    # tod_losses[dtime] += losses[i]
    # tod_counts[dtime] += 1

tod_losses /= tod_counts
# tod_losses[np.isnan(tod_losses)] = 0
plt.plot(range(366), tod_losses)
plt.title('Loss by Day of Year')

# Example visualizers

In [None]:
i = np.random.randint(0, len(dataloader.dataset) // 10)
worst_inds = np.argsort(losses)[::-1]
ind = worst_inds[i]
print(losses[ind])
pv, meta, nonhrv, weather, targets = dataloader.dataset[ind]

def to_np(a):
    return a.detach().cpu().numpy()

pv_feature = pv
pv_target = targets
pred = preds[ind]
hrv_feature = to_np(nonhrv[keys.NONHRV.VIS008])

fig, ax = plt.subplots()

ax.plot(np.arange(0, 12), pv_feature, color='black', label="features")
ax.plot(np.arange(12, 60), pv_target, color='green', label="target")
ax.plot(np.arange(12, 60), pred, color='red', label="prediction")
ax.plot([11,12], [pv_feature[-1], pv_target[0]], color='black')

ax.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(20, 10))
plt.imshow(np.hstack(nonhrv[keys.NONHRV.VIS008]))
plt.axis('off')
plt.tight_layout()
plt.show()