# Ozone Transformer Example

## Environments Setup

1. Create Python virtual environment named `venv`
   - `python -m venv venv`
2. Activate the new virtual environment `venv`
   - `source venv/bin/activate`
3. Install required python libraries from `ot_requirements.txt` file
   - `pip install -r ot_requirements.txt`
4. Select `venv` as python kernel for this JupyterNotebook

In [1]:
import functools

from tqdm import tqdm

import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader, TensorDataset


## Prepare Dataset

In [2]:
xr_ot = xr.open_dataset("data/mljc_workshop_o3_L25.nc")

In [3]:
ozone_arr = xr_ot["SpeciesConcVV_O3"].data
ozone_arr_trimmed = ozone_arr[:, :-2, :]

In [4]:
ozone_chunked = ozone_arr_trimmed.reshape(
    ozone_arr_trimmed.shape[0], 
    ozone_arr_trimmed.shape[1] // 11, 11, 
    ozone_arr_trimmed.shape[2] // 12, 12
).swapaxes(2, 3).swapaxes(0, 2)
ozone_chunked = ozone_chunked.reshape(-1, *ozone_chunked.shape[2:])

In [5]:
ozone_chunked_trimmed = ozone_chunked[:, :-20, :, :]

In [6]:
ozone_chunked = ozone_chunked_trimmed.reshape(
    ozone_chunked_trimmed.shape[0], 
    ozone_chunked_trimmed.shape[1] // 100, 100, 
    ozone_chunked_trimmed.shape[2], 
    ozone_chunked_trimmed.shape[3]
)
ozone_chunked = ozone_chunked.reshape(-1, *ozone_chunked.shape[2:])

In [7]:
ozone_chunked = ozone_chunked.reshape(*ozone_chunked.shape[:-2], -1)

In [8]:
MAP_SHAPE = (11, 12)

In [9]:
ozone_chunked.shape

(696, 100, 132)

In [10]:
def shift_tensor_right(x, shift, dim):
    size = x.size(dim)
    if shift >= size:
        # Entire tensor becomes zero
        return torch.zeros_like(x)

    # Slice from start to size - shift along dim
    sliced = x.narrow(dim, 0, size - shift)

    # Pad with zeros at the beginning along that dimension
    pad_shape = list(x.shape)
    pad_shape[dim] = shift
    pad_tensor = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)

    return torch.cat((pad_tensor, sliced), dim=dim)

shift_sequence_dim_right = functools.partial(shift_tensor_right, shift=1, dim=1)


The dimension of the dataset: (batch_size, sequence_size, embedding_size)

We have following dataset:
- input: `ozone_chunked`
- output: `ozone_chunked_shifted`

In [11]:
SRC_train, SRC_test = train_test_split(ozone_chunked, train_size=0.8)

In [12]:
SRC_train = torch.tensor(SRC_train)
SRC_test  = torch.tensor(SRC_test)

In [13]:
TGT_train = shift_sequence_dim_right(SRC_train)
TGT_test = shift_sequence_dim_right(SRC_test)

Divided into train and test dataset (let's ignore validation datasets)
- Train:
    - `SRC_train`
    - `TGT_train`
- Test:
    - `SRC_test`
    - `TGT_test`

In [14]:
dataset_train = TensorDataset(SRC_train, TGT_train)
dataset_test = TensorDataset(SRC_test, TGT_test)

dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=16, shuffle=True)

## Model Setup

In [15]:
def get_sinusoidal_positional_embedding(sequence_size, embedding_size):
    pos = torch.arange(0, sequence_size, dtype=torch.float).unsqueeze(1)
    i = torch.arange(0, embedding_size, 2).float()
    angle_rates = 1 / (10000 ** (i / embedding_size))
    
    angle_rads = pos * angle_rates  # [seq_len, d_model/2]

    pe = torch.zeros(sequence_size, embedding_size)
    pe[:, 0::2] = torch.sin(angle_rads)
    pe[:, 1::2] = torch.cos(angle_rads)

    return pe  # [seq_len, d_model]

In [16]:
EMBEDDING_SIZE = SRC_test.shape[2]
SEQUENCE_SIZE = SRC_train.shape[1]

In [17]:
class TransformerModel(torch.nn.Module):
    def __init__(
            self, 
            embedding_size, 
            n_head, n_encoder_layer, n_decoder_layer, n_feedforward, 
            positional_embedding_func,
            sequence_dim_shift_func,
        ):
        super(TransformerModel, self).__init__()

        self.positional_encoding_func = positional_embedding_func
        self.sequence_dim_shift_func = sequence_dim_shift_func

        self.transformer = torch.nn.Transformer(
            embedding_size, 
            n_head, 
            n_encoder_layer, 
            n_decoder_layer, 
            n_feedforward, 
            batch_first=True
        )
        
    def forward(self, x):
        
        x += self.positional_encoding_func(x.shape[1], x.shape[2])
        x_shifted = self.sequence_dim_shift_func(x)
        x = self.transformer(x, x_shifted)

        return x

## Training

In [18]:
model = TransformerModel(EMBEDDING_SIZE, 12, 6, 6, 2048, get_sinusoidal_positional_embedding, shift_sequence_dim_right)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [19]:
print(model.parameters)

<bound method Module.parameters of TransformerModel(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=132, out_features=132, bias=True)
          )
          (linear1): Linear(in_features=132, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=132, bias=True)
          (norm1): LayerNorm((132,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((132,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((132,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
   

In [20]:
SRC_train.shape

torch.Size([556, 100, 132])

In [21]:
def train_epoch(model, train_loader, loss_fn, optimizer):
    model.train()
    epoch_loss = 0
    epoch_accuracy = 0
    # total_batches = 0
    total_batches = len(train_loader)
    
    for texts, labels in tqdm(train_loader):
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(texts)
        
        # Compute loss and gradients
        loss = loss_fn(outputs, labels)
        loss.backward()
        
        # Update model parameters
        optimizer.step()
        
        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        correct = (preds == labels).sum().item()
        accuracy = correct / labels.size(0)
        
        epoch_loss += loss.item()
        epoch_accuracy += accuracy
    
    return epoch_loss / total_batches, epoch_accuracy / total_batches

def evaluate(model, test_loader, loss_fn):
    model.eval()
    epoch_loss = 0
    epoch_accuracy = 0
    total_batches = len(test_loader)
    
    with torch.no_grad():
        for texts, labels in tqdm(test_loader):
            texts, labels = texts.to(device), labels.to(device)
            # Forward pass
            outputs = model(texts)
            
            # Compute loss
            loss = loss_fn(outputs, labels)
            
            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct = (preds == labels).sum().item()
            accuracy = correct / labels.size(0)
            
            epoch_loss += loss.item()
            epoch_accuracy += accuracy
    
    return epoch_loss / total_batches, epoch_accuracy / total_batches