# CAFormer TPU training

In [None]:
import torch
#if not torch.cuda.is_available():
#    raise RuntimeError("Requires GPUs with CUDA enabled.")
try: 
    import monai
except: 
    !pip install --no-deps monai -q

!pip uninstall -y tensorflow
!pip install tensorflow-cpu

In [None]:
%%writefile _cfg.py
from types import SimpleNamespace
import torch

cfg= SimpleNamespace()
cfg.seed = 3
cfg.subsample = None
cfg.ema = True
cfg.ema_decay = 0.99
cfg.backbone = "caformer_b36.sail_in22k_ft_in1k"
cfg.epochs = 5
cfg.batch_size = 32
cfg.RUN_VALID = True
cfg.RUN_TEST  = True
cfg.RUN_TRAIN = True
cfg.RUN_TRAIN_ALL = True

# Scheduler

I have removed the training loop from this notebook, though it is the same as previous notebooks. 

The only difference was the use of a custom learning rate scheduler. The scheduler uses a constant learning rate followed by a cosine annealing learning rate. It seems that a learning rate of 1e-4 works well at the beggining, but a lower learning rate is required to achieve lower training and validation MAE.

In [None]:
%%writefile _scheduler.py
import math

from torch.optim.lr_scheduler import _LRScheduler

class ConstantCosineLR(_LRScheduler):
    """
    Constant learning rate followed by CosineAnnealing.
    """
    def __init__(
        self, 
        optimizer,
        total_steps, 
        pct_cosine, 
        last_epoch=-1,
        ):
        self.total_steps = total_steps
        self.milestone = int(total_steps * (1 - pct_cosine))
        self.cosine_steps = max(total_steps - self.milestone, 1)
        self.min_lr = 0
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = self.last_epoch + 1
        if step <= self.milestone:
            factor = 1.0
        else:
            s = step - self.milestone
            factor = 0.5 * (1 + math.cos(math.pi * s / self.cosine_steps))
        return [lr * factor for lr in self.base_lrs]

In [None]:
import torch
import matplotlib.pyplot as plt
from _scheduler import ConstantCosineLR

# Dummy model
n_steps = 10_000
model = torch.nn.Linear(1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)

# Scheduler
scheduler = ConstantCosineLR(optimizer, total_steps=n_steps, pct_cosine=0.5)

# Get LRs
arr = []
for _ in range(n_steps):
    scheduler.step()
    arr.append(optimizer.param_groups[0]['lr'])

plt.plot(arr)
plt.xlabel("Step")
plt.ylabel("LR")
plt.title("ConstantCosineLR")
plt.show()


## Dataset

In [None]:
%%writefile _dataset.py

import os
import glob

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(
        self, 
        cfg,
        mode = "train", 
    ):
        self.cfg = cfg
        self.mode = mode
        
        self.data, self.labels, self.records = self.load_metadata()

    def load_metadata(self, ):

        # Select rows
        df= pd.read_csv("/kaggle/input/openfwi-preprocessed-72x72/folds.csv")
        if self.cfg.subsample is not None:
            df= df.groupby(["dataset", "fold"]).head(self.cfg.subsample)

        if self.mode == "train":
            df= df[df["fold"] != 0]
        else:
            df= df[df["fold"] == 0]

        
        data = []
        labels = []
        records = []
        mmap_mode = "r"

        for idx, row in tqdm(df.iterrows(), total=len(df)):
            row= row.to_dict()

            # Hacky way to get exact file name
            p1 = os.path.join("/kaggle/input/open-wfi-1/openfwi_float16_1/", row["data_fpath"])
            p2 = os.path.join("/kaggle/input/open-wfi-1/openfwi_float16_1/", row["data_fpath"].split("/")[0], "*", row["data_fpath"].split("/")[-1])
            p3 = os.path.join("/kaggle/input/open-wfi-2/openfwi_float16_2/", row["data_fpath"])
            p4 = os.path.join("/kaggle/input/open-wfi-2/openfwi_float16_2/", row["data_fpath"].split("/")[0], "*", row["data_fpath"].split("/")[-1])
            farr= glob.glob(p1) + glob.glob(p2) + glob.glob(p3) + glob.glob(p4)
            
            # Map to lbl fpath
            farr= farr[0]
            flbl= farr.replace('seis', 'vel').replace('data', 'model')
            
            # Load
            arr= np.load(farr, mmap_mode=mmap_mode)
            lbl= np.load(flbl, mmap_mode=mmap_mode)

            # Append
            data.append(arr)
            labels.append(lbl)
            records.append(row["dataset"])

        return data, labels, records

    def __getitem__(self, idx):
        row_idx= idx // 500
        col_idx= idx % 500

        d= self.records[row_idx]
        x= self.data[row_idx][col_idx, ...]
        y= self.labels[row_idx][col_idx, ...]

        # Augs 
        if self.mode == "train":
            
            # Temporal flip
            if np.random.random() < 0.5:
                x= x[::-1, :, ::-1]
                y= y[..., ::-1]

        x= x.copy()
        y= y.copy()
        
        return x, y

    def __len__(self, ):
        return len(self.records) * 500



# Model

This time we use the `CAFormer` backbone from timm. See more info on this backbone [here](https://huggingface.co/timm/caformer_b36.sail_in22k_ft_in1k) and the original paper [here](https://arxiv.org/abs/2210.13452).


### Encoder

Like with Convnext, we modify the encoder so that the feature maps are aligned with the target output shape. I think there is room for improvement at the `nn.ReflectionPad2d` step. Currently, the model uses lots of padding here and I am afraid the detail in the shallowest feature map is lacking.

### Decoder

The biggest changes in this notebook are to the decoder. 

First, we use PixelShuffle for upsampling. Pixelshuffle typically works well when fine detail is important, though it is more computatially expensive. Second, we add SCSE blocks. These are commonly used to increase decoder capacity with a minimal increase in parameter count and runtime. Finally, we add intermediate convolutions between the encoder output and decoder blocks. I beleive this trick was first introduced on Kaggle in the 3rd place solution of the Contrails Competition [here](https://www.kaggle.com/competitions/google-research-identify-contrails-redu), and also increases decoder capacity.

In [None]:
%%writefile _model.py

from copy import deepcopy
from types import MethodType

import torch
import torch.nn as nn
import torch.nn.functional as F

import timm

from monai.networks.blocks import UpSample, SubpixelUpsample

####################
## EMA + Ensemble ##
####################

class ModelEMA(nn.Module):
    def __init__(self, model, decay=0.99, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)


class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = nn.ModuleList(models).eval()

    def forward(self, x):
        output = []

        for m in self.models:
            logits = m(x)

            output.append(logits)

        output = torch.stack(output)
        output = torch.quantile(output, 0.5, dim=0)
        return output
        

#############
## Decoder ##
#############

class ConvBnAct2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding: int = 0,
        stride: int = 1,
        norm_layer: nn.Module = nn.Identity,
        act_layer: nn.Module = nn.ReLU,
    ):
        super().__init__()

        self.conv= nn.Conv2d(
            in_channels, 
            out_channels,
            kernel_size,
            stride=stride, 
            padding=padding, 
            bias=False,
        )
        self.norm = norm_layer(out_channels) if norm_layer != nn.Identity else nn.Identity()
        self.act= act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class SCSEModule2d(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.Tanh(),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        ) # Output [B, C, 1, 1]
        self.sSE = nn.Sequential(
            nn.Conv2d(in_channels, 1, 1), 
            nn.Sigmoid(),
            ) # Output [B, 1, H, W]

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)


class Attention2d(nn.Module):
    def __init__(self, name, **params):
        super().__init__()
        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule2d(**params)
        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)


class DecoderBlock2d(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False,
        upsample_mode: str = "deconv",
        scale_factor: int = 2,
    ):
        super().__init__()

        # Upsample block
        if upsample_mode == "pixelshuffle":
            self.upsample= SubpixelUpsample(
                spatial_dims= 2,
                in_channels= in_channels,
                scale_factor= scale_factor,
            )
        else:
            self.upsample = UpSample(
                spatial_dims= 2,
                in_channels= in_channels,
                out_channels= in_channels,
                scale_factor= scale_factor,
                mode= upsample_mode,
            )

        if intermediate_conv:
            k= 3
            c= skip_channels if skip_channels != 0 else in_channels
            self.intermediate_conv = nn.Sequential(
                ConvBnAct2d(c, c, k, k//2),
                ConvBnAct2d(c, c, k, k//2),
                )
        else:
            self.intermediate_conv= None

        self.attention1 = Attention2d(
            name= attention_type, 
            in_channels= in_channels + skip_channels,
            )

        self.conv1 = ConvBnAct2d(
            in_channels + skip_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
        )

        self.conv2 = ConvBnAct2d(
            out_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
        )
        self.attention2 = Attention2d(
            name= attention_type, 
            in_channels= out_channels,
            )

    def forward(self, x, skip=None):
        x = self.upsample(x)

        if self.intermediate_conv is not None:
            if skip is not None:
                skip = self.intermediate_conv(skip)
            else:
                x = self.intermediate_conv(x)

        if skip is not None:
            # print(x.shape, skip.shape)
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class UnetDecoder2d(nn.Module):
    """
    Unet decoder.
    Source: https://arxiv.org/abs/1505.04597
    """
    def __init__(
        self,
        encoder_channels: tuple[int],
        skip_channels: tuple[int] = None,
        decoder_channels: tuple = (256, 128, 64, 32),
        scale_factors: tuple = (2,2,2,2),
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = "scse",
        intermediate_conv: bool = True,
        upsample_mode: str = "pixelshuffle",
    ):
        super().__init__()
        
        if len(encoder_channels) == 4:
            decoder_channels= decoder_channels[1:]
        self.decoder_channels= decoder_channels
        
        if skip_channels is None:
            skip_channels= list(encoder_channels[1:]) + [0]

        # Build decoder blocks
        in_channels= [encoder_channels[0]] + list(decoder_channels[:-1])
        self.blocks = nn.ModuleList()

        for i, (ic, sc, dc) in enumerate(zip(in_channels, skip_channels, decoder_channels)):
            # print(i, ic, sc, dc)
            self.blocks.append(
                DecoderBlock2d(
                    ic, sc, dc, 
                    norm_layer= norm_layer,
                    attention_type= attention_type,
                    intermediate_conv= intermediate_conv,
                    upsample_mode= upsample_mode,
                    scale_factor= scale_factors[i],
                    )
            )

    def forward(self, feats: list[torch.Tensor]):
        res= [feats[0]]
        feats= feats[1:]

        # Decoder blocks
        for i, b in enumerate(self.blocks):
            skip= feats[i] if i < len(feats) else None
            res.append(
                b(res[-1], skip=skip),
                )
            
        return res
        
class SegmentationHead2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        scale_factor: tuple[int] = (2,2),
        kernel_size: int = 3,
        mode: str = "nontrainable",
    ):
        super().__init__()
        self.conv= nn.Conv2d(
            in_channels, out_channels, kernel_size= kernel_size,
            padding= kernel_size//2
        )
        self.upsample = UpSample(
            spatial_dims= 2,
            in_channels= out_channels,
            out_channels= out_channels,
            scale_factor= scale_factor,
            mode= mode,
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        return x
        

#############
## Encoder ##
#############

class Net(nn.Module):
    def __init__(
        self,
        backbone: str,
        pretrained: bool = True,
    ):
        super().__init__()
        
        # Encoder
        self.backbone= timm.create_model(
            backbone,
            in_chans= 5,
            pretrained= pretrained,
            features_only= True,
            drop_path_rate=0.0,
            )
        ecs= [_["num_chs"] for _ in self.backbone.feature_info][::-1]

        # Decoder
        self.decoder= UnetDecoder2d(
            encoder_channels= ecs,
        )

        self.seg_head= SegmentationHead2d(
            in_channels= self.decoder.decoder_channels[-1],
            out_channels= 1,
            scale_factor= 1,
        )
        
        self._update_stem(backbone)

    def _update_stem(self, backbone):
        m = self.backbone

        m.stem.conv.stride=(4,1)
        m.stem.conv.padding=(0,4)
        m.stages_0.downsample = nn.AvgPool2d(kernel_size=(4,1), stride=(4,1))
        m.stem= nn.Sequential(
            nn.ReflectionPad2d((0,0,78,78)),
            m.stem,
        )

        pass

        
    def proc_flip(self, x_in):
        x_in= torch.flip(x_in, dims=[-3, -1])
        x= self.backbone(x_in)
        x= x[::-1]

        # Decoder
        x= self.decoder(x)
        x_seg= self.seg_head(x[-1])
        x_seg= x_seg[..., 1:-1, 1:-1]
        x_seg= torch.flip(x_seg, dims=[-1])
        x_seg= x_seg * 1500 + 3000
        return x_seg

    def forward(self, batch):
        x= batch

        # Encoder
        x_in = x
        x= self.backbone(x)
        # print([_.shape for _ in x])
        x= x[::-1]

        # Decoder
        x= self.decoder(x)
        # print([_.shape for _ in x])
        x_seg= self.seg_head(x[-1])
        x_seg= x_seg[..., 1:-1, 1:-1]
        x_seg= x_seg * 1500 + 3000
    
        if self.training:
            return x_seg
        else:
            p1 = self.proc_flip(x_in)
            x_seg = torch.mean(torch.stack([x_seg, p1]), dim=0)
            return x_seg

### Utils

Same as previous notebook. 

In [None]:
%%writefile _utils.py

import datetime

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

# Train and Valid

In [None]:
%%writefile train_fwi_xla_8tpu.py
import sys
import os
from tqdm import tqdm
import numpy as np
import time 
import gc
import random

import torch
import torch.nn as nn
from torch.amp import autocast
from torch.utils.data import DistributedSampler

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils as xu
from torch_xla.amp import syncfree


from _cfg import cfg
from _dataset import CustomDataset
from _model import ModelEMA, Net
from _utils import format_time
from _scheduler import ConstantCosineLR


def _mp_fn(index, cfg):
    # Setup
    torch.manual_seed(cfg.seed)
    world_size = xm.xrt_world_size()
    local_rank = xm.get_ordinal()
    device = xm.xla_device()
    xm.master_print(f"Process {local_rank} initialized on device: {device}")

    # Prepare Dataset
    train_ds = CustomDataset(cfg=cfg, mode="train")
    valid_ds = CustomDataset(cfg=cfg, mode="valid")
    xm.rendezvous('Dataset preparing complete')
    # Train Data
    sampler_tr = DistributedSampler(
        train_ds,
        num_replicas=world_size,
        rank=local_rank,
        shuffle=True
        )
    train_pytorch_dl = torch.utils.data.DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        sampler=sampler_tr,
        num_workers=1,
    )
    train_dl = pl.MpDeviceLoader(train_pytorch_dl, device)

    # Valid Data
    sampler_vl = DistributedSampler(
        valid_ds,
        num_replicas=world_size,
        rank=local_rank,
        shuffle=True
        )
    valid_pytorch_dl = torch.utils.data.DataLoader(
        valid_ds,
        batch_size=cfg.batch_size,
        sampler=sampler_vl,
        num_workers=1,
    )
    valid_dl = pl.MpDeviceLoader(valid_pytorch_dl, device)

    # Define Models
    model = Net(
        backbone=cfg.backbone,
        pretrained=False,
        )
    state_dict= torch.load("/kaggle/input/openfwi-preprocessed-72x72/models_1000x70/unet2d_caformer_seed3_epochbest.pt", map_location=torch.device('cpu'), weights_only=True)
    state_dict= {k.removeprefix("_orig_mod."):v for k,v in state_dict.items()} # Remove torch.compile() prefix
    
    model.load_state_dict(state_dict)
    model = model.to(device)
    
    if cfg.ema:
        print("Initializing EMA model..")
        ema_model = ModelEMA(
            model,
            decay=cfg.ema_decay,
            device=device            )
    else:
        ema_model = None


    # Larning Parameters
    criterion = nn.L1Loss()
    optimizer = syncfree.SGD(model.parameters(), lr=1e-6)
    # Custom Scheduler TODO: Fill the value of total_steps
    scheduler = ConstantCosineLR(optimizer, total_steps=len(train_dl)*cfg.epochs, pct_cosine=0.7)


    # ================  Train and Valid ===================
    val_best_loss = 1e5
    for epoch in range(cfg.epochs):
        if epoch == 0:
            tstart = time.time()
            print("Start Training")        
        # =============== Train Loop ===============
        tracker = xm.RateTracker()
        model.train()
        tr_total_loss = []
        sampler_tr.set_epoch(epoch)
        for i, (inputs, targets) in enumerate(train_dl):
            optimizer.zero_grad(set_to_none=True)
            with autocast('xla'):   
                logits = model(inputs)
                loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            tr_total_loss.append(loss.item())
            if ema_model is not None:
                ema_model.update(model)

            xm.optimizer_step(optimizer, barrier=True)
            scheduler.step()

            tracker.add(cfg.batch_size)
            if (i+1) % 100 == 0 or i == 0:
                train_loss = np.mean(tr_total_loss)
                tr_total_loss = []
                if xm.is_master_ordinal():
                    xm.master_print(f"Epoch: {epoch}, Step: {i+1}/{len(train_dl)}, Loss: {train_loss}, Rate: {tracker.rate()} samples/sec")
        # For insurance
        if xm.is_master_ordinal():
            xm.save(model.state_dict(), f'epoch_{epoch}_train.pt')

        # =============== Valid Loop ===============
        model.eval()
        val_logits = []
        val_targets = []
        with torch.inference_mode():
            for i, (inputs, targets) in enumerate(valid_dl):
                with autocast('xla'):
                    if ema_model is not None:
                        out = ema_model.module(inputs)
                    else:
                        out = model(inputs)
                    
                val_logits.append(out)
                val_targets.append(targets)
                
            val_logits = torch.cat(val_logits, dim=0)
            val_targets = torch.cat(val_targets, dim=0)

            gathered_logits = xm.all_gather(val_logits)
            gathered_targets = xm.all_gather(val_targets)

            if xm.is_master_ordinal():
                loss = criterion(gathered_logits, gathered_targets).item()
                xm.master_print(f'Epoch {epoch} Val Loss={loss} Time={time.asctime()}')
                # Save model
                xm.save(model.state_dict(), f'epoch_{epoch}.pt')
                xm.master_print(f'Model checkpoint saved!')
                
        xm.master_print(f"Finish epoch {epoch}")

if __name__ == '__main__':
    import os
    os.environ.pop("TPU_PROCESS_ADDRESSES")
    xmp.spawn(fn=_mp_fn, args=(cfg,), start_method='fork')

In [None]:
!python train_fwi_xla_8tpu.py