In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import yaml
from types import SimpleNamespace
from dask.diagnostics import ProgressBar
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import pandas as pd
import math

def dict_to_namespace(data):
	if isinstance(data, dict):
		return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in data.items()})
	elif isinstance(data, list):
		return [dict_to_namespace(item) if isinstance(item, dict) else item for item in data]
	else:
		return data

In [None]:
class ERA5Pollutants(Dataset):
	def __init__(self, config, mode = 'train', means = None, stds = None):
		super().__init__()
		if config.load_data:
			data = pd.read_csv(config.load_data)
		else:
			data = pd.read_csv(config.haaq.path)
			assert config.processed_data != '', 'Pass processed_data in the config - file will be stored here'
			era5 = xr.open_zarr(config.era5.path)
			daily_avg = era5.coarsen(time=4).mean()
			daily_avg = daily_avg.assign_coords(time=daily_avg["time"].dt.floor("D"))
			del era5

			new_lats = np.arange(-90 + config.haaq.res/2, 90, config.haaq.res)
			new_lons = np.arange(0, 360, config.haaq.res)
			daily_avg_subset = daily_avg.interp(latitude=new_lats, longitude=new_lons, method='linear')
			del daily_avg

			for var in list(daily_avg_subset.data_vars):
				if var not in config.input_vars:
					daily_avg_subset = daily_avg_subset.drop_vars(var)
				else:
					data[var] = 0.

			drop_level = [50, 100,  150,  200,  250,  300,  400,  500,  600,  700,  850,  925]
			daily_avg_subset = daily_avg_subset.drop_sel(level=drop_level)
			daily_avg_subset.load()

			for i in tqdm(range(len(data))):
				for var in config.output_vars:
					data[i, var] = daily_avg_subset.sel(time=data.iloc[i]['time'], latitude=data.iloc[i]['latitude'], longitude=data.iloc[i]['longitude'])[0].values().item()

			data.to_csv(config.processed_data)
			del daily_avg_subset
		data['time'] = pd.to_datetime(data['time'])
		if mode == 'train':
			time_mask = (data['time'] >= config.train.time_slice[0]) & (data['time'] <= config.train.time_slice[1])
		elif mode == 'val':
			time_mask = (data['time'] >= config.val.time_slice[0]) & (data['time'] <= config.val.time_slice[1])
		elif mode == 'test':
			time_mask = (data['time'] >= config.test.time_slice[0]) & (data['time'] <= config.test.time_slice[1])
		else:
			raise ValueError('Unknown data mode, use train/val/test')
		data = data[time_mask]
		data[config.output_vars] = data[config.output_vars].interpolate(axis=0)
		data[config.output_vars] = data[config.output_vars].bfill()
		if means is None or stds is None:
			means, stds = {}, {}
			for var in config.input_vars + config.output_vars:
				means[var] = data[var].mean()
				stds[var] = data[var].std()
		for var in config.input_vars + config.output_vars:
			data[var] = (data[var] - means[var]) / stds[var]

		self.means, self.stds = means, stds
		self.data = data

	def __len__(self):
		return len(self.data)

	def __getitem__(self, idx):
		inputs, outputs = [], []
		for input_var in config.input_vars:
			inputs.append(self.data[input_var].iloc[idx].astype(np.float32))

		for output_var in config.output_vars:
			outputs.append(self.data[output_var].iloc[idx].astype(np.float32))

		return torch.tensor(inputs), torch.tensor(outputs)

	@classmethod
	def collate_fn(cls, batch):
		x_batch, y_batch = zip(*batch)
		return torch.stack(x_batch), torch.stack(y_batch)

In [None]:
class ShallowNet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ShallowNet, self).__init__()
        self.mlp1 = nn.Linear(in_ch, 5)
        self.mlp2 = nn.Linear(5, 5)
        self.mlp3 = nn.Linear(5, out_ch)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.mlp1(x)
        x = self.relu(x)
        x = self.mlp2(x) + x
        x = self.relu(x)
        x = self.mlp3(x)
        return x

In [None]:
class Climatology():
	def __init__(self, config, data):
		self.output = []
		for var in config.output_vars:
			self.output.append(data[var].mean())
		self.output = torch.tensor(self.output)

	def eval(self):
		pass

	def __call__(self, batch):
		return torch.repeat_interleave(self.output.unsqueeze(0).to(batch.device), batch.shape[0], dim = 0)

In [None]:
def load_model(model, model_path):
    model.load_state_dict(torch.load(model_path))
    model.eval()
    print(f"Model loaded from {model_path}")
    return model

In [None]:
def validate_model(config, model, val_dataloader, epoch, device = torch.device('cuda')):
    model.eval()
    val_loss = torch.tensor([0.0] * len(config.output_vars)).to(device)
    criterion = nn.MSELoss(reduction='none')
    with torch.no_grad():
        for x, y in tqdm(val_dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y).sum(dim = 0)
            val_loss += loss
        loss_per_var = val_loss / len(val_dataloader.dataset)
        loss_total = loss_per_var.mean()
    return torch.sqrt(loss_per_var), torch.sqrt(loss_total)

In [None]:
def train_model(config, model, train_dataloader, val_dataloader, device = torch.device('cuda')):
    # Hyperparameters
    epochs = 10
    learning_rate = 1e-2

    # Initialize dataset, dataloader, model, loss, optimizer
    criterion = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    min_val_loss = float('inf')

    # Training loop
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for x, y in tqdm(train_dataloader):
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        _, val_loss = validate_model(config, model, val_dataloader, epoch, device)
        val_loss = min(min_val_loss, val_loss)
        if val_loss == min_val_loss:
            torch.save(model.state_dict(), config.model_dir + '/best_model.pth')

        print(f"Epoch {epoch + 1}, Train Loss: {epoch_loss / len(train_dataloader)}, Validation Loss: {val_loss}")
    return model

In [None]:
config = dict_to_namespace(yaml.safe_load(open('config.yaml').read()))
config.model_dir = os.path.join(config.model_dir, config.job_name)
os.makedirs(config.model_dir, exist_ok=True)

In [None]:
train_dataset = ERA5Pollutants(config, mode = 'train')
val_dataset = ERA5Pollutants(config, mode = 'val', means = train_dataset.means, stds = train_dataset.stds)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False, collate_fn=ERA5Pollutants.collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=ERA5Pollutants.collate_fn)

model = ShallowNet(len(config.input_vars), len(config.output_vars))
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model = train_model(config, model, train_dataloader, val_dataloader, device)

  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 1, Train Loss: 0.8840233082982742, Validation Loss: 0.7445871233940125


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 2, Train Loss: 0.8722660151916907, Validation Loss: 0.7468861937522888


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 3, Train Loss: 0.8634095921406141, Validation Loss: 0.7487437129020691


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 4, Train Loss: 0.8570885845895754, Validation Loss: 0.7494612336158752


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 5, Train Loss: 0.8533816747367382, Validation Loss: 0.7504296898841858


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 6, Train Loss: 0.8516010539292924, Validation Loss: 0.7560070753097534


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 7, Train Loss: 0.8520636805176552, Validation Loss: 0.7527524828910828


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 8, Train Loss: 0.8515152884504846, Validation Loss: 0.7549209594726562


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 9, Train Loss: 0.8512806089008256, Validation Loss: 0.7518259882926941


  0%|          | 0/1308 [00:00<?, ?it/s]

  0%|          | 0/72 [00:00<?, ?it/s]

Epoch 10, Train Loss: 0.8495013088853717, Validation Loss: 0.7504407167434692


In [None]:
# model = load_model(model, config.model_dir + '/best_model.pth')
climatology = Climatology(config, train_dataset.data)
test_dataset = ERA5Pollutants(config, mode = 'test', means = train_dataset.means, stds = train_dataset.stds)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=ERA5Pollutants.collate_fn)
test_loss_var, test_loss = validate_model(config, model, test_dataloader, 0)
test_loss_climatology_var, test_loss_climatology = validate_model(config, climatology, test_dataloader, 0)
print(f"Model: {test_loss_var}, {test_loss}")
print(f'Climatology: {test_loss_climatology_var}, {test_loss_climatology}')

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

Model: tensor([0.8147, 0.7439, 0.7101], device='cuda:0'), 0.7575061321258545
Climatology: tensor([0.8675, 0.7408, 0.7943], device='cuda:0'), 0.8025245070457458
