In [294]:
import torch as T
import torch.nn as nn
#from torchtext import data, datasets
#from torchtext.vocab import Vocab
import torch.optim as optim
import time
import copy
import torch
import torch.nn.functional as F
from torchsummary import summary
import math

# Common imports
import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
import dask
import math
import datetime
from collections import OrderedDict
from datagenerator import *
from util_data import * 

In [295]:
print("Cuda Avaliable :", torch.cuda.is_available())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

Cuda Avaliable : False
cpu


In [296]:
# open data from WeatherBench
DATADIR = '/Users/noeliaotero/Documents/CAS_ML/WeatherBench-master/data/WeatherBench/5.625deg/'
# Load the entire dataset
z500 = xr.open_mfdataset(f'{DATADIR}geopotential_500/*.nc', combine='by_coords').z
t850 = xr.open_mfdataset(f'{DATADIR}temperature_850/*.nc', combine='by_coords').t.drop('level')
ds = xr.merge([z500, t850])

In [285]:
# only load a subset of the training data
ds_train = ds.sel(time=slice('2015', '2016'))  
ds_test = ds.sel(time=slice('2017', '2018'))

In [286]:

# then we need a dictionary for all the variables and levels we want to extract from the dataset
dic = OrderedDict({'z': None, 't': None})
lead_time =1
bs = 32
# Create a training and validation data generator. Use the train mean and std for validation as well.
dg_train = DataGenerator(
    ds_train.sel(time=slice('2015', '2015')), dic, lead_time, batch_size=bs, load=True)
dg_valid = DataGenerator(
    ds_train.sel(time=slice('2016', '2016')), dic, lead_time, batch_size=bs, mean=dg_train.mean, std=dg_train.std, shuffle=False)

Loading data into RAM
Loading data into RAM


In [298]:
X,y=dg_train[0]
print(X.shape)
Xt = torch.as_tensor(X)
print(Xt.shape)

(32, 2, 32, 64)
torch.Size([32, 2, 32, 64])


In [313]:
class PatchEmbedding(nn.Module):
    """Split image into patches and then embed them.
    Parameters
    ----------
    img_size :  Size of the image (my case: it is not a square).
    patch_size : Size of the patch (my case: it is not a square).
    in_chans : int Number of input channels.
    embed_dim : int The emmbedding dimension. This determines how big an embeding of our patch is going to be.
                Embedding stays constant across entire network ; flattened length for each token (or patch)
    Attributes
    ----------
    n_patches : int
        Number of patches inside of our image.
    proj : nn.Conv2d
        Convolutional layer that does both the splitting into patches
        and their embedding.
    """
    def __init__(self, img_size, patch_size, in_chans, embed_dim):

        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        super().__init__()
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.n_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        print(self.n_patches)
        # here, we take the kernel size and the stride size as the patch_size, this way sliding 
        # the kernel along the input tensor, we'll never slide it in an overlappig way and the 
        # kernel will fall into these patches that we're trying to divide our image into.
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
              

    def forward(self, X):
        """ Run forward pass
        Parameters
        -----------
        x: torch.Tensor
            Shape(n_samples, in_chans, img_size, img_size)
        Returns
        --------
        torch.Tensor (n_smaples, n_patches, emb_dim) """
        # or ..Output shape: (batch size, no. of patches, no. of channels)
        x = torch.as_tensor(X)
        x = self.proj(x) # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2)  #  we take the last two dimmesions that represent the grid of patches and 
                          #  flatten them into single dim (n_samples, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (n_samples, n_patches, embed_dim)

        
        return x 

In [364]:
class Attention(nn.Module):
    """Attention mechanism.
    Parameters
    ----------
    dim : int
        The input and out dimension of per token features.
    n_heads : int
        Number of attention heads.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    attn_p : float
        Dropout probability applied to the query, key and value tensors.
    proj_p : float
        Dropout probability applied to the output tensor.
    Attributes
    ----------
    scale : float
        Normalizing consant for the dot product.
    qkv : nn.Linear
        Linear projection for the query, key and value.
    proj : nn.Linear
        Linear mapping that takes in the concatenated output of all attention
        heads and maps it into a new space.
    attn_drop, proj_drop : nn.Dropout
        Dropout layers.
    """
    def __init__(self, dim, n_heads, qkv_bias=True, attn_p=0., proj_p=0.):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """
        n_samples, n_tokens, dim = x.shape
    

        if dim != self.dim:
            raise ValueError

        qkv = self.qkv(x)  # (n_samples, n_patches + 1, 3 * dim)
        qkv = qkv.reshape(
                n_samples, n_tokens, 3, self.n_heads, self.head_dim
        ) 
        # (n_smaples, n_patches + 1, 3, n_heads, head_dim)
        qkv = qkv.permute(
                2, 0, 3, 1, 4
        )  # (3, n_samples, n_heads, n_patches + 1, head_dim)
       
        q, k , v = qkv[0], qkv[1], qkv[2]
        
        k_t = k.transpose(-2, -1)  # (n_samples, n_heads, head_dim, n_patches + 1)
      
        dp = (
           q @ k_t
        ) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = dp.softmax(dim=-1)  # (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = self.attn_drop(attn)

        weighted_avg = attn @ v  # (n_samples, n_heads, n_patches +1, head_dim)
        weighted_avg = weighted_avg.transpose(
                1, 2
        )  # (n_samples, n_patches + 1, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2)  # (n_samples, n_patches + 1, dim)

        x = self.proj(weighted_avg)  # (n_samples, n_patches + 1, dim)
        x = self.proj_drop(x)  # (n_samples, n_patches + 1, dim)

        return x

In [365]:
class MLP(nn.Module):
    """Multilayer perceptron.
    Parameters
    ----------
    in_features : int
        Number of input features.
    hidden_features : int
        Number of nodes in the hidden layer.
    out_features : int
        Number of output features.
    p : float
        Dropout probability.
    Attributes
    ----------
    fc : nn.Linear
        The First linear layer.
    act : nn.GELU
        GELU activation function.
    fc2 : nn.Linear
        The second linear layer.
    drop : nn.Dropout
        Dropout layer.
    """
    def __init__(self, in_features, hidden_features, out_features, p=0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, in_features)`.
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches +1, out_features)`
        """
        x = self.fc1(
                x
        ) # (n_samples, n_patches + 1, hidden_features)
        x = self.act(x)  # (n_samples, n_patches + 1, hidden_features)
        x = self.drop(x)  # (n_samples, n_patches + 1, hidden_features)
        x = self.fc2(x)  # (n_samples, n_patches + 1, hidden_features)
        x = self.drop(x)  # (n_samples, n_patches + 1, hidden_features)

        return x

In [380]:
class Block(nn.Module):
    """Transformer block.
    Parameters
    ----------
    dim : int
        Embeddinig dimension.
    n_heads : int
        Number of attention heads.
    mlp_ratio : float
        Determines the hidden dimension size of the `MLP` module with respect
        to `dim`.
    qkv_bias : bool
        If True then we include bias to the query, key and value projections.
    p, attn_p : float
        Dropout probability.
    Attributes
    ----------
    norm1, norm2 : LayerNorm
        Layer normalization.
    attn : Attention
        Attention module.
    mlp : MLP
        MLP module.
    """
    def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
        # dim = 64
        super().__init__()
       
        self.dim = dim
        self.n_head = n_heads
        self.norm1 = nn.LayerNorm(self.dim, eps=1e-6)    #This is the first layer norm
        # This is the MSA 
        self.attn = Attention(
                self.dim,
                n_heads=self.n_head,
                qkv_bias=qkv_bias,
                attn_p=attn_p,
                proj_p=p
        )

        self.norm2 = nn.LayerNorm(dim, eps=1e-6)   # second layer norm
        hidden_features = int(dim * mlp_ratio)
        # This is where the MLP layer comes in 
        self.mlp = MLP(
                in_features=dim,
                hidden_features=hidden_features,
                out_features=dim,
        )

    def forward(self, x):
        """Run forward pass.
        Parameters
        ----------
        x : torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        Returns
        -------
        torch.Tensor
            Shape `(n_samples, n_patches + 1, dim)`.
        """
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        return x

In [293]:
p = PatchEmbedding((32,64), (16,16), 2, 64)
p(Xt).shape

8


torch.Size([32, 8, 64])

In [367]:
class ViTencoder(nn.Module):
    def __init__(self, embed_dim, n_heads, depth):
        super().__init__()
        self.embed_dim = embed_dim
        self.depth = depth
        self.n_heads = n_heads
        self.layer = nn.ModuleList()
        self.encoder_norm = nn.LayerNorm(self.embed_dim, eps=1e-6)
        for _ in range(self.depth):
            layer = Block(self.embed_dim, self.n_heads)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):


        for layer_block in self.layer:
            hidden_states = layer_block(hidden_states)

        encoded = self.encoder_norm(hidden_states)
        return encoded

In [368]:
class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=2,
            use_batchnorm=True,
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        bn = nn.BatchNorm2d(out_channels)

        super(Conv2dReLU, self).__init__(conv, bn, relu)

In [411]:
# now merge everything 
class ViT(nn.Module):
    def __init__(
            self,
            img_size=(32,64),
            patch_size=(1,2), #(4,8),
            in_chans= 2   ,   #2,    
            embed_dim= 64,       #64,
            depth=12,           #This is 12 for TransUNet
            n_heads=8,
    ):
        super(ViT, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.in_chans = in_chans
        self.depth = depth
        self.n_heads = n_heads
        self.patch_embed = PatchEmbedding(
                img_size=self.img_size,
                patch_size=self.patch_size,
                in_chans=self.in_chans,
                embed_dim=self.embed_dim,
        )
       
        self.encoder = ViTencoder(self.embed_dim, self.n_heads, self.depth)
        
        self.down_factor = 0 # 4, 2
        self.conv_more = Conv2dReLU(self.embed_dim, 128, kernel_size=1, use_batchnorm=True)

    def forward(self, input_ids):
        # embedding_output, features = self.embeddings(input_ids)
        # encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)
        # return encoded, attn_weights, features
        embedding_output = self.patch_embed(input_ids)
        encoded = self.encoder(embedding_output)  # (B, n_patch, hidden)
        print(encoded.shape)
        B, n_patch, hidden = encoded.size()
        l, h = (self.img_size[0] // 2**self.down_factor // self.patch_size[0]), (self.img_size[1] // 2**self.down_factor // self.patch_size[1])
    
        x = encoded.permute(0, 2, 1)
        print(B)
        print(hidden)
        print(l)
        print(h)
        x = x.contiguous().view(B, hidden, l, h)

        x = self.conv_more(x)
        return x

In [328]:
p = PatchEmbedding((32,64), (2,4), 2, 128)
img_size = (32,64)
patch_size = (2,4)

256


In [412]:
encoder = ViT(img_size=(32,64), patch_size=(4,8), in_chans=2, embed_dim=64, depth = 12)

64


In [405]:
(32 // 2**0// 4), (64// 2**0 // 8)

(8, 8)

In [413]:
x = encoder(Xt)

torch.Size([32, 64, 64])
32
64
8
8


In [414]:
x.shape

torch.Size([32, 128, 4, 4])

In [183]:
yy = x.contiguous().view(32, 64, 16, 16)

In [184]:
yy.shape

torch.Size([32, 64, 16, 16])

In [185]:
Conv2dReLU(64, 2, kernel_size=1, use_batchnorm=True)(yy).shape

torch.Size([32, 2, 8, 8])

In [415]:
Xt.shape

torch.Size([32, 2, 32, 64])