In [None]:
# prompt: mount

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!pip install imagecodecs --quiet
!pip install timm torchvision --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.6/45.6 MB[0m [31m51.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m93.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m91.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m56.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Cell 2 ▶ Consolidated imports
import os
import math
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import timm
from torchvision import transforms

from tifffile import imread


In [None]:
# Cell 1 ▶ Dataset (224×224 targets, no down-scaling + 6-h sequence)
class LSTDataset(Dataset):
    def __init__(self, df, patches_dir, weather_cols, n_hours=6):
        self.df           = df.reset_index(drop=True)
        self.patches_dir  = patches_dir
        self.weather_cols = weather_cols
        self.n_hours      = n_hours
        self.n_vars       = len(weather_cols) // n_hours
        self.transform    = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485,0.456,0.406],
                std =[0.229,0.224,0.225],
            ),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row  = self.df.loc[idx]
        arr  = imread(os.path.join(self.patches_dir, row["patch_filename"])
                     ).astype(np.float32)            # (4,H,W)

        # ── image input ───────────────────────────────────
        img_np = arr[[1,2,3]].transpose(1,2,0).astype(np.uint8)
        img    = self.transform(img_np)              # [3,224,224]

        # ── target LST ───────────────────────────────────
        tar_np = arr[0]                              # (H,W)
        tar    = torch.tensor(tar_np, dtype=torch.float32).unsqueeze(0)
        tar    = F.interpolate(
            tar.unsqueeze(0),
            size=(224,224),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)                                 # [1,224,224]

        # ── meteorological sequence [6×5] ───────────────
        w_flat = row[self.weather_cols].values.astype(np.float32)
        w_seq  = torch.from_numpy(w_flat).view(self.n_hours, self.n_vars)

        return img, w_seq, tar


In [None]:
# Cell 3 ▶ Load merged patch+6h-meteo CSV & build DataLoaders
csv_path    = "/content/drive/MyDrive/patch_with_meteo_last6h.csv"
patches_dir = "/content/drive/MyDrive/PatchedOutput_Cleaned"

# 1) read & drop truly missing
df = pd.read_csv(csv_path, parse_dates=['date'])
df = df.dropna(subset=['patch_filename','date']).reset_index(drop=True)

# 2) pick exactly the 6h-lag columns
seq_cols = [c for c in df.columns if "_t-" in c]
assert len(seq_cols) == 6*5, f"expected 30 lag cols, got {len(seq_cols)}"

print("Using 6-hour sequence cols:", seq_cols[:5], "... total =", len(seq_cols))

# 3) build dataset with those only
dataset = LSTDataset(df, patches_dir, seq_cols, n_hours=6)

# 4) split & loaders
n_train = int(0.8 * len(dataset))
train_ds, val_ds = random_split(dataset, [n_train, len(dataset)-n_train])

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,
                          num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False,
                          num_workers=0, pin_memory=False)


Using 6-hour sequence cols: ['air_temp_C_t-5h', 'dew_point_C_t-5h', 'relative_humidity_percent_t-5h', 'wind_speed_m_s_t-5h', 'precipitation_in_t-5h'] ... total = 30


In [None]:
# Cell 4 ▶ ViT + LSTM fusion → 224×224 decoder
class PretrainedViTLSTModel(nn.Module):
    def __init__(self,
                 weather_dim=5,      # vars per time step
                 hidden_dim=768,
                 vit_name="vit_base_patch16_224",
                 lstm_layers=1,
                 lstm_dropout=0.1,
                 num_transformer_layers=2,
                 num_heads=8):
        super().__init__()
        # ViT backbone
        self.vit = timm.create_model(vit_name, pretrained=True, num_classes=0)
        for p in self.vit.parameters():
            p.requires_grad = False

        # LSTM for weather sequence → one embedding
        self.weather_encoder = nn.LSTM(
            input_size=weather_dim,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=lstm_dropout if lstm_layers > 1 else 0.0
        )

        # fusion transformer
        enc = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(enc, num_transformer_layers)

        # decoder up-sample 14→28→56→112→224
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim,   hidden_dim // 2, 2, 2),
            nn.BatchNorm2d(hidden_dim // 2), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, 2, 2),
            nn.BatchNorm2d(hidden_dim // 4), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(hidden_dim // 4, hidden_dim // 8, 2, 2),
            nn.BatchNorm2d(hidden_dim // 8), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(hidden_dim // 8,      1,         2, 2),
        )

    def forward(self, images, weather_seq):
        # ViT → [B,197,768]
        feats   = self.vit.forward_features(images)
        cls_tok = feats[:, :1]       # [B,1,768]
        patch_t = feats[:, 1:]       # [B,196,768]

        # LSTM → (outputs, (h_n, c_n)); take last layer's h_n
        _, (h_n, _) = self.weather_encoder(weather_seq)
        w_tok = h_n[-1].unsqueeze(1) # [B,1,hidden_dim]

        # concat → [B,198,hidden_dim]
        tokens = torch.cat([patch_t, w_tok, cls_tok], dim=1)

        # transformer expects [seq, batch, dim]
        t = self.transformer(tokens.permute(1, 0, 2)).permute(1, 0, 2)
        patch_out = t[:, :-2, :]     # drop weather + CLS → [B,196,dim]

        # reshape 196 → 14×14
        B, N, D = patch_out.size()
        G = int(math.sqrt(N))
        x = patch_out.transpose(1, 2).view(B, D, G, G)

        return self.deconv(x)        # [B,1,224,224]


In [None]:
# Cell 5 ▶ Instantiate & unfreeze last ViT layers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PretrainedViTLSTModel(
    weather_dim            = len(seq_cols) // 6,   # seq_cols was defined in Cell 3
    hidden_dim             = 768,
    vit_name               = "vit_base_patch16_224",
    lstm_layers            = 1,                   # keep as desired
    lstm_dropout           = 0.1,
    num_transformer_layers = 2,                   # matches __init__ signature
    num_heads              = 8
).to(device)

# unfreeze final ViT blocks
for name, p in model.vit.named_parameters():
    if any(layer in name for layer in ["blocks.10", "blocks.11", "norm"]):
        p.requires_grad = True




In [None]:
# Cell 6 ▶ Optimizer, loss & scheduler
opt       = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4, weight_decay=1e-2
)
loss_fn   = nn.SmoothL1Loss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt, mode='min', factor=0.5, patience=3, verbose=True
)




In [None]:
import torch
from tqdm import tqdm
from pathlib import Path
import math

# — optionally re-init your model, optimizer & scheduler here —
# model     = ViT_LSTM(...).to(device)
# opt       = torch.optim.Adam(model.parameters(), lr=1e-4)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=5)

num_epochs = 20
save_dir = Path("/content/drive/MyDrive/Model_vit_lstm_Checkpoints")
save_dir.mkdir(parents=True, exist_ok=True)

for epoch in range(num_epochs):
    # —— TRAIN ——
    model.train()
    train_loss = 0.0
    seen       = 0
    train_bar  = tqdm(train_loader, desc=f"Epoch {epoch+1:02d} Train", unit="batch")
    for imgs, w_seq, tgt in train_bar:
        imgs, w_seq, tgt = imgs.to(device), w_seq.to(device), tgt.to(device)
        opt.zero_grad()
        out      = model(imgs, w_seq)
        loss     = loss_fn(out, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        bsz = imgs.size(0)
        train_loss += loss.item() * bsz
        seen      += bsz
        train_bar.set_postfix(
            batch_loss=f"{loss.item():.4f}",
            avg_loss  =f"{train_loss/seen:.4f}"
        )

    train_rmse = math.sqrt(train_loss / len(train_loader.dataset))

    # —— VALIDATE ——
    model.eval()
    val_loss = 0.0
    seen     = 0
    val_bar  = tqdm(val_loader, desc=f"Epoch {epoch+1:02d}   Val", unit="batch")
    with torch.no_grad():
        for imgs, w_seq, tgt in val_bar:
            imgs, w_seq, tgt = imgs.to(device), w_seq.to(device), tgt.to(device)
            out      = model(imgs, w_seq)
            l_val    = loss_fn(out, tgt).item()
            bsz      = imgs.size(0)
            val_loss += l_val * bsz
            seen     += bsz
            val_bar.set_postfix(
                batch_loss=f"{l_val:.4f}",
                avg_loss  =f"{val_loss/seen:.4f}"
            )

    val_rmse = math.sqrt(val_loss / len(val_loader.dataset))
    scheduler.step(val_loss)   # or scheduler.step() if epoch‐based

    print(f"Epoch {epoch+1:02d} ▶ Train RMSE: {train_rmse:.3f} | Val RMSE: {val_rmse:.3f}")

    # —— SAVE CHECKPOINT ——
    ckpt = {
        'epoch': epoch+1,
        'model_state_dict':      model.state_dict(),
        'optimizer_state_dict':  opt.state_dict(),
        'scheduler_state_dict':  scheduler.state_dict(),
        'train_rmse':            train_rmse,
        'val_rmse':              val_rmse
    }
    torch.save(ckpt, save_dir/f"vit_lstm_epoch{epoch+1:02d}.pt")
    print("✅ Saved checkpoint.")

print("✅ Training complete (0 → 20 epochs)")


Epoch 01 Train: 100%|██████████| 2295/2295 [02:34<00:00, 14.82batch/s, avg_loss=1.3455, batch_loss=0.3028]
Epoch 01   Val: 100%|██████████| 574/574 [00:28<00:00, 20.13batch/s, avg_loss=0.9677, batch_loss=0.8528]


Epoch 01 ▶ Train RMSE: 1.160 | Val RMSE: 0.984
✅ Saved checkpoint.


Epoch 02 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.56batch/s, avg_loss=0.5450, batch_loss=0.3008]
Epoch 02   Val: 100%|██████████| 574/574 [00:28<00:00, 20.27batch/s, avg_loss=0.4168, batch_loss=0.0811]


Epoch 02 ▶ Train RMSE: 0.738 | Val RMSE: 0.646
✅ Saved checkpoint.


Epoch 03 Train: 100%|██████████| 2295/2295 [02:50<00:00, 13.46batch/s, avg_loss=0.3595, batch_loss=0.4615]
Epoch 03   Val: 100%|██████████| 574/574 [00:28<00:00, 20.19batch/s, avg_loss=0.2523, batch_loss=0.1417]


Epoch 03 ▶ Train RMSE: 0.600 | Val RMSE: 0.502
✅ Saved checkpoint.


Epoch 04 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.52batch/s, avg_loss=0.2902, batch_loss=0.1785]
Epoch 04   Val: 100%|██████████| 574/574 [00:28<00:00, 19.98batch/s, avg_loss=0.2330, batch_loss=0.0539]


Epoch 04 ▶ Train RMSE: 0.539 | Val RMSE: 0.483
✅ Saved checkpoint.


Epoch 05 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.51batch/s, avg_loss=0.2472, batch_loss=0.8123]
Epoch 05   Val: 100%|██████████| 574/574 [00:28<00:00, 20.11batch/s, avg_loss=0.1844, batch_loss=0.0668]


Epoch 05 ▶ Train RMSE: 0.497 | Val RMSE: 0.429
✅ Saved checkpoint.


Epoch 06 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.50batch/s, avg_loss=0.2134, batch_loss=0.0727]
Epoch 06   Val: 100%|██████████| 574/574 [00:28<00:00, 20.15batch/s, avg_loss=0.1463, batch_loss=0.0511]


Epoch 06 ▶ Train RMSE: 0.462 | Val RMSE: 0.383
✅ Saved checkpoint.


Epoch 07 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.52batch/s, avg_loss=0.1871, batch_loss=0.0211]
Epoch 07   Val: 100%|██████████| 574/574 [00:28<00:00, 20.20batch/s, avg_loss=0.2056, batch_loss=0.0316]


Epoch 07 ▶ Train RMSE: 0.433 | Val RMSE: 0.453
✅ Saved checkpoint.


Epoch 08 Train: 100%|██████████| 2295/2295 [02:50<00:00, 13.43batch/s, avg_loss=0.1817, batch_loss=0.3791]
Epoch 08   Val: 100%|██████████| 574/574 [00:28<00:00, 19.98batch/s, avg_loss=0.1513, batch_loss=0.0412]


Epoch 08 ▶ Train RMSE: 0.426 | Val RMSE: 0.389
✅ Saved checkpoint.


Epoch 09 Train: 100%|██████████| 2295/2295 [02:50<00:00, 13.48batch/s, avg_loss=0.1594, batch_loss=0.0447]
Epoch 09   Val: 100%|██████████| 574/574 [00:28<00:00, 20.16batch/s, avg_loss=0.1389, batch_loss=0.0610]


Epoch 09 ▶ Train RMSE: 0.399 | Val RMSE: 0.373
✅ Saved checkpoint.


Epoch 10 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.52batch/s, avg_loss=0.1603, batch_loss=0.0541]
Epoch 10   Val: 100%|██████████| 574/574 [00:28<00:00, 20.07batch/s, avg_loss=0.1247, batch_loss=0.0363]


Epoch 10 ▶ Train RMSE: 0.400 | Val RMSE: 0.353
✅ Saved checkpoint.


Epoch 11 Train: 100%|██████████| 2295/2295 [02:49<00:00, 13.51batch/s, avg_loss=0.1411, batch_loss=0.2063]
Epoch 11   Val: 100%|██████████| 574/574 [00:28<00:00, 19.97batch/s, avg_loss=0.1437, batch_loss=0.0359]


Epoch 11 ▶ Train RMSE: 0.376 | Val RMSE: 0.379
✅ Saved checkpoint.


Epoch 12 Train: 100%|██████████| 2295/2295 [02:37<00:00, 14.61batch/s, avg_loss=0.1378, batch_loss=0.0435]
Epoch 12   Val: 100%|██████████| 574/574 [00:28<00:00, 20.09batch/s, avg_loss=0.1154, batch_loss=0.0448]


Epoch 12 ▶ Train RMSE: 0.371 | Val RMSE: 0.340
✅ Saved checkpoint.


Epoch 13 Train: 100%|██████████| 2295/2295 [02:37<00:00, 14.58batch/s, avg_loss=0.1310, batch_loss=0.0814]
Epoch 13   Val: 100%|██████████| 574/574 [00:28<00:00, 20.10batch/s, avg_loss=0.1078, batch_loss=0.0333]


Epoch 13 ▶ Train RMSE: 0.362 | Val RMSE: 0.328
✅ Saved checkpoint.


Epoch 14 Train: 100%|██████████| 2295/2295 [02:37<00:00, 14.56batch/s, avg_loss=0.1243, batch_loss=0.0346]
Epoch 14   Val: 100%|██████████| 574/574 [00:28<00:00, 20.14batch/s, avg_loss=0.0932, batch_loss=0.0388]


Epoch 14 ▶ Train RMSE: 0.353 | Val RMSE: 0.305
✅ Saved checkpoint.


Epoch 15 Train: 100%|██████████| 2295/2295 [02:37<00:00, 14.59batch/s, avg_loss=0.1176, batch_loss=0.0403]
Epoch 15   Val: 100%|██████████| 574/574 [00:28<00:00, 20.16batch/s, avg_loss=0.0980, batch_loss=0.0298]


Epoch 15 ▶ Train RMSE: 0.343 | Val RMSE: 0.313
✅ Saved checkpoint.


Epoch 16 Train: 100%|██████████| 2295/2295 [02:37<00:00, 14.58batch/s, avg_loss=0.1190, batch_loss=0.0167]
Epoch 16   Val: 100%|██████████| 574/574 [00:29<00:00, 19.78batch/s, avg_loss=0.0765, batch_loss=0.0405]


Epoch 16 ▶ Train RMSE: 0.345 | Val RMSE: 0.277
✅ Saved checkpoint.


Epoch 17 Train: 100%|██████████| 2295/2295 [02:39<00:00, 14.37batch/s, avg_loss=0.1113, batch_loss=0.0169]
Epoch 17   Val: 100%|██████████| 574/574 [00:28<00:00, 19.84batch/s, avg_loss=0.1145, batch_loss=0.0315]


Epoch 17 ▶ Train RMSE: 0.334 | Val RMSE: 0.338
✅ Saved checkpoint.


Epoch 18 Train: 100%|██████████| 2295/2295 [02:36<00:00, 14.64batch/s, avg_loss=0.1063, batch_loss=0.0387]
Epoch 18   Val: 100%|██████████| 574/574 [00:28<00:00, 20.15batch/s, avg_loss=0.0954, batch_loss=0.0418]


Epoch 18 ▶ Train RMSE: 0.326 | Val RMSE: 0.309
✅ Saved checkpoint.


Epoch 19 Train: 100%|██████████| 2295/2295 [02:36<00:00, 14.63batch/s, avg_loss=0.1004, batch_loss=3.9148]
Epoch 19   Val: 100%|██████████| 574/574 [00:28<00:00, 20.05batch/s, avg_loss=0.1944, batch_loss=0.0678]


Epoch 19 ▶ Train RMSE: 0.317 | Val RMSE: 0.441
✅ Saved checkpoint.


Epoch 20 Train: 100%|██████████| 2295/2295 [02:37<00:00, 14.62batch/s, avg_loss=0.1015, batch_loss=0.0807]
Epoch 20   Val: 100%|██████████| 574/574 [00:28<00:00, 20.18batch/s, avg_loss=0.0864, batch_loss=0.0289]


Epoch 20 ▶ Train RMSE: 0.319 | Val RMSE: 0.294
✅ Saved checkpoint.
✅ Training complete (0 → 20 epochs)
