In [2]:
import os
import time
import warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from tqdm import tqdm
from timm import create_model

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load CSV and define paths ===
train_dir = "data/train_data"
test_dir = "data/test_data"
train_df = pd.read_csv('data/Train.csv')
test_df = pd.read_csv('data/Test.csv')
train_df['ID_path'] = train_df['ID'].apply(lambda x: os.path.join(train_dir, f"{x}.npy"))
test_df['ID_path'] = test_df['ID'].apply(lambda x: os.path.join(test_dir, f"{x}.npy"))

# === Train/Validation Split ===
train_set, valid_set = train_test_split(train_df, stratify=train_df['label'], test_size=0.2, random_state=42)

# === Data Utilities ===
def load_and_normalize_npy_image(path):
    img = np.load(path).astype(np.float32)
    img = (img - img.min((0, 1))) / (img.max((0, 1)) - img.min((0, 1)) + 1e-5)
    return img

class NPYDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = load_and_normalize_npy_image(self.df.iloc[idx]['ID_path'])
        label = self.df.iloc[idx].get('label', -1)
        img = torch.from_numpy(img.transpose(2, 0, 1))
        return img, torch.tensor(label, dtype=torch.long)

class TestDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = load_and_normalize_npy_image(self.df.iloc[idx]['ID_path'])
        img = torch.from_numpy(img.transpose(2, 0, 1))
        return img, self.df.iloc[idx]['ID']

# === Model Wrapping Utility ===
def wrap_model(model, in_channels=12, num_classes=2):
    if hasattr(model, 'conv1') and isinstance(model.conv1, nn.Conv2d):
        model.conv1 = nn.Conv2d(in_channels, model.conv1.out_channels, kernel_size=7, stride=2, padding=3, bias=False)
    if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear):
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    if hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

# === Training ===
def train_model(model, model_name):
    train_loader = DataLoader(NPYDataset(train_set), batch_size=16, shuffle=True, num_workers=4)
    valid_loader = DataLoader(NPYDataset(valid_set), batch_size=16, shuffle=False, num_workers=4)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    best_val_f1, best_weights = 0, None
    patience, counter = 10, 0

    for epoch in range(100):
        model.train()
        preds, labels = [], []
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            preds.extend(torch.argmax(out, 1).cpu().numpy())
            labels.extend(y.cpu().numpy())
        train_f1 = f1_score(labels, preds, average='weighted')

        model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for x, y in valid_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                preds.extend(torch.argmax(out, 1).cpu().numpy())
                labels.extend(y.cpu().numpy())
        val_f1 = f1_score(labels, preds, average='weighted')

        print(f"[{model_name}] Epoch {epoch+1} Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_weights = model.state_dict()
            counter = 0
        else:
            counter += 1
        if counter >= patience:
            print(f"[{model_name}] Early stopping at epoch {epoch+1}")
            break

    model.load_state_dict(best_weights)
    torch.save(model.state_dict(), f"{model_name}_best_{best_val_f1:.4f}.pth")
    return model, best_val_f1

# === Inference ===
def predict_and_save(model, name):
    loader = DataLoader(TestDataset(test_df), batch_size=16, shuffle=False, num_workers=4)
    model.eval()
    preds, ids = [], []
    with torch.no_grad():
        for x, id_batch in tqdm(loader, desc=name):
            x = x.to(device)
            out = model(x)
            preds.extend(torch.argmax(out, 1).cpu().numpy())
            ids.extend(id_batch)

    df = pd.DataFrame({'ID': ids, 'Predicted': preds})
    os.makedirs("subs", exist_ok=True)
    df.to_csv(f"subs/{name}_submission.csv", index=False)

# === Define Model Constructors ===
model_names = {
    'resnet18': lambda: models.resnet18(weights=models.ResNet18_Weights.DEFAULT),
    'resnet34': lambda: models.resnet34(weights=models.ResNet34_Weights.DEFAULT),
    'resnet50': lambda: models.resnet50(weights=models.ResNet50_Weights.DEFAULT),
    'resnet101': lambda: models.resnet101(weights=models.ResNet101_Weights.DEFAULT),
    'resnet152': lambda: models.resnet152(weights=models.ResNet152_Weights.DEFAULT),
    'densenet121': lambda: models.densenet121(weights=models.DenseNet121_Weights.DEFAULT),
    'efficientnet_b0': lambda: create_model('efficientnet_b0', pretrained=True, in_chans=12, num_classes=2),
    'convnext_base': lambda: create_model('convnext_base', pretrained=True, in_chans=12, num_classes=2),
    'swin_base_patch4_window7_224': lambda: create_model('swin_base_patch4_window7_224', pretrained=True, in_chans=12, num_classes=2)
}

# === Train & Predict Loop ===
for name, constructor in model_names.items():
    print(f"\n======== Training {name} ========")
    model = constructor()
    model = wrap_model(model)
    model, f1 = train_model(model, name)
    predict_and_save(model, f"{name}_f1_{f1:.4f}")



[resnet18] Epoch 1 Train F1: 0.7634 | Val F1: 0.7452
[resnet18] Epoch 2 Train F1: 0.8190 | Val F1: 0.8273
[resnet18] Epoch 3 Train F1: 0.8542 | Val F1: 0.8530
[resnet18] Epoch 4 Train F1: 0.8571 | Val F1: 0.8950
[resnet18] Epoch 5 Train F1: 0.8742 | Val F1: 0.8783
[resnet18] Epoch 6 Train F1: 0.8792 | Val F1: 0.8994
[resnet18] Epoch 7 Train F1: 0.8853 | Val F1: 0.9072
[resnet18] Epoch 8 Train F1: 0.8994 | Val F1: 0.8956
[resnet18] Epoch 9 Train F1: 0.9038 | Val F1: 0.8747
[resnet18] Epoch 10 Train F1: 0.9059 | Val F1: 0.9070
[resnet18] Epoch 11 Train F1: 0.9042 | Val F1: 0.8998
[resnet18] Epoch 12 Train F1: 0.9134 | Val F1: 0.8970
[resnet18] Epoch 13 Train F1: 0.9269 | Val F1: 0.9039
[resnet18] Epoch 14 Train F1: 0.9287 | Val F1: 0.8922
[resnet18] Epoch 15 Train F1: 0.9315 | Val F1: 0.8818
[resnet18] Epoch 16 Train F1: 0.9377 | Val F1: 0.8996
[resnet18] Epoch 17 Train F1: 0.8951 | Val F1: 0.9018
[resnet18] Early stopping at epoch 17


resnet18_f1_0.9072: 100%|██████████| 338/338 [00:00<00:00, 373.10it/s]



Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /home/tdx/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 114MB/s] 


[resnet34] Epoch 1 Train F1: 0.7509 | Val F1: 0.7468
[resnet34] Epoch 2 Train F1: 0.7873 | Val F1: 0.8239
[resnet34] Epoch 3 Train F1: 0.8538 | Val F1: 0.7788
[resnet34] Epoch 4 Train F1: 0.8102 | Val F1: 0.7540
[resnet34] Epoch 5 Train F1: 0.8476 | Val F1: 0.8461
[resnet34] Epoch 6 Train F1: 0.8716 | Val F1: 0.8240
[resnet34] Epoch 7 Train F1: 0.8723 | Val F1: 0.8887
[resnet34] Epoch 8 Train F1: 0.8204 | Val F1: 0.8307
[resnet34] Epoch 9 Train F1: 0.8585 | Val F1: 0.8815
[resnet34] Epoch 10 Train F1: 0.8781 | Val F1: 0.8885
[resnet34] Epoch 11 Train F1: 0.8827 | Val F1: 0.8855
[resnet34] Epoch 12 Train F1: 0.8935 | Val F1: 0.6923
[resnet34] Epoch 13 Train F1: 0.8924 | Val F1: 0.8998
[resnet34] Epoch 14 Train F1: 0.9065 | Val F1: 0.9068
[resnet34] Epoch 15 Train F1: 0.9031 | Val F1: 0.8903
[resnet34] Epoch 16 Train F1: 0.8865 | Val F1: 0.8828
[resnet34] Epoch 17 Train F1: 0.8863 | Val F1: 0.8979
[resnet34] Epoch 18 Train F1: 0.9036 | Val F1: 0.9023
[resnet34] Epoch 19 Train F1: 0.9203 

resnet34_f1_0.9144: 100%|██████████| 338/338 [00:01<00:00, 327.36it/s]



[resnet50] Epoch 1 Train F1: 0.7884 | Val F1: 0.8513
[resnet50] Epoch 2 Train F1: 0.8540 | Val F1: 0.8706
[resnet50] Epoch 3 Train F1: 0.8667 | Val F1: 0.8932
[resnet50] Epoch 4 Train F1: 0.8761 | Val F1: 0.9056
[resnet50] Epoch 5 Train F1: 0.9003 | Val F1: 0.8973
[resnet50] Epoch 6 Train F1: 0.9047 | Val F1: 0.8709
[resnet50] Epoch 7 Train F1: 0.9145 | Val F1: 0.8861
[resnet50] Epoch 8 Train F1: 0.9195 | Val F1: 0.9036
[resnet50] Epoch 9 Train F1: 0.9232 | Val F1: 0.9029
[resnet50] Epoch 10 Train F1: 0.9342 | Val F1: 0.8959
[resnet50] Epoch 11 Train F1: 0.9422 | Val F1: 0.9180
[resnet50] Epoch 12 Train F1: 0.9363 | Val F1: 0.9013
[resnet50] Epoch 13 Train F1: 0.9519 | Val F1: 0.9147
[resnet50] Epoch 14 Train F1: 0.9487 | Val F1: 0.9073
[resnet50] Epoch 15 Train F1: 0.9584 | Val F1: 0.9131
[resnet50] Epoch 16 Train F1: 0.9682 | Val F1: 0.9050
[resnet50] Epoch 17 Train F1: 0.9725 | Val F1: 0.9039
[resnet50] Epoch 18 Train F1: 0.9687 | Val F1: 0.8951
[resnet50] Epoch 19 Train F1: 0.9729

resnet50_f1_0.9180: 100%|██████████| 338/338 [00:01<00:00, 270.57it/s]



Downloading: "https://download.pytorch.org/models/resnet101-cd907fc2.pth" to /home/tdx/.cache/torch/hub/checkpoints/resnet101-cd907fc2.pth


100%|██████████| 171M/171M [00:01<00:00, 115MB/s]  


[resnet101] Epoch 1 Train F1: 0.7689 | Val F1: 0.7452
[resnet101] Epoch 2 Train F1: 0.8382 | Val F1: 0.8899
[resnet101] Epoch 3 Train F1: 0.8514 | Val F1: 0.8648
[resnet101] Epoch 4 Train F1: 0.8534 | Val F1: 0.8659
[resnet101] Epoch 5 Train F1: 0.8575 | Val F1: 0.8868
[resnet101] Epoch 6 Train F1: 0.8741 | Val F1: 0.8879
[resnet101] Epoch 7 Train F1: 0.8910 | Val F1: 0.8865
[resnet101] Epoch 8 Train F1: 0.8955 | Val F1: 0.9118
[resnet101] Epoch 9 Train F1: 0.8865 | Val F1: 0.8944
[resnet101] Epoch 10 Train F1: 0.8922 | Val F1: 0.8876
[resnet101] Epoch 11 Train F1: 0.8680 | Val F1: 0.7897
[resnet101] Epoch 12 Train F1: 0.8854 | Val F1: 0.9038
[resnet101] Epoch 13 Train F1: 0.8829 | Val F1: 0.8755
[resnet101] Epoch 14 Train F1: 0.9010 | Val F1: 0.8980
[resnet101] Epoch 15 Train F1: 0.9147 | Val F1: 0.9167
[resnet101] Epoch 16 Train F1: 0.9166 | Val F1: 0.8822
[resnet101] Epoch 17 Train F1: 0.9233 | Val F1: 0.9153
[resnet101] Epoch 18 Train F1: 0.9325 | Val F1: 0.8916
[resnet101] Epoch 1

resnet101_f1_0.9167: 100%|██████████| 338/338 [00:01<00:00, 190.96it/s]



Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /home/tdx/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth


100%|██████████| 230M/230M [00:02<00:00, 116MB/s]  


[resnet152] Epoch 1 Train F1: 0.7701 | Val F1: 0.8192
[resnet152] Epoch 2 Train F1: 0.8371 | Val F1: 0.8639
[resnet152] Epoch 3 Train F1: 0.8679 | Val F1: 0.8646
[resnet152] Epoch 4 Train F1: 0.8661 | Val F1: 0.8285
[resnet152] Epoch 5 Train F1: 0.8609 | Val F1: 0.8091
[resnet152] Epoch 6 Train F1: 0.8164 | Val F1: 0.7452
[resnet152] Epoch 7 Train F1: 0.7472 | Val F1: 0.8093
[resnet152] Epoch 8 Train F1: 0.8202 | Val F1: 0.8500
[resnet152] Epoch 9 Train F1: 0.8521 | Val F1: 0.8121
[resnet152] Epoch 10 Train F1: 0.8596 | Val F1: 0.8454
[resnet152] Epoch 11 Train F1: 0.8356 | Val F1: 0.8795
[resnet152] Epoch 12 Train F1: 0.8645 | Val F1: 0.8827
[resnet152] Epoch 13 Train F1: 0.8765 | Val F1: 0.8985
[resnet152] Epoch 14 Train F1: 0.8905 | Val F1: 0.8752
[resnet152] Epoch 15 Train F1: 0.8914 | Val F1: 0.7761
[resnet152] Epoch 16 Train F1: 0.8815 | Val F1: 0.8649
[resnet152] Epoch 17 Train F1: 0.8889 | Val F1: 0.8768
[resnet152] Epoch 18 Train F1: 0.8898 | Val F1: 0.8319
[resnet152] Epoch 1

resnet152_f1_0.9177: 100%|██████████| 338/338 [00:02<00:00, 145.56it/s]



Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /home/tdx/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


100%|██████████| 30.8M/30.8M [00:00<00:00, 106MB/s] 


RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[16, 12, 64, 64] to have 3 channels, but got 12 channels instead