In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import csv
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tqdm import tqdm


In [2]:
all_inputs = [
    f
    for f in
    Path('/kaggle/input/waveform-inversion/train_samples').rglob('*.npy')
    if ('seis' in f.stem) or ('data' in f.stem)
]

def inputs_files_to_output_files(input_files):
    return [
        Path(str(f).replace('seis', 'vel').replace('data', 'model'))
        for f in input_files
    ]

all_outputs = inputs_files_to_output_files(all_inputs)

assert all(f.exists() for f in all_outputs)

train_inputs = [all_inputs[i] for i in range(0, len(all_inputs), 2)] # Sample every two
valid_inputs = [f for f in all_inputs if not f in train_inputs]

train_outputs = inputs_files_to_output_files(train_inputs)
valid_outputs = inputs_files_to_output_files(valid_inputs)

class SeismicDataset(Dataset):
    def __init__(self, inputs_files, output_files, n_examples_per_file=500):
        assert len(inputs_files) == len(output_files)
        self.inputs_files = inputs_files
        self.output_files = output_files
        self.n_examples_per_file = n_examples_per_file

    def __len__(self):
        return len(self.inputs_files) * self.n_examples_per_file

    def __getitem__(self, idx):
        # Calculate file offset and sample offset within file
        file_idx = idx // self.n_examples_per_file
        sample_idx = idx % self.n_examples_per_file

        X = np.load(self.inputs_files[file_idx], mmap_mode='r')
        y = np.load(self.output_files[file_idx], mmap_mode='r')

        try:
            return X[sample_idx].copy(), y[sample_idx].copy()
        finally:
            del X, y

In [3]:
dstrain = SeismicDataset(train_inputs, train_outputs)
dsvalid = SeismicDataset(valid_inputs, valid_outputs)

train_loader = DataLoader(dstrain, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(dsvalid, batch_size=16, shuffle=True, num_workers=4)

In [4]:
class DoubleConv(nn.Module):
    # (Conv → BN → GELU) × 2
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c,  out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.GELU(),
            nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.GELU()
        )

    def forward(self, x):
        return self.net(x)


class Down(nn.Module):
    # Max-pool then DoubleConv
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_c, out_c)
        )

    def forward(self, x):
        return self.net(x)

class Up(nn.Module):
    # Upsample  →  concat with skip  →  DoubleConv
    def __init__(self, in_c, out_c):
        super().__init__()
        # up-sample:        
        self.up   = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
        self.conv = DoubleConv(out_c * 3, out_c)   

    def forward(self, x, skip):
        x = self.up(x)                              
        if x.shape[-2:] != skip.shape[-2:]:
            skip = F.interpolate(skip, size=x.shape[-2:], mode="bilinear",
                                  align_corners=False)
        x = torch.cat([skip, x], dim=1)             
        return self.conv(x)

# ---------- the network ---------- #
class UNet70(nn.Module):
    def __init__(self, in_channels=5, base_c=64, out_scale=1000, out_shift=1500):
        super().__init__()
        self.scale  = out_scale
        self.shift  = out_shift

        # encoder
        self.inc   = DoubleConv(in_channels,  base_c)        #  64
        self.down1 = Down(base_c,        base_c * 2)         # 128
        self.down2 = Down(base_c * 2,    base_c * 4)         # 256
        self.down3 = Down(base_c * 4,    base_c * 8)         # 512
        self.down4 = Down(base_c * 8,    base_c * 8)         # bottleneck

        # decoder
        self.up1 = Up(base_c * 8, base_c * 4)                # 256
        self.up2 = Up(base_c * 4, base_c * 2)                # 128
        self.up3 = Up(base_c * 2, base_c)                    #  64
        self.up4 = Up(base_c,     base_c // 2)               #  32

        # head
        self.out_conv = nn.Conv2d(base_c // 2, 1, 1)         # 1 × H × W

    def forward(self, x):
        # ---------- encoder ----------
        x1 = self.inc(x)     # N×64×H×W
        x2 = self.down1(x1)  # N×128×H/2×W/2
        x3 = self.down2(x2)  # N×256×H/4×W/4
        x4 = self.down3(x3)  # N×512×H/8×W/8
        x5 = self.down4(x4)  # N×512×H/16×W/16 (bottleneck)

        # ---------- decoder ----------
        y = self.up1(x5, x4)
        y = self.up2(y,  x3)
        y = self.up3(y,  x2)
        y = self.up4(y,  x1)

        y = self.out_conv(y)               # (N,1,H/2,W/2)

        # enforce fixed 70×70 regardless of input size
        y = F.interpolate(y, size=(70, 70), mode="bilinear", align_corners=False)

        return y * self.scale + self.shift  


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = UNet70().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

epochs = 20  

best_val_loss = float('inf')

for epoch in range(epochs):
    # ── 1. TRAIN ───────────────────────────────────────────────────────
    model.train()
    running_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"[Train] Epoch {epoch+1}"):
        xb, yb = xb.to(device), yb.to(device).squeeze(1)
        optimizer.zero_grad()
        preds = model(xb).squeeze(1)
        loss  = loss_fn(preds, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)

    # ── 2. VALIDATE ────────────────────────────────────────────────────
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb, yb = xb.to(device), yb.to(device).squeeze(1)
            preds  = model(xb).squeeze(1)
            val_loss += loss_fn(preds, yb).item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1} | train {train_loss:.4f} | val {val_loss:.4f}")

    # ── 3. CHECKPOINT ON VAL LOSS ─────────────────────────────────────
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
        print(f"✅  New best model saved (val loss {best_val_loss:.4f})")

Using device: cuda


[Train] Epoch 1:  49%|████▊     | 152/313 [01:20<01:27,  1.85it/s]

In [None]:
%%time
test_files = list(Path('/kaggle/input/waveform-inversion/test').glob('*.npy'))
len(test_files)

In [None]:
x_cols = [f'x_{i}' for i in range(1, 70, 2)]
fieldnames = ['oid_ypos'] + x_cols

In [None]:
class TestDataset(Dataset):
    def __init__(self, test_files):
        self.test_files = test_files


    def __len__(self):
        return len(self.test_files)


    def __getitem__(self, i):
        test_file = self.test_files[i]

        return np.load(test_file), test_file.stem

In [None]:
ds = TestDataset(test_files)
dl = DataLoader(ds, batch_size=128, num_workers=4, pin_memory=True)

In [None]:
PATH = "/kaggle/working/best_model.pth" 
model.eval()
model.load_state_dict(torch.load(PATH, weights_only=True))

with open('submission.csv', 'wt', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    
    for inputs, oids_test in tqdm(dl, desc='test'):
        inputs = inputs.to(device)
        with torch.inference_mode():
            outputs = model(inputs)

        y_preds = outputs[:, 0].cpu().numpy()
        
        for y_pred, oid_test in zip(y_preds, oids_test):
            for y_pos in range(70):
                row = dict(
                    zip(
                        x_cols,
                        [y_pred[y_pos, x_pos] for x_pos in range(1, 70, 2)]
                    )
                )
                row['oid_ypos'] = f"{oid_test}_y_{y_pos}"
            
                writer.writerow(row)