# Dataset Setup

In [1]:
import zarr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader


store = zarr.open('processed_data_cse151b_v2_corrupted_ssp245.zarr', mode='r')
store.tree()


ModuleNotFoundError: No module named 'zarr'

In [11]:
def train_test_split(data):
    test = data[1]
    
    temp = np.concatenate([data[0], data[3], data[2]])
    train = temp[:-120]
    val = temp[-120:]
    return test, train, val


f1 = store["BC"][:].reshape(4, 1021, 3456)
f2 = store["SO2"][:].reshape(4, 1021, 3456)
f3 = store["rsdt"][:].reshape(4, 1021, 3456)


f4 = store["CO2"][:]
f5 = store["CH4"][:] 
f4 = np.repeat(f4[:, :, np.newaxis], 3456, axis=2)
f5 = np.repeat(f5[:, :, np.newaxis], 3456, axis=2)

f6 = store["time"][:] 
f6 = np.array([f6, f6, f6, f6])
f6 = np.repeat(f6[:, :, np.newaxis], 3456, axis=2)

t1 = store["pr"][:].mean(axis=2).reshape(4, 1021, 3456)
t2 = store["tas"][:].mean(axis=2).reshape(4, 1021, 3456)

data = np.array([f1.T, f2.T, f3.T, f4.T, f5.T, f6.T, t1.T, t2.T]).T

X = np.array([f1.T, f2.T, f3.T, f4.T, f5.T, f6.T]).T
y = t2

X_test, X_train, X_val = train_test_split(X)
y_test, y_train, y_val = train_test_split(y)

# Tranformer Model

In [19]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import numpy as np

# === Positional encoding (add to input) ===
H, W = 48, 72
grid_y, grid_x = np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W), indexing='ij')
pos = np.stack([grid_y, grid_x], axis=-1).reshape(3456, 2)

# Add 2D position to input (X_train and X_val must be defined already)
X_train = np.concatenate([X_train, np.broadcast_to(pos, (X_train.shape[0], 3456, 2))], axis=-1)
X_val = np.concatenate([X_val, np.broadcast_to(pos, (X_val.shape[0], 3456, 2))], axis=-1)

# === Target scaling ===
SCALE_FACTOR = 300.0
y_train_scaled = y_train / SCALE_FACTOR
y_val_scaled = y_val / SCALE_FACTOR

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_scaled, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val_scaled, dtype=torch.float32)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# === Model ===
class PositionWiseTransformer(nn.Module):
    def __init__(self, input_dim=8, model_dim=64, num_heads=2, num_layers=1, dropout=0.1):
        super().__init__()
        self.cnn = nn.Conv1d(input_dim, model_dim, kernel_size=5, padding=2)
        self.norm = nn.LayerNorm(model_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=128,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_proj = nn.Linear(model_dim, 1)

    def forward(self, x):
        x = x.transpose(1, 2)         # (B, 8, 3456)
        x = self.cnn(x)               # (B, model_dim, 3456)
        x = x.transpose(1, 2)         # (B, 3456, model_dim)
        x = self.norm(x)
        x = self.transformer(x)
        x = self.output_proj(x)
        return x.squeeze(-1)          # (B, 3456)

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PositionWiseTransformer().to(device)

# Xavier init
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# LR Warmup scheduler
def lr_lambda(step):
    warmup_steps = 200
    return min((step + 1) / warmup_steps, 1.0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

torch.cuda.empty_cache()

# === Training ===
step = 0
for epoch in range(30):
    model.train()
    train_loss = 0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1} - Train"):
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        scheduler.step()
        step += 1
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1} - Val"):
            xb, yb = xb.to(device), yb.to(device)
            val_preds = model(xb)
            val_loss += criterion(val_preds, yb).item()

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1} | Train RMSE: {SCALE_FACTOR * avg_train_loss**0.5:.2f} | Val RMSE: {SCALE_FACTOR * avg_val_loss**0.5:.2f}")



Epoch 1 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.99it/s]
Epoch 1 - Val: 100%|██████████| 60/60 [00:00<00:00, 306.56it/s]


Epoch 1 | Train RMSE: 110.82 | Val RMSE: 20.76


Epoch 2 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.97it/s]
Epoch 2 - Val: 100%|██████████| 60/60 [00:00<00:00, 308.43it/s]


Epoch 2 | Train RMSE: 35.16 | Val RMSE: 21.85


Epoch 3 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.97it/s]
Epoch 3 - Val: 100%|██████████| 60/60 [00:00<00:00, 307.59it/s]


Epoch 3 | Train RMSE: 26.57 | Val RMSE: 22.89


Epoch 4 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.96it/s]
Epoch 4 - Val: 100%|██████████| 60/60 [00:00<00:00, 305.42it/s]


Epoch 4 | Train RMSE: 23.58 | Val RMSE: 21.46


Epoch 5 - Train: 100%|██████████| 1472/1472 [03:05<00:00,  7.94it/s]
Epoch 5 - Val: 100%|██████████| 60/60 [00:00<00:00, 295.35it/s]


Epoch 5 | Train RMSE: 21.95 | Val RMSE: 21.01


Epoch 6 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.97it/s]
Epoch 6 - Val: 100%|██████████| 60/60 [00:00<00:00, 307.25it/s]


Epoch 6 | Train RMSE: 20.37 | Val RMSE: 19.43


Epoch 7 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.97it/s]
Epoch 7 - Val: 100%|██████████| 60/60 [00:00<00:00, 304.86it/s]


Epoch 7 | Train RMSE: 18.64 | Val RMSE: 17.98


Epoch 8 - Train: 100%|██████████| 1472/1472 [03:04<00:00,  7.97it/s]
Epoch 8 - Val: 100%|██████████| 60/60 [00:00<00:00, 301.15it/s]


Epoch 8 | Train RMSE: 17.54 | Val RMSE: 18.08


Epoch 9 - Train:   6%|▌         | 83/1472 [00:10<02:56,  7.87it/s]


KeyboardInterrupt: 

# Vision Tranformer

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

# === Parameters ===
H, W = 48, 72
PATCH_H, PATCH_W = 6, 6
NUM_PATCHES = (H // PATCH_H) * (W // PATCH_W)
PATCH_SIZE = PATCH_H * PATCH_W
IN_CHANNELS = 8
PATCH_DIM = PATCH_SIZE * IN_CHANNELS
EMBED_DIM = 128
OUT_DIM = 2  # pr and tas

# === Dataset setup (same as before, with positional encoding) ===
grid_y, grid_x = np.meshgrid(np.linspace(0, 1, H), np.linspace(0, 1, W), indexing='ij')
pos = np.stack([grid_y, grid_x], axis=-1).reshape(H * W, 2)
X_train = np.concatenate([X_train, np.broadcast_to(pos, (X_train.shape[0], H * W, 2))], axis=-1)
X_val = np.concatenate([X_val, np.broadcast_to(pos, (X_val.shape[0], H * W, 2))], axis=-1)

SCALE_FACTOR = 300.0
y_train_scaled = y_train / SCALE_FACTOR
y_val_scaled = y_val / SCALE_FACTOR

X_train_tensor = torch.tensor(X_train, dtype=torch.float32).reshape(-1, H, W, IN_CHANNELS)
y_train_tensor = torch.tensor(y_train_scaled, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).reshape(-1, H, W, IN_CHANNELS)
y_val_tensor = torch.tensor(y_val_scaled, dtype=torch.float32)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)

# === ViT-style model ===
class ViTRegression(nn.Module):
    def __init__(self, in_channels=8, patch_size=6, embed_dim=128, out_dim=2, image_size=(48, 72)):
        super().__init__()
        self.patch_h, self.patch_w = patch_size, patch_size
        self.H, self.W = image_size
        self.num_patches = (self.H // self.patch_h) * (self.W // self.patch_w)
        patch_dim = in_channels * patch_size * patch_size

        self.proj = nn.Linear(patch_dim, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(self.num_patches, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=4, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        self.regressor = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, out_dim)  # output: pr and tas
        )

    def forward(self, x):
        B, H, W, C = x.shape
        assert H == self.H and W == self.W

        # Divide into patches
        x = x.unfold(1, self.patch_h, self.patch_h).unfold(2, self.patch_w, self.patch_w)
        x = x.contiguous().view(B, -1, C, self.patch_h, self.patch_w)  # (B, N_patches, C, 6, 6)
        x = x.flatten(3).flatten(2)  # (B, N_patches, C*6*6)

        x = self.proj(x) + self.pos_embed  # (B, N_patches, D)
        x = self.transformer(x)
        out = self.regressor(x)  # (B, N_patches, 2)
        return out.view(B, -1, 2)

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTRegression().to(device)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
model.apply(init_weights)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

def lr_lambda(step):
    warmup_steps = 200
    return min((step + 1) / warmup_steps, 1.0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# === Training loop ===
step = 0
for epoch in range(30):
    model.train()
    train_loss = 0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch+1} - Train"):
        xb, yb = xb.to(device), yb.to(device).view(xb.shape[0], -1, 2)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        scheduler.step()
        step += 1
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for xb, yb in tqdm(val_loader, desc=f"Epoch {epoch+1} - Val"):
            xb, yb = xb.to(device), yb.to(device).view(xb.shape[0], -1, 2)
            val_preds = model(xb)
            val_loss += criterion(val_preds, yb).item()

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    rmse_train = SCALE_FACTOR * avg_train_loss**0.5
    rmse_val = SCALE_FACTOR * avg_val_loss**0.5
    print(f"Epoch {epoch+1} | Train RMSE: {rmse_train:.2f} | Val RMSE: {rmse_val:.2f}")


# Result Checker

In [1]:
import matplotlib.pyplot as plt
import seaborn as sns


x_sample = X_val_tensor[0].unsqueeze(0).to(device)  # (1, 3456, 6)
y_true = y_val_tensor[0].cpu().numpy().reshape(48, 72)

# Predict
model.eval()
with torch.no_grad():
    y_pred = model(x_sample).cpu().numpy().squeeze().reshape(48, 72)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

sns.heatmap(y_true, ax=axes[0], cmap="viridis")
axes[0].set_title("Ground Truth")

sns.heatmap(y_pred, ax=axes[1], cmap="viridis")
axes[1].set_title("Model Prediction")

plt.tight_layout()
plt.show()
