In [None]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt

# Reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Collapse A/B classes → 5 unified classes
raw_classes = [
    'CurveFault_A', 'CurveVel_A', 'FlatFault_A', 'FlatVel_A', 'Style_A',
    'CurveFault_B', 'CurveVel_B', 'FlatFault_B', 'FlatVel_B', 'Style_B'
]
collapsed_labels = {'CurveFault': 0, 'CurveVel': 1, 'FlatFault': 2, 'FlatVel': 3, 'Style': 4}

# Build full index: list of (filepath, sample_idx, label)
base_dir = '/kaggle/input/waveform-inversion/train_samples'
index = []
unique_files = {}

for cls in raw_classes:
    label_name = cls.split('_')[0]  # Collapse A/B
    label = collapsed_labels[label_name]

    cls_folder = os.path.join(base_dir, cls)
    data_folder = os.path.join(cls_folder, 'data')
    search_folder = data_folder if os.path.isdir(data_folder) else cls_folder

    fnames = sorted(f for f in os.listdir(search_folder)
                    if f.endswith('.npy') and (f.startswith('data') or f.startswith('seis')))
    for fname in fnames:
        file_path = os.path.join(search_folder, fname)
        if file_path not in unique_files:
            unique_files[file_path] = None  # mmap will go here later
        for i in range(500):
            index.append((file_path, i, label))

# Train/val split at sample level
train_idx, val_idx = train_test_split(
    index, test_size=0.1, stratify=[label for (_, _, label) in index], random_state=seed)

# Lazy mmap dataset
class SeismicLazyMmapDataset(Dataset):
    def __init__(self, index, file_map):
        self.index = index
        self.file_map = file_map

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        file_path, sample_idx, label = self.index[idx]

        if self.file_map[file_path] is None:
            self.file_map[file_path] = np.load(file_path, mmap_mode='r')

        arr = self.file_map[file_path]  # (500,5,1000,70)
        sample = arr[sample_idx]
        mean = sample.mean()
        std = sample.std() + 1e-6
        sample_norm = (sample - mean) / std
        return torch.tensor(sample_norm, dtype=torch.float32), torch.tensor(label)

# Initialize mmap-enabled datasets
train_dataset = SeismicLazyMmapDataset(train_idx, unique_files.copy())
val_dataset = SeismicLazyMmapDataset(val_idx, unique_files.copy())

# Dataloaders with persistent workers and prefetching
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, 
                          num_workers=2, pin_memory=True, persistent_workers=True, prefetch_factor=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, 
                        num_workers=2, pin_memory=True, persistent_workers=True, prefetch_factor=4)

# Model (same as before, but for completeness)
class SeismicNetV2(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(5, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256*62*4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 5)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SeismicNetV2().to(device)
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Training with checkpointing
num_epochs = 20
train_acc_list, val_acc_list = [], []
best_val_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    correct, total = 0, 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    train_acc = correct / total
    train_acc_list.append(train_acc)

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    val_acc_list.append(val_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model_to_save = model.module if hasattr(model, "module") else model
        torch.save(model_to_save.state_dict(), "/kaggle/working/best_seismicnet.pth")
        print(f"✅ Epoch {epoch+1}: New best val acc = {val_acc:.4f} → model saved.")

    print(f"Epoch {epoch+1}: Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")

# Final save
model_to_save = model.module if hasattr(model, "module") else model
torch.save(model_to_save.state_dict(), "/kaggle/working/final_seismicnet.pth")

# Accuracy curve
plt.plot(train_acc_list, label='Train Acc')
plt.plot(val_acc_list, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Train vs Validation Accuracy')
plt.legend()
plt.savefig("/kaggle/working/accuracy.png")
plt.show()