In [40]:
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

from config_data.config import load_config
from utils import pair

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from pathlib import Path
import h5py

import matplotlib.pyplot as plt
from tqdm import tqdm

In [9]:
configs = load_config()

In [43]:
from torch.utils.data import Dataset, DataLoader, random_split

h5_file_path = 'pdsi_data_dims_60_90_train_frames_16.h5'

class H5Dataset(Dataset):
    def __init__(self, h5_file_path):
        self.data = h5py.File(h5_file_path, 'r')
        self.x_data = self.data['X_data']
        self.y_data = self.data['y_data']

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, index):

        x_item = self.x_data[index]
        y_item = self.y_data[index]

        x_item = np.expand_dims(x_item, axis=0)
        y_item = np.expand_dims(y_item, axis=0)

        return x_item, y_item

dataset = H5Dataset(h5_file_path)

seed = torch.Generator()
seed.manual_seed(42)
train_dataset, test_dataset = random_split(dataset, [0.7, 0.3], generator=seed)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

In [36]:
class Model(nn.Module):
    """
    Just one Linear layer
    """
    def __init__(self, configs):
        super(Model, self).__init__()
        self.pred_len = configs.pred_len
        self.param_len = configs.param_len
        self.patch_size = configs.patch_size
        self.memory = configs.memory
        self.emb_size = configs.emb_size


        self.image_size_h, self.image_size_w = pair(configs.image_size)
        assert self.image_size_h % self.patch_size == 0 and self.image_size_w % self.patch_size == 0, 'image dimensions must be divisible by the patch size'


        self.projection = nn.Sequential(
            Rearrange('batch_size one ts (h p1) (w p2) -> batch_size (h w) (one ts p1 p2)', p1 = self.patch_size, p2 = self.patch_size),
            nn.Linear(self.memory*self.patch_size*self.patch_size, self.emb_size)
        ) # this breaks down the image in s1xs2 patches, and then flat them
        self.linear = nn.Linear(self.emb_size, self.pred_len*self.patch_size**2)
        self.convertion = Rearrange('batch_size (h w) (ts p1 p2) -> batch_size 1 ts (h p1) (w p2)',
                                    h = self.image_size_h//self.patch_size,
                                    w = self.image_size_w//self.patch_size,
                                    p1 = self.patch_size, p2 = self.patch_size)

    def forward(self, x):
        # x: [Batch, Input length, Channel]
        x = self.projection(x)
        x = self.linear(x)
        x = self.convertion(x)
        return x # [Batch, Output length, Channel]

In [37]:
model = Model(configs)

In [41]:
optim = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5) # TODO
criterion = nn.MSELoss(reduction="mean") # TODO
device =configs.gpu

In [46]:
train_loss_epoch = []
val_loss_epoch = []
epochs = 999999999

# with torch.autograd.set_detect_anomaly(True):


for epoch in tqdm(range(epochs)):
    train_loss = []

    #########
    # TRAIN #
    #########
    # TODO: add code for the training step

    model.train()

    for batch in tqdm(train_loader, leave=False):
        optim.zero_grad()
        batch[0], batch[1] = batch[0].to(device).float(), batch[1].to(device).float()
        pred = model(batch[0])
        loss = criterion(pred, batch[1])

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()
        train_loss.append(loss.item())

    train_loss_epoch.append(np.mean(train_loss))
    #######
    # VAL #
    #######
    model.train(False)
    val_loss = []

    with torch.autograd.no_grad():
        for batch in tqdm(val_loader, leave=False):
            batch[0], batch[1] = batch[0].to(device).float(), batch[1].to(device).float()

            pred = model(batch[0])
            loss = criterion(pred, batch[1])
            val_loss.append(loss.item())

    val_loss_epoch.append(np.mean(val_loss))

    if np.argmin(val_loss_epoch) == epoch:
        torch.save(model.state_dict(), "./output/eathtransformer/temp/best_model2_2.pth")
    print("train loss:", np.round(train_loss_epoch[-1], 5), "\tval loss:", np.round(val_loss_epoch[-1], 5))

  0%|          | 0/999999999 [00:13<?, ?it/s]


KeyboardInterrupt: 